diff --git a/CMakeLists.txt b/CMakeLists.txt index 4fb78e59f..fd9499826 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -228,6 +228,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-cann.h include/ggml-cuda.h include/ggml-kompute.h + include/ggml-opt.h include/ggml-metal.h include/ggml-rpc.h include/ggml-sycl.h diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 79c7f2442..b273a1add 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,7 +20,7 @@ target_include_directories(common-ggml PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(gpt-2) add_subdirectory(gpt-j) -# add_subdirectory(mnist) +add_subdirectory(mnist) add_subdirectory(sam) add_subdirectory(yolo) add_subdirectory(simple) diff --git a/examples/mnist/README.md b/examples/mnist/README.md index 9e0966f44..10df02dbb 100644 --- a/examples/mnist/README.md +++ b/examples/mnist/README.md @@ -18,7 +18,7 @@ $ python3 mnist-train-fc.py mnist-fc-f32.gguf ... -Test loss: 0.066051+-0.011630, Test accuracy: 98.07+-0.14% +Test loss: 0.066377+-0.010468, Test accuracy: 97.94+-0.14% Model tensors saved to mnist-fc-f32.gguf: fc1.weight (500, 784) @@ -61,22 +61,21 @@ ________________________________________________________ ________________________________________________________ ________________________________________________________ ________________________________________________________ -mnist_graph_eval: trying to load a ggml graph from mnist-fc-f32.gguf -ggml_graph_import: invalid magic number, got 46554747 -mnist_graph_eval: could not load a ggml graph from mnist-fc-f32.gguf ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes -mnist_model: using CPU backend +mnist_model: using CUDA0 (NVIDIA GeForce RTX 3090) as primary backend +mnist_model: unsupported operations will be executed on the following fallback backends (in order of priority): +mnist_model: - CPU (AMD Ryzen 9 5950X 16-Core Processor) mnist_model_init_from_file: loading model weights from 'mnist-fc-f32.gguf' mnist_model_init_from_file: model arch is mnist-fc mnist_model_init_from_file: successfully loaded weights from mnist-fc-f32.gguf -main: loaded model in 13.03 ms -mnist_model_eval: model evaluation on 10000 images took 95.02 ms, 9.50 us/image +main: loaded model in 109.44 ms +mnist_model_eval: model evaluation on 10000 images took 76.92 ms, 7.69 us/image main: predicted digit is 3 -main: test_loss=0.066051+-0.009343 -main: test_acc=98.07+-0.14% +main: test_loss=0.066379+-0.009101 +main: test_acc=97.94+-0.14% ``` In addition to the evaluation on the test set the GGML evaluation also prints a random image from the test set as well as the model prediction for said image. @@ -87,10 +86,6 @@ $ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-im ``` It can then be evaluated with the same binary as above. -When training a model with GGML the computation graph for the forward pass is also exported to `mnist-fc-f32.ggml`. -Compared to the GGUF (which only contains the weights) this file also contains the model architecture. -As long as the input and output tensors are well-defined an exported GGML graph is fully agnostic w.r.t. the model architecture. -It can be evaluated using the `mnist-eval` binary by substituting the argument for the GGUF file. ## Convolutional network @@ -101,8 +96,8 @@ $ python3 mnist-train-cnn.py mnist-cnn-f32.gguf ... -Test loss: 0.045483 -Test accuracy: 98.56% +Test loss: 0.047947 +Test accuracy: 98.46% GGUF model saved to 'mnist-cnn-f32.gguf' ``` @@ -139,25 +134,24 @@ ________________________________________________________ ________________________________________________________ ________________________________________________________ ________________________________________________________ -mnist_graph_eval: trying to load a ggml graph from mnist-cnn-f32.gguf -ggml_graph_import: invalid magic number, got 46554747 -mnist_graph_eval: could not load a ggml graph from mnist-cnn-f32.gguf ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes -mnist_model: using CPU backend +mnist_model: using CUDA0 (NVIDIA GeForce RTX 3090) as primary backend +mnist_model: unsupported operations will be executed on the following fallback backends (in order of priority): +mnist_model: - CPU (AMD Ryzen 9 5950X 16-Core Processor) mnist_model_init_from_file: loading model weights from 'mnist-cnn-f32.gguf' mnist_model_init_from_file: model arch is mnist-cnn mnist_model_init_from_file: successfully loaded weights from mnist-cnn-f32.gguf -main: loaded model in 11.88 ms -mnist_model_eval: model evaluation on 10000 images took 1074.09 ms, 107.41 us/image +main: loaded model in 91.99 ms +mnist_model_eval: model evaluation on 10000 images took 267.61 ms, 26.76 us/image main: predicted digit is 1 -main: test_loss=0.045483+-0.006884 -main: test_acc=98.56+-0.12% +main: test_loss=0.047955+-0.007029 +main: test_acc=98.46+-0.12% ``` -Like with the fully connected network the convolutional network can also be trained on the CPU using GGML: +Like with the fully connected network the convolutional network can also be trained using GGML: ``` bash $ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte @@ -165,11 +159,12 @@ $ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train- As always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`. -## CUDA +## Hardware Acceleration -The fully connected model can be trained and evaluated using CUDA. -`mnist-train` and `mnist-eval` accept an additional, optional argument behind those listed so far to specify the backend. -The default is `CPU`, by specifying `CUDA0` the first available CUDA device can be used instead (make sure to compile GGML with CUDA cupport). +Both the training and evaluation code is agnostic in terms of hardware as long as the corresponding GGML backend has implemented the necessary operations. +A specific backend can be selected by appending the above commands with a backend name. +The compute graphs then schedule the operations to preferentially use the specified backend. +Note that if a backend does not implement some of the necessary operations a CPU fallback is used instead which may result in bad performance. ## Web demo diff --git a/examples/mnist/mnist-common.cpp b/examples/mnist/mnist-common.cpp index 0c4ca07d0..c848823b0 100644 --- a/examples/mnist/mnist-common.cpp +++ b/examples/mnist/mnist-common.cpp @@ -1,6 +1,7 @@ +#include "ggml.h" #include "ggml-alloc.h" #include "ggml-backend.h" -#include "ggml.h" +#include "ggml-opt.h" #include "mnist-common.h" @@ -14,7 +15,7 @@ #include #include -bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) { +bool mnist_image_load(const std::string & fname, ggml_opt_dataset_t dataset) { auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { fprintf(stderr, "failed to open images file %s\n", fname.c_str()); @@ -23,12 +24,14 @@ bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) { fin.seekg(16); uint8_t image[MNIST_NINPUT]; - float * buf = ggml_get_data_f32(dataset.data); + struct ggml_tensor * images = ggml_opt_dataset_data(dataset); + float * buf = ggml_get_data_f32(images); - for (int iex = 0; iex < dataset.nex; ++iex) { + GGML_ASSERT(images->ne[0] == MNIST_NINPUT); + for (int64_t iex = 0; iex < images->ne[1]; ++iex) { fin.read((char *) image, sizeof(image)); - for (int i = 0; i < MNIST_NINPUT; ++i) { + for (int64_t i = 0; i < MNIST_NINPUT; ++i) { buf[iex*MNIST_NINPUT + i] = image[i] / 255.0f; // Normalize to [0, 1] } } @@ -36,11 +39,14 @@ bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) { return true; } -void mnist_image_print(FILE * stream, mnist_dataset & dataset, const int iex) { - const float * image = ggml_get_data_f32(dataset.data) + iex*MNIST_NINPUT; +void mnist_image_print(FILE * stream, ggml_opt_dataset_t dataset, const int iex) { + struct ggml_tensor * images = ggml_opt_dataset_data(dataset); + GGML_ASSERT(images->ne[0] == MNIST_NINPUT); + GGML_ASSERT(iex < images->ne[1]); + const float * image = ggml_get_data_f32(images) + iex*MNIST_NINPUT; - for (int row = 0; row < MNIST_HW; row++) { - for (int col = 0; col < MNIST_HW; col++) { + for (int64_t row = 0; row < MNIST_HW; row++) { + for (int64_t col = 0; col < MNIST_HW; col++) { const int rgb = roundf(255.0f * image[row*MNIST_HW + col]); #ifdef _WIN32 fprintf(stream, "%s", rgb >= 220 ? "##" : "__"); // Represented via text. @@ -52,7 +58,7 @@ void mnist_image_print(FILE * stream, mnist_dataset & dataset, const int iex) { } } -bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) { +bool mnist_label_load(const std::string & fname, ggml_opt_dataset_t dataset) { auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { fprintf(stderr, "failed to open labels file %s\n", fname.c_str()); @@ -61,12 +67,14 @@ bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) { fin.seekg(8); uint8_t label; - float * buf = ggml_get_data_f32(dataset.labels); + struct ggml_tensor * labels = ggml_opt_dataset_labels(dataset); + float * buf = ggml_get_data_f32(labels); - for (int iex = 0; iex < dataset.nex; ++iex) { + GGML_ASSERT(labels->ne[0] == MNIST_NCLASSES); + for (int64_t iex = 0; iex < labels->ne[1]; ++iex) { fin.read((char *) &label, sizeof(label)); - for (int i = 0; i < MNIST_NCLASSES; ++i) { + for (int64_t i = 0; i < MNIST_NCLASSES; ++i) { buf[iex*MNIST_NCLASSES + i] = i == label ? 1.0f : 0.0f; } } @@ -74,99 +82,6 @@ bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) { return true; } -mnist_eval_result mnist_graph_eval(const std::string & fname, const float * images, const float * labels, const int nex, const int nthreads) { - fprintf(stderr, "%s: trying to load a ggml graph from %s\n", __func__, fname.c_str()); - mnist_eval_result result; - - struct ggml_context * ctx_data = nullptr; - struct ggml_context * ctx_eval = nullptr; - - struct ggml_cgraph * gf; - { - const int64_t t_start_us = ggml_time_us(); - - gf = ggml_graph_import(fname.c_str(), &ctx_data, &ctx_eval); - - const int64_t t_total_us = ggml_time_us() - t_start_us; - const double t_total_ms = 1e-3*t_total_us; - if (gf) { - fprintf(stderr, "%s: graph import took %.2lf ms\n", __func__, t_total_ms); - } - } - - if (!gf) { - fprintf(stderr, "%s: could not load a ggml graph from %s\n", __func__, fname.c_str()); - return result; - } - fprintf(stderr, "%s: successfully loaded a ggml graph from %s\n", __func__, fname.c_str()); - - const size_t buf_size = 100 * 1024*1024; - void * buf_compute = malloc(buf_size); - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf_compute, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx_compute = ggml_init(params); - - struct ggml_tensor * images_batch = ggml_graph_get_tensor(gf, "images"); - GGML_ASSERT(images_batch); - GGML_ASSERT(images_batch->ne[0] == MNIST_NINPUT || (images_batch->ne[0] == MNIST_HW && images_batch->ne[1] == MNIST_HW)); - - struct ggml_tensor * labels_batch = ggml_graph_get_tensor(gf, "labels"); - GGML_ASSERT(labels_batch); - GGML_ASSERT(labels_batch->ne[0] == MNIST_NCLASSES); - GGML_ASSERT(labels_batch->ne[2] == 1); - GGML_ASSERT(labels_batch->ne[3] == 1); - - const int nbatch = labels_batch->ne[1]; - GGML_ASSERT(nex % nbatch == 0); - - struct ggml_tensor * logits_batch = ggml_graph_get_tensor(gf, "logits"); - GGML_ASSERT(logits_batch); - GGML_ASSERT(logits_batch->ne[0] == MNIST_NCLASSES); - GGML_ASSERT(logits_batch->ne[1] == nbatch); - GGML_ASSERT(logits_batch->ne[2] == 1); - GGML_ASSERT(logits_batch->ne[3] == 1); - - GGML_ASSERT(images_batch->ne[1] == logits_batch->ne[1] || images_batch->ne[3] == logits_batch->ne[1]); - - struct ggml_tensor * loss = ggml_graph_get_tensor(gf, "loss"); - - { - const int64_t t_start_us = ggml_time_us(); - - for (int iex0; iex0 < nex; iex0 += nbatch) { - memcpy(images_batch->data, images + iex0*MNIST_NINPUT, ggml_nbytes(images_batch)); - memcpy(labels_batch->data, labels + iex0*MNIST_NCLASSES, ggml_nbytes(labels_batch)); - ggml_graph_compute_with_ctx(ctx_compute, gf, nthreads); - - for (int iexb = 0; iexb < nbatch; ++iexb) { - const float * probs_data = ggml_get_data_f32(logits_batch) + iexb*MNIST_NCLASSES; - - result.pred.push_back(std::max_element(probs_data, probs_data + MNIST_NCLASSES) - probs_data); - } - - result.loss.push_back(*ggml_get_data_f32(loss)); - } - - const int64_t t_total_us = ggml_time_us() - t_start_us; - const double t_total_ms = 1e-3*t_total_us; - fprintf(stderr, "%s: model evaluation on %d images took %.2lf ms, %.2lf us/image\n", - __func__, nex, t_total_ms, (double) t_total_us/nex); - } - - ggml_free(ctx_data); - ggml_free(ctx_eval); - ggml_free(ctx_compute); - free(buf_compute); - - result.success = true; - return result; -} - // Temporary util function for loading data from GGUF to a backend != CPU until GGML itself provides this functionality: bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf) { FILE * f = ggml_fopen(fname, "rb"); @@ -213,15 +128,15 @@ bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct g return true; } -mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend) { - mnist_model model(backend); +mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend, const int nbatch_logical, const int nbatch_physical) { + mnist_model model(backend, nbatch_logical, nbatch_physical); fprintf(stderr, "%s: loading model weights from '%s'\n", __func__, fname.c_str()); struct gguf_context * ctx; { struct gguf_init_params params = { /*.no_alloc =*/ true, - /*.ctx =*/ &model.ctx_weight, + /*.ctx =*/ &model.ctx_gguf, }; ctx = gguf_init_from_file(fname.c_str(), params); if (!ctx) { @@ -233,66 +148,66 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str fprintf(stderr, "%s: model arch is %s\n", __func__, model.arch.c_str()); if (model.arch == "mnist-fc") { - model.fc1_weight = ggml_get_tensor(model.ctx_weight, "fc1.weight"); + model.fc1_weight = ggml_get_tensor(model.ctx_gguf, "fc1.weight"); GGML_ASSERT(model.fc1_weight->ne[0] == MNIST_NINPUT); GGML_ASSERT(model.fc1_weight->ne[1] == MNIST_NHIDDEN); GGML_ASSERT(model.fc1_weight->ne[2] == 1); GGML_ASSERT(model.fc1_weight->ne[3] == 1); - model.fc1_bias = ggml_get_tensor(model.ctx_weight, "fc1.bias"); + model.fc1_bias = ggml_get_tensor(model.ctx_gguf, "fc1.bias"); GGML_ASSERT(model.fc1_bias->ne[0] == MNIST_NHIDDEN); GGML_ASSERT(model.fc1_bias->ne[1] == 1); GGML_ASSERT(model.fc1_bias->ne[2] == 1); GGML_ASSERT(model.fc1_bias->ne[3] == 1); - model.fc2_weight = ggml_get_tensor(model.ctx_weight, "fc2.weight"); + model.fc2_weight = ggml_get_tensor(model.ctx_gguf, "fc2.weight"); GGML_ASSERT(model.fc2_weight->ne[0] == MNIST_NHIDDEN); GGML_ASSERT(model.fc2_weight->ne[1] == MNIST_NCLASSES); GGML_ASSERT(model.fc2_weight->ne[2] == 1); GGML_ASSERT(model.fc2_weight->ne[3] == 1); - model.fc2_bias = ggml_get_tensor(model.ctx_weight, "fc2.bias"); + model.fc2_bias = ggml_get_tensor(model.ctx_gguf, "fc2.bias"); GGML_ASSERT(model.fc2_bias->ne[0] == MNIST_NCLASSES); GGML_ASSERT(model.fc2_bias->ne[1] == 1); GGML_ASSERT(model.fc2_bias->ne[2] == 1); GGML_ASSERT(model.fc2_bias->ne[3] == 1); } else if (model.arch == "mnist-cnn") { - model.conv1_kernel = ggml_get_tensor(model.ctx_weight, "conv1.kernel"); + model.conv1_kernel = ggml_get_tensor(model.ctx_gguf, "conv1.kernel"); GGML_ASSERT(model.conv1_kernel->type == GGML_TYPE_F32); GGML_ASSERT(model.conv1_kernel->ne[0] == 3); GGML_ASSERT(model.conv1_kernel->ne[1] == 3); GGML_ASSERT(model.conv1_kernel->ne[2] == 1); GGML_ASSERT(model.conv1_kernel->ne[3] == MNIST_CNN_NCB); - model.conv1_bias = ggml_get_tensor(model.ctx_weight, "conv1.bias"); + model.conv1_bias = ggml_get_tensor(model.ctx_gguf, "conv1.bias"); GGML_ASSERT(model.conv1_bias->type == GGML_TYPE_F32); GGML_ASSERT(model.conv1_bias->ne[0] == 1); GGML_ASSERT(model.conv1_bias->ne[1] == 1); GGML_ASSERT(model.conv1_bias->ne[2] == MNIST_CNN_NCB); GGML_ASSERT(model.conv1_bias->ne[3] == 1); - model.conv2_kernel = ggml_get_tensor(model.ctx_weight, "conv2.kernel"); + model.conv2_kernel = ggml_get_tensor(model.ctx_gguf, "conv2.kernel"); GGML_ASSERT(model.conv2_kernel->type == GGML_TYPE_F32); GGML_ASSERT(model.conv2_kernel->ne[0] == 3); GGML_ASSERT(model.conv2_kernel->ne[1] == 3); GGML_ASSERT(model.conv2_kernel->ne[2] == MNIST_CNN_NCB); GGML_ASSERT(model.conv2_kernel->ne[3] == MNIST_CNN_NCB*2); - model.conv2_bias = ggml_get_tensor(model.ctx_weight, "conv2.bias"); + model.conv2_bias = ggml_get_tensor(model.ctx_gguf, "conv2.bias"); GGML_ASSERT(model.conv2_bias->type == GGML_TYPE_F32); GGML_ASSERT(model.conv2_bias->ne[0] == 1); GGML_ASSERT(model.conv2_bias->ne[1] == 1); GGML_ASSERT(model.conv2_bias->ne[2] == MNIST_CNN_NCB*2); GGML_ASSERT(model.conv2_bias->ne[3] == 1); - model.dense_weight = ggml_get_tensor(model.ctx_weight, "dense.weight"); + model.dense_weight = ggml_get_tensor(model.ctx_gguf, "dense.weight"); GGML_ASSERT(model.dense_weight->type == GGML_TYPE_F32); GGML_ASSERT(model.dense_weight->ne[0] == (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2)); GGML_ASSERT(model.dense_weight->ne[1] == MNIST_NCLASSES); GGML_ASSERT(model.dense_weight->ne[2] == 1); GGML_ASSERT(model.dense_weight->ne[3] == 1); - model.dense_bias = ggml_get_tensor(model.ctx_weight, "dense.bias"); + model.dense_bias = ggml_get_tensor(model.ctx_gguf, "dense.bias"); GGML_ASSERT(model.dense_bias->type == GGML_TYPE_F32); GGML_ASSERT(model.dense_bias->ne[0] == MNIST_NCLASSES); GGML_ASSERT(model.dense_bias->ne[1] == 1); @@ -301,19 +216,29 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str } else { fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str()); } - model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend); - if(!load_from_gguf(fname.c_str(), model.ctx_weight, ctx)) { + model.buf_gguf = ggml_backend_alloc_ctx_tensors(model.ctx_gguf, model.backends[0]); + + if(!load_from_gguf(fname.c_str(), model.ctx_gguf, ctx)) { fprintf(stderr, "%s: loading weights from %s failed\n", __func__, fname.c_str()); exit(1); } + // The space in ctx_gguf exactly fits the model weights, + // the images (which also need to be statically allocated) need to be put in a different context. + + model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL); + ggml_set_name(model.images, "images"); + ggml_set_input(model.images); + + model.buf_static = ggml_backend_alloc_ctx_tensors(model.ctx_static, model.backends[0]); + fprintf(stderr, "%s: successfully loaded weights from %s\n", __func__, fname.c_str()); return model; } -mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend) { - mnist_model model(backend); +mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend, const int nbatch_logical, const int nbatch_physical) { + mnist_model model(backend, nbatch_logical, nbatch_physical); model.arch = arch; std::random_device rd{}; @@ -324,10 +249,10 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string if (model.arch == "mnist-fc") { fprintf(stderr, "%s: initializing random weights for a fully connected model\n", __func__); - model.fc1_weight = ggml_new_tensor_2d(model.ctx_weight, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NHIDDEN); - model.fc1_bias = ggml_new_tensor_1d(model.ctx_weight, GGML_TYPE_F32, MNIST_NHIDDEN); - model.fc2_weight = ggml_new_tensor_2d(model.ctx_weight, GGML_TYPE_F32, MNIST_NHIDDEN, MNIST_NCLASSES); - model.fc2_bias = ggml_new_tensor_1d(model.ctx_weight, GGML_TYPE_F32, MNIST_NCLASSES); + model.fc1_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NHIDDEN); + model.fc1_bias = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32, MNIST_NHIDDEN); + model.fc2_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NHIDDEN, MNIST_NCLASSES); + model.fc2_bias = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32, MNIST_NCLASSES); ggml_set_name(model.fc1_weight, "fc1.weight"); ggml_set_name(model.fc1_bias, "fc1.bias"); @@ -339,12 +264,12 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string init_tensors.push_back(model.fc2_weight); init_tensors.push_back(model.fc2_bias); } else if (model.arch == "mnist-cnn") { - model.conv1_kernel = ggml_new_tensor_4d(model.ctx_weight, GGML_TYPE_F32, 3, 3, 1, MNIST_CNN_NCB); - model.conv1_bias = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, 1, 1, MNIST_CNN_NCB); - model.conv2_kernel = ggml_new_tensor_4d(model.ctx_weight, GGML_TYPE_F32, 3, 3, MNIST_CNN_NCB, MNIST_CNN_NCB*2); - model.conv2_bias = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, 1, 1, MNIST_CNN_NCB*2); - model.dense_weight = ggml_new_tensor_2d(model.ctx_weight, GGML_TYPE_F32, (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), MNIST_NCLASSES); - model.dense_bias = ggml_new_tensor_1d(model.ctx_weight, GGML_TYPE_F32, MNIST_NCLASSES); + model.conv1_kernel = ggml_new_tensor_4d(model.ctx_static, GGML_TYPE_F32, 3, 3, 1, MNIST_CNN_NCB); + model.conv1_bias = ggml_new_tensor_3d(model.ctx_static, GGML_TYPE_F32, 1, 1, MNIST_CNN_NCB); + model.conv2_kernel = ggml_new_tensor_4d(model.ctx_static, GGML_TYPE_F32, 3, 3, MNIST_CNN_NCB, MNIST_CNN_NCB*2); + model.conv2_bias = ggml_new_tensor_3d(model.ctx_static, GGML_TYPE_F32, 1, 1, MNIST_CNN_NCB*2); + model.dense_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), MNIST_NCLASSES); + model.dense_bias = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32, MNIST_NCLASSES); ggml_set_name(model.conv1_kernel, "conv1.kernel"); ggml_set_name(model.conv1_bias, "conv1.bias"); @@ -363,7 +288,11 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str()); } - model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend); + model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL); + ggml_set_name(model.images, "images"); + ggml_set_input(model.images); + + model.buf_static = ggml_backend_alloc_ctx_tensors(model.ctx_static, model.backends[0]); for (ggml_tensor * t : init_tensors) { GGML_ASSERT(t->type == GGML_TYPE_F32); @@ -379,20 +308,13 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string return model; } -void mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical) { - model.nbatch_logical = nbatch_logical; - model.nbatch_physical = nbatch_physical; - +void mnist_model_build(mnist_model & model) { if (model.arch == "mnist-fc") { ggml_set_param(model.ctx_compute, model.fc1_weight); ggml_set_param(model.ctx_compute, model.fc1_bias); ggml_set_param(model.ctx_compute, model.fc2_weight); ggml_set_param(model.ctx_compute, model.fc2_bias); - model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch_physical); - ggml_set_name(model.images, "images"); - ggml_set_input(model.images); - ggml_tensor * fc1 = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute, ggml_mul_mat(model.ctx_compute, model.fc1_weight, model.images), model.fc1_bias)); @@ -407,12 +329,10 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int ggml_set_param(model.ctx_compute, model.dense_weight); ggml_set_param(model.ctx_compute, model.dense_bias); - model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch_physical); - ggml_set_name(model.images, "images"); - ggml_set_input(model.images); + struct ggml_tensor * images_2D = ggml_reshape_4d(model.ctx_compute, model.images, MNIST_HW, MNIST_HW, 1, model.images->ne[1]); struct ggml_tensor * conv1_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute, - ggml_conv_2d(model.ctx_compute, model.conv1_kernel, model.images, 1, 1, 1, 1, 1, 1), + ggml_conv_2d(model.ctx_compute, model.conv1_kernel, images_2D, 1, 1, 1, 1, 1, 1), model.conv1_bias)); GGML_ASSERT(conv1_out->ne[0] == MNIST_HW); GGML_ASSERT(conv1_out->ne[1] == MNIST_HW); @@ -459,212 +379,35 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int GGML_ASSERT(model.logits->ne[1] == model.nbatch_physical); GGML_ASSERT(model.logits->ne[2] == 1); GGML_ASSERT(model.logits->ne[3] == 1); - - model.probs = ggml_soft_max(model.ctx_compute, model.logits); - ggml_set_name(model.probs, "probs"); - ggml_set_output(model.probs); - GGML_ASSERT(model.probs->type == GGML_TYPE_F32); - GGML_ASSERT(model.probs->ne[0] == MNIST_NCLASSES); - GGML_ASSERT(model.probs->ne[1] == model.nbatch_physical); - GGML_ASSERT(model.probs->ne[2] == 1); - GGML_ASSERT(model.probs->ne[3] == 1); - - model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch_physical); - ggml_set_name(model.labels, "labels"); - ggml_set_input(model.labels); - - model.loss = ggml_cross_entropy_loss(model.ctx_compute, model.logits, model.labels); - ggml_set_name(model.loss, "loss"); - ggml_set_output(model.loss); - ggml_set_loss(model.loss); - GGML_ASSERT(model.loss->type == GGML_TYPE_F32); - GGML_ASSERT(model.loss->ne[0] == 1); - GGML_ASSERT(model.loss->ne[1] == 1); - GGML_ASSERT(model.loss->ne[2] == 1); - GGML_ASSERT(model.loss->ne[3] == 1); - - model.pred = ggml_argmax(model.ctx_compute, model.logits); - ggml_set_name(model.pred, "predictions"); - ggml_set_output(model.pred); - GGML_ASSERT(model.pred->type == GGML_TYPE_I32); - GGML_ASSERT(model.pred->ne[0] == model.nbatch_physical); - GGML_ASSERT(model.pred->ne[1] == 1); - GGML_ASSERT(model.pred->ne[2] == 1); - GGML_ASSERT(model.pred->ne[3] == 1); - - model.acc_count = ggml_count_equal(model.ctx_compute, model.pred, ggml_argmax(model.ctx_compute, model.labels)); - ggml_set_name(model.acc_count, "accuracy_count"); - ggml_set_output(model.acc_count); - GGML_ASSERT(model.acc_count->type == GGML_TYPE_I64); - GGML_ASSERT(model.acc_count->ne[0] == 1); - GGML_ASSERT(model.acc_count->ne[1] == 1); - GGML_ASSERT(model.acc_count->ne[2] == 1); - GGML_ASSERT(model.acc_count->ne[3] == 1); } -mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset) { - mnist_eval_result result; - - struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute); - // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed. - ggml_build_forward_expand(gf, model.loss); - ggml_build_forward_expand(gf, model.pred); - ggml_build_forward_expand(gf, model.acc_count); +ggml_opt_result_t mnist_model_eval(mnist_model & model, ggml_opt_dataset_t dataset) { + ggml_opt_result_t result = ggml_opt_result_init(); - model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend); + ggml_opt_params params = ggml_opt_default_params(model.backend_sched, model.ctx_compute, model.images, model.logits, GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); + params.build_type = GGML_OPT_BUILD_TYPE_FORWARD; + ggml_opt_context_t opt_ctx = ggml_opt_init(params); { const int64_t t_start_us = ggml_time_us(); - float tmp_loss; - std::vector tmp_pred(model.nbatch_physical); - int64_t tmp_acc_count; - - GGML_ASSERT(sizeof(tmp_loss) == ggml_nbytes(model.loss)); - GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred)); - GGML_ASSERT(sizeof(tmp_acc_count) == ggml_nbytes(model.acc_count)); - - GGML_ASSERT(dataset.nex % model.nbatch_physical == 0); - const int nbatches = dataset.nex/model.nbatch_physical; - for (int ibatch = 0; ibatch < nbatches; ++ibatch) { - dataset.get_batch(model.images, model.labels, ibatch); - - ggml_backend_graph_compute(model.backend, gf); - - ggml_backend_tensor_get(model.loss, &tmp_loss, 0, ggml_nbytes(model.loss)); - ggml_backend_tensor_get(model.pred, tmp_pred.data(), 0, ggml_nbytes(model.pred)); - ggml_backend_tensor_get(model.acc_count, &tmp_acc_count, 0, ggml_nbytes(model.acc_count)); - - result.loss.push_back(tmp_loss); - result.pred.insert(result.pred.end(), tmp_pred.begin(), tmp_pred.end()); - result.ncorrect += tmp_acc_count; - result.ntotal += model.nbatch_physical; - } + ggml_opt_epoch(opt_ctx, dataset, nullptr, result, /*idata_split =*/ 0, nullptr, nullptr); const int64_t t_total_us = ggml_time_us() - t_start_us; const double t_total_ms = 1e-3*t_total_us; + const int nex = ggml_opt_dataset_data(dataset)->ne[1]; fprintf(stderr, "%s: model evaluation on %d images took %.2lf ms, %.2lf us/image\n", - __func__, (int)dataset.nex, t_total_ms, (double) t_total_us/dataset.nex); + __func__, nex, t_total_ms, (double) t_total_us/nex); } - result.success = true; + ggml_opt_free(opt_ctx); + return result; } -void mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split) { - const int64_t t_start_us = ggml_time_us(); - - const int opt_period = model.nbatch_logical / model.nbatch_physical; - const int nbatches_logical = dataset.nex / model.nbatch_logical; - const int nbatches_physical = dataset.nex / model.nbatch_physical; - const int ibatch_split = ((int)((1.0f - val_split)*nbatches_logical))*opt_period; // train <-> val split index (physical) - const int ishard_split = ibatch_split * model.nbatch_physical/dataset.shard_size; - - // gf == graph forward, forward pass only. - struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. - // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed. - ggml_build_forward_expand(gf, model.loss); - ggml_build_forward_expand(gf, model.pred); - ggml_build_forward_expand(gf, model.acc_count); - - // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. - struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf); - ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ opt_period > 1); - - // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. - struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad); - ggml_build_opt_adamw(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f); - - model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend); - ggml_graph_reset(gb_opt); // Set gradients to zero, reset optimizer. - - dataset.shuffle(-1); // Shuffle all data (train + validation). - - for (int epoch = 0; epoch < nepoch; ++epoch) { - fprintf(stderr, "%s: epoch %02d start...", __func__, epoch); - const int64_t t_start_us = ggml_time_us(); - - dataset.shuffle(ishard_split); // Shuffle only the training data, keeping training and validation set separate. - - int ibatch_physical = 0; - - float tmp_loss; - std::vector tmp_pred(model.nbatch_physical); - int64_t tmp_acc_count; - - GGML_ASSERT(sizeof(tmp_loss) == ggml_nbytes(model.loss)); - GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred)); - GGML_ASSERT(sizeof(tmp_acc_count) == ggml_nbytes(model.acc_count)); - - mnist_eval_result result_train; - for (; ibatch_physical < ibatch_split; ++ibatch_physical) { - dataset.get_batch(model.images, model.labels, ibatch_physical); - - // With a period of opt_period == nbatch_logical/nbatch_physical iterations: - if ((ibatch_physical + 1) % opt_period != 0) { - // For the first opt_period - 1 iterations, only calculate gradients and accumulate them: - ggml_backend_graph_compute(model.backend, gb_grad); - } else { - // For the last iteration, calculate gradients and also apply the optimizer: - ggml_backend_graph_compute(model.backend, gb_opt); // gb_opt contains all nodes of gb_grad so no extra call for gb_grad is needed. - ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer. - } - - ggml_backend_tensor_get(model.loss, &tmp_loss, 0, ggml_nbytes(model.loss)); - ggml_backend_tensor_get(model.pred, tmp_pred.data(), 0, ggml_nbytes(model.pred)); - ggml_backend_tensor_get(model.acc_count, &tmp_acc_count, 0, ggml_nbytes(model.acc_count)); - - result_train.loss.push_back(tmp_loss); - result_train.pred.insert(result_train.pred.end(), tmp_pred.begin(), tmp_pred.end()); - result_train.ncorrect += tmp_acc_count; - result_train.ntotal += model.nbatch_physical; - } - - mnist_eval_result result_val; - for (; ibatch_physical < nbatches_physical; ++ibatch_physical) { - dataset.get_batch(model.images, model.labels, ibatch_physical); - - ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed. - - ggml_backend_tensor_get(model.loss, &tmp_loss, 0, ggml_nbytes(model.loss)); - ggml_backend_tensor_get(model.pred, tmp_pred.data(), 0, ggml_nbytes(model.pred)); - ggml_backend_tensor_get(model.acc_count, &tmp_acc_count, 0, ggml_nbytes(model.acc_count)); - - result_val.loss.push_back(tmp_loss); - result_val.pred.insert(result_val.pred.end(), tmp_pred.begin(), tmp_pred.end()); - result_val.ncorrect += tmp_acc_count; - result_val.ntotal += model.nbatch_physical; - } - - { - const double loss_mean = mnist_loss(result_train).first; - const double percent_correct = 100.0 * mnist_accuracy(result_train).first; - - const int64_t t_epoch_us = ggml_time_us() - t_start_us; - const double t_epoch_s = 1e-6*t_epoch_us; - fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%", t_epoch_s, loss_mean, percent_correct); - } - - if (ibatch_split < nbatches_physical) { - const std::pair loss = mnist_loss(result_val); - const std::pair acc = mnist_accuracy(result_val); - - fprintf(stderr, ", val_loss=%.6lf+-%.6lf, val_acc=%.2f+-%.2f%%", loss.first, loss.second, 100.0*acc.first, 100.0*acc.second); - } - fprintf(stderr, "\n"); - } - - const int64_t t_total_us = ggml_time_us() - t_start_us; - const double t_total_s = 1e-6*t_total_us; - fprintf(stderr, "%s: training took %.2lfs\n", __func__, t_total_s); - - if (ggml_backend_is_cpu(model.backend)) { - std::string fname = model.arch + "-f32.ggml"; - fprintf(stderr, "%s: saving the GGML graph for the forward pass to %s\n", __func__, fname.c_str()); - ggml_graph_export(gf, fname.c_str()); - } else { - fprintf(stderr, "%s: not saving the GGML graph for the forward pass because this is only supported for the CPU backend\n", __func__); - } +void mnist_model_train(mnist_model & model, ggml_opt_dataset_t dataset, const int nepoch, const float val_split) { + ggml_opt_fit(model.backend_sched, model.ctx_compute, model.images, model.logits, dataset, + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, ggml_opt_get_default_optimizer_params, nepoch, model.nbatch_logical, val_split, false); } void mnist_model_save(mnist_model & model, const std::string & fname) { @@ -703,34 +446,6 @@ void mnist_model_save(mnist_model & model, const std::string & fname) { gguf_free(gguf_ctx); } -std::pair mnist_loss(const mnist_eval_result & result) { - const size_t nbatches = result.loss.size(); - GGML_ASSERT(nbatches >= 2); - - double sum = 0.0; - double sum_squared = 0.0; - - for (const float & loss : result.loss) { - sum += loss; - sum_squared += loss*loss; - } - - const double mean = sum/nbatches; - const double uncertainty = sqrt((sum_squared/nbatches - mean*mean) / (nbatches - 1)); - - return std::make_pair(mean, uncertainty); -} - -std::pair mnist_accuracy(const mnist_eval_result & result) { - GGML_ASSERT(result.ntotal >= result.ncorrect); - GGML_ASSERT(result.ntotal >= 2); - - const double fraction_correct = ((double) result.ncorrect) / ((double) result.ntotal); - const double uncertainty = sqrt(fraction_correct * (1.0 - fraction_correct) / (result.ncorrect - 1)); - - return std::make_pair(fraction_correct, uncertainty); -} - #ifdef __cplusplus extern "C" { #endif @@ -738,15 +453,19 @@ extern "C" { int wasm_eval(uint8_t * digitPtr) { std::vector digit(digitPtr, digitPtr + MNIST_NINPUT); - struct mnist_dataset dataset(1, 1); - memcpy(dataset.data->data, digitPtr, ggml_nbytes(dataset.data)); - ggml_set_zero(dataset.labels); // The labels are not needed. + ggml_opt_dataset_t dataset = ggml_opt_dataset_init(MNIST_NINPUT, MNIST_NCLASSES, 1, 1); + struct ggml_tensor * data = ggml_opt_dataset_data(dataset); + memcpy(data->data, digitPtr, ggml_nbytes(data)); + ggml_set_zero(ggml_opt_dataset_labels(dataset)); // The labels are not needed. + + mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU", /*nbatch_logical =*/ 1, /*nbatch_physical =*/ 1); + mnist_model_build(model); + ggml_opt_result_t result = mnist_model_eval(model, dataset); - mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU"); - mnist_model_build(model, 1, 1); - mnist_eval_result result = mnist_model_eval(model, dataset); + int32_t pred; + ggml_opt_result_pred(result, &pred); - return result.pred[0]; + return pred; } int wasm_random_digit(char * digitPtr) { diff --git a/examples/mnist/mnist-common.h b/examples/mnist/mnist-common.h index 22d59b989..090cf3715 100644 --- a/examples/mnist/mnist-common.h +++ b/examples/mnist/mnist-common.h @@ -9,11 +9,16 @@ #include "ggml-backend.h" #include "ggml.h" #include "ggml-cpu.h" +#include "ggml-opt.h" -#define MNIST_NTRAIN 60000 -#define MNIST_NTEST 10000 -#define MNIST_NBATCH_LOGICAL 1000 -#define MNIST_NBATCH_PHYSICAL 500 +#define MNIST_NTRAIN 60000 +#define MNIST_NTEST 10000 + +// Gradient accumulation can be achieved by setting the logical batch size to a multiple of the physical one. +// The logical batch size determines how many datapoints are used for a gradient update. +// The physical batch size determines how many datapoints are processed in parallel, larger values utilize compute better but need more memory. +#define MNIST_NBATCH_LOGICAL 1000 +#define MNIST_NBATCH_PHYSICAL 500 static_assert(MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL == 0, "MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL != 0"); static_assert(MNIST_NTRAIN % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0"); @@ -28,77 +33,15 @@ static_assert(MNIST_NTEST % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NB // NCB = number of channels base #define MNIST_CNN_NCB 8 -struct mnist_dataset { - struct ggml_context * ctx; - struct ggml_tensor * data; - struct ggml_tensor * labels; - - int64_t nex; - int64_t shard_size; - size_t nbs_data; - size_t nbs_labels; - - std::vector permutation; - std::mt19937 rng; - - mnist_dataset(const int64_t nex, const int64_t shard_size) : nex(nex), shard_size(shard_size) { - const size_t nbytes_images = nex*MNIST_NINPUT *sizeof(float) + ggml_tensor_overhead(); - const size_t nbytes_labels = nex*MNIST_NCLASSES*sizeof(float) + ggml_tensor_overhead(); - struct ggml_init_params params = { - /*.mem_size =*/ nbytes_images + nbytes_labels, - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ false, - }; - ctx = ggml_init(params); - - data = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, MNIST_HW, MNIST_HW, nex); - labels = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, MNIST_NCLASSES, nex); - - nbs_data = ggml_nbytes(data) * shard_size/nex; - nbs_labels = ggml_nbytes(labels) * shard_size/nex; - - permutation.resize(nex/shard_size); - for (size_t i = 0; i < permutation.size(); ++i) { - permutation[i] = i; - } - } - - ~mnist_dataset() { - ggml_free(ctx); - } - - void shuffle(const size_t ishard_max) { - if (ishard_max < permutation.size()) { - std::shuffle(permutation.begin(), permutation.begin() + ishard_max, rng); - return; - } - std::shuffle(permutation.begin(), permutation.end(), rng); - } - - void get_batch(struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, const int64_t ibatch) { - const int64_t shards_per_batch = ggml_nbytes(data_batch) / nbs_data; - for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { - const int64_t ishard = permutation[ibatch*shards_per_batch + ishard_batch]; - - ggml_backend_tensor_set(data_batch, (const char *) data->data + ishard*nbs_data, ishard_batch*nbs_data, nbs_data); - ggml_backend_tensor_set(labels_batch, (const char *) labels->data + ishard*nbs_labels, ishard_batch*nbs_labels, nbs_labels); - } - } -}; - struct mnist_model { std::string arch; - ggml_backend_t backend; - int nbatch_logical; - int nbatch_physical; - - struct ggml_tensor * images = nullptr; - struct ggml_tensor * labels = nullptr; - struct ggml_tensor * logits = nullptr; - struct ggml_tensor * probs = nullptr; - struct ggml_tensor * loss = nullptr; - struct ggml_tensor * pred = nullptr; - struct ggml_tensor * acc_count = nullptr; + ggml_backend_sched_t backend_sched; + std::vector backends; + const int nbatch_logical; + const int nbatch_physical; + + struct ggml_tensor * images = nullptr; + struct ggml_tensor * logits = nullptr; struct ggml_tensor * fc1_weight = nullptr; struct ggml_tensor * fc1_bias = nullptr; @@ -112,28 +55,66 @@ struct mnist_model { struct ggml_tensor * dense_weight = nullptr; struct ggml_tensor * dense_bias = nullptr; - struct ggml_context * ctx_weight = nullptr; + struct ggml_context * ctx_gguf = nullptr; + struct ggml_context * ctx_static = nullptr; struct ggml_context * ctx_compute = nullptr; - ggml_backend_buffer_t buf_weight = nullptr; - ggml_backend_buffer_t buf_compute = nullptr; - - mnist_model(const std::string & backend_name) { - ggml_backend_dev_t dev = ggml_backend_dev_by_name(backend_name.c_str()); - if (dev == nullptr) { - fprintf(stderr, "%s: ERROR: backend %s not found, available:\n", __func__, backend_name.c_str()); - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t this_dev = ggml_backend_dev_get(i); - fprintf(stderr, " - %s (%s)\n", ggml_backend_dev_name(this_dev), ggml_backend_dev_description(this_dev)); + ggml_backend_buffer_t buf_gguf = nullptr; + ggml_backend_buffer_t buf_static = nullptr; + + mnist_model(const std::string & backend_name, const int nbatch_logical, const int nbatch_physical) + : nbatch_logical(nbatch_logical), nbatch_physical(nbatch_physical) { + std::vector devices; + const int ncores_logical = std::thread::hardware_concurrency(); + const int nthreads = std::min(ncores_logical, (ncores_logical + 4) / 2); + + // Add primary backend: + if (!backend_name.empty()) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(backend_name.c_str()); + if (dev == nullptr) { + fprintf(stderr, "%s: ERROR: backend %s not found, available:\n", __func__, backend_name.c_str()); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev_i = ggml_backend_dev_get(i); + fprintf(stderr, " - %s (%s)\n", ggml_backend_dev_name(dev_i), ggml_backend_dev_description(dev_i)); + } + exit(1); } - exit(1); + + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + GGML_ASSERT(backend); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, nthreads); + } + + backends.push_back(backend); + devices.push_back(dev); } - fprintf(stderr, "%s: using %s (%s) backend\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev)); + // Add all available backends as fallback. + // A "backend" is a stream on a physical device so there is no problem with adding multiple backends for the same device. + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + GGML_ASSERT(backend); - backend = ggml_backend_dev_init(dev, NULL); - if (ggml_backend_is_cpu(backend)) { - const int ncores_logical = std::thread::hardware_concurrency(); - ggml_backend_cpu_set_n_threads(backend, std::min(ncores_logical, (ncores_logical + 4)/2)); + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, nthreads); + } + + backends.push_back(backend); + devices.push_back(dev); + } + + // The order of the backends passed to ggml_backend_sched_new determines which backend is given priority. + backend_sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), GGML_DEFAULT_GRAPH_SIZE, false); + fprintf(stderr, "%s: using %s (%s) as primary backend\n", + __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0])); + if (backends.size() >= 2) { + fprintf(stderr, "%s: unsupported operations will be executed on the following fallback backends (in order of priority):\n", __func__); + for (size_t i = 1; i < backends.size(); ++i) { + fprintf(stderr, "%s: - %s (%s)\n", __func__, ggml_backend_name(backends[i]), ggml_backend_dev_description(devices[i])); + } } { @@ -143,7 +124,7 @@ struct mnist_model { /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - ctx_weight = ggml_init(params); + ctx_static = ggml_init(params); } { @@ -159,36 +140,26 @@ struct mnist_model { } ~mnist_model() { - ggml_free(ctx_weight); + ggml_free(ctx_gguf); + ggml_free(ctx_static); ggml_free(ctx_compute); - ggml_backend_buffer_free(buf_weight); - ggml_backend_buffer_free(buf_compute); - ggml_backend_free(backend); + ggml_backend_buffer_free(buf_gguf); + ggml_backend_buffer_free(buf_static); + ggml_backend_sched_free(backend_sched); + for (ggml_backend_t backend : backends) { + ggml_backend_free(backend); + } } }; -struct mnist_eval_result { - bool success = false; - - std::vector loss; - std::vector pred; - int64_t ncorrect = 0; - int64_t ntotal = 0; -}; +bool mnist_image_load(const std::string & fname, ggml_opt_dataset_t dataset); +void mnist_image_print(FILE * f, ggml_opt_dataset_t dataset, const int iex); +bool mnist_label_load(const std::string & fname, ggml_opt_dataset_t dataset); -bool mnist_image_load(const std::string & fname, mnist_dataset & dataset); -void mnist_image_print(FILE * f, mnist_dataset & dataset, const int iex); -bool mnist_label_load(const std::string & fname, mnist_dataset & dataset); - -mnist_eval_result mnist_graph_eval(const std::string & fname, const float * images, const float * labels, const int nex, const int nthreads); - -mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend); -mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend); -void mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical); -mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset); -void mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split); +mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend, const int nbatch_logical, const int nbatch_physical); +mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend, const int nbatch_logical, const int nbatch_physical); +void mnist_model_build(mnist_model & model); +ggml_opt_result_t mnist_model_eval(mnist_model & model, ggml_opt_dataset_t dataset); +void mnist_model_train(mnist_model & model, ggml_opt_dataset_t dataset, const int nepoch, const float val_split); void mnist_model_save(mnist_model & model, const std::string & fname); - -std::pair mnist_loss(const mnist_eval_result & result); -std::pair mnist_accuracy(const mnist_eval_result & result); diff --git a/examples/mnist/mnist-eval.cpp b/examples/mnist/mnist-eval.cpp index 125bbb8cb..218742cfe 100644 --- a/examples/mnist/mnist-eval.cpp +++ b/examples/mnist/mnist-eval.cpp @@ -1,4 +1,5 @@ #include "ggml.h" +#include "ggml-opt.h" #include "mnist-common.h" @@ -24,7 +25,7 @@ int main(int argc, char ** argv) { exit(1); } - struct mnist_dataset dataset(/*nex =*/ MNIST_NTEST, /*shard_size =*/ MNIST_NBATCH_PHYSICAL); + ggml_opt_dataset_t dataset = ggml_opt_dataset_init(MNIST_NINPUT, MNIST_NCLASSES, MNIST_NTEST, MNIST_NBATCH_PHYSICAL); if (!mnist_image_load(argv[2], dataset)) { return 1; @@ -36,46 +37,31 @@ int main(int argc, char ** argv) { const int iex = rand() % MNIST_NTEST; mnist_image_print(stdout, dataset, iex); - const std::string backend = argc >= 5 ? argv[4] : "CPU"; - - mnist_eval_result result_eval; - - if (backend == "CPU") { - const int ncores_logical = std::thread::hardware_concurrency(); - result_eval = mnist_graph_eval( - argv[1], ggml_get_data_f32(dataset.data), ggml_get_data_f32(dataset.labels), MNIST_NTEST, std::min(ncores_logical, (ncores_logical + 4)/2)); - if (result_eval.success) { - fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]); - - std::pair result_loss = mnist_loss(result_eval); - fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second); - - std::pair result_acc = mnist_accuracy(result_eval); - fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second); - - return 0; - } - } else { - fprintf(stdout, "%s: not trying to load a GGML graph from %s because this is only supported for the CPU backend\n", __func__, argv[1]); - } + const std::string backend = argc >= 5 ? argv[4] : ""; const int64_t t_start_us = ggml_time_us(); + mnist_model model = mnist_model_init_from_file(argv[1], backend, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL); + mnist_model_build(model); + const int64_t t_load_us = ggml_time_us() - t_start_us; + fprintf(stdout, "%s: loaded model in %.2lf ms\n", __func__, t_load_us / 1000.0); - mnist_model model = mnist_model_init_from_file(argv[1], backend); + ggml_opt_result_t result_eval = mnist_model_eval(model, dataset); - mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL); + std::vector pred(MNIST_NTEST); + ggml_opt_result_pred(result_eval, pred.data()); + fprintf(stdout, "%s: predicted digit is %d\n", __func__, pred[iex]); - const int64_t t_load_us = ggml_time_us() - t_start_us; - - fprintf(stdout, "%s: loaded model in %.2lf ms\n", __func__, t_load_us / 1000.0); - result_eval = mnist_model_eval(model, dataset); - fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]); + double loss; + double loss_unc; + ggml_opt_result_loss(result_eval, &loss, &loss_unc); + fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, loss, loss_unc); - std::pair result_loss = mnist_loss(result_eval); - fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second); + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(result_eval, &accuracy, &accuracy_unc); + fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*accuracy, 100.0*accuracy_unc); - std::pair result_acc = mnist_accuracy(result_eval); - fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second); + ggml_opt_result_free(result_eval); return 0; } diff --git a/examples/mnist/mnist-train.cpp b/examples/mnist/mnist-train.cpp index 161dcf80f..a61dd05b0 100644 --- a/examples/mnist/mnist-train.cpp +++ b/examples/mnist/mnist-train.cpp @@ -1,3 +1,4 @@ +#include "ggml-opt.h" #include "mnist-common.h" #include @@ -19,7 +20,7 @@ int main(int argc, char ** argv) { // The MNIST model is so small that the overhead from data shuffling is non-negligible, especially with CUDA. // With a shard size of 10 this overhead is greatly reduced at the cost of less shuffling (does not seem to have a significant impact). // A batch of 500 images then consists of 50 random shards of size 10 instead of 500 random shards of size 1. - struct mnist_dataset dataset(/*nex =*/ MNIST_NTRAIN, /*shard_size =*/ 10); + ggml_opt_dataset_t dataset = ggml_opt_dataset_init(MNIST_NINPUT, MNIST_NCLASSES, MNIST_NTRAIN, /*ndata_shard =*/ 10); if (!mnist_image_load(argv[3], dataset)) { return 1; @@ -28,9 +29,9 @@ int main(int argc, char ** argv) { return 1; } - mnist_model model = mnist_model_init_random(argv[1], argc >= 6 ? argv[5] : "CPU"); + mnist_model model = mnist_model_init_random(argv[1], argc >= 6 ? argv[5] : "", MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL); - mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL); + mnist_model_build(model); mnist_model_train(model, dataset, /*nepoch =*/ 30, /*val_split =*/ 0.05f); diff --git a/include/ggml-backend.h b/include/ggml-backend.h index 0a65dbfca..cef164764 100644 --- a/include/ggml-backend.h +++ b/include/ggml-backend.h @@ -86,7 +86,7 @@ extern "C" { GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - // "offset" refers to the offset of the tensor data for setting/getting data + // "offset" refers to the offset in tensor->data for setting/getting data GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); @@ -242,14 +242,20 @@ extern "C" { ggml_backend_sched_reserve(sched, reserve_graph); // compute - graph = build_graph(sched); - ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation + for (int i = 0; i < 10; ++i) { + ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically + } // if there are graph inputs: - ggml_backend_sched_reset(sched); - ggml_backend_sched_alloc_graph(sched, graph); - ggml_backend_tensor_set(input_tensor, ...); - ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called) + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it + ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, graph); // execute the graph + + // as an alternative to the above it is also possible to assign the inputs to a dedicated context and + // allocate them statically via ggml_backend_alloc_ctx_tensors } */ @@ -264,7 +270,7 @@ extern "C" { // typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); - // Initialize a backend scheduler + // Initialize a backend scheduler, backends with low index are given priority over backends with high index GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); @@ -289,7 +295,9 @@ extern "C" { GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); - // Reset all assignments and allocators - must be called before changing the node backends + // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph. + // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers. + // The correct way to use this API is to discard the deallocated tensors and create new ones. GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); // Set a callback to be called for each resulting node during graph compute diff --git a/include/ggml-opt.h b/include/ggml-opt.h new file mode 100644 index 000000000..eb5eab9de --- /dev/null +++ b/include/ggml-opt.h @@ -0,0 +1,216 @@ +// This file contains functionality for training models using GGML. +// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets. +// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + + struct ggml_opt_dataset; + struct ggml_opt_context; + struct ggml_opt_result; + + typedef struct ggml_opt_dataset * ggml_opt_dataset_t; + typedef struct ggml_opt_context * ggml_opt_context_t; + typedef struct ggml_opt_result * ggml_opt_result_t; + + // ====== Loss ====== + + // built-in loss types, i.e. the built-in quantities minimized by the optimizer + // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value + enum ggml_opt_loss_type { + GGML_OPT_LOSS_TYPE_MEAN, + GGML_OPT_LOSS_TYPE_SUM, + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, + GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, + }; + + // ====== Dataset ====== + + GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); + + // get underlying tensors that store the data + GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] + GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] + + // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative + GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata); + + // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch + GGML_API void ggml_opt_dataset_get_batch( + ggml_opt_dataset_t dataset, + struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] + struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] + int64_t ibatch); + + // ====== Model / Context ====== + + enum ggml_opt_build_type { + GGML_OPT_BUILD_TYPE_FORWARD, + GGML_OPT_BUILD_TYPE_GRAD, + GGML_OPT_BUILD_TYPE_OPT, + }; + + // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss + struct ggml_opt_optimizer_params { + // AdamW optimizer parameters + struct { + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float wd; // weight decay for AdamW, use 0.0f to disable + } adamw; + }; + + // callback to calculate optimizer parameters prior to a backward pass + // userdata can be used to pass arbitrary data + typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); + + // returns the default optimizer params (constant) + // userdata is not used + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + + // parameters for initializing a new optimization context + struct ggml_opt_params { + ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs + + struct ggml_context * ctx_compute; // created in user code, holds non-static tensors + + // the forward graph is defined by inputs and outputs + // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; + + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + + int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + // get parameters for an optimization context with defaults set where possible + // parameters for which no sensible defaults exist are supplied as arguments to this function + GGML_API ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + struct ggml_context * ctx_compute, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs, + enum ggml_opt_loss_type loss_type); + + GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); + GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); + + // set gradients to zero, initilize loss, and optionally reset the optimizer + GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + + // get underlying tensors that store data + GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor + GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor + GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against + GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss + GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs + GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + + GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); + + // ====== Optimization Result ====== + + GGML_API ggml_opt_result_t ggml_opt_result_init(); + GGML_API void ggml_opt_result_free(ggml_opt_result_t result); + GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); + + // get data from result, uncertainties are optional and can be ignored by passing NULL + GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints + GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value + GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values + GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value + + // ====== Computation ====== + + // do forward pass, increment result if not NULL + GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + + // do forward pass, increment result if not NULL, do backward pass + GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + + // ############################################################################ + // ## The high-level functions start here. They do not depend on any private ## + // ## functions or structs and can be copied to and adapted for user code. ## + // ############################################################################ + + // ====== Intended Usage ====== + // + // 1. Select the appropriate loss for your problem. + // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them. + // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster). + // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors. + // The first context should contain the model parameters and inputs and be allocated statically in user code. + // The second context should contain all other tensors and will be (re)allocated automatically. + // Due to this automated allocation the data of the second context is not defined when accessed in user code. + // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors. + // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead. + + // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation + typedef void (*ggml_opt_epoch_callback)( + bool train, // true after training evaluation, false after validation evaluation + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, // result associated with the dataset subsection + int64_t ibatch, // number of batches that have been evaluated so far + int64_t ibatch_max, // total number of batches in this dataset subsection + int64_t t_start_us); // time at which the evaluation on the dataset subsection was started + + // do training on front of dataset, do evaluation only on back of dataset + GGML_API void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, // result to increment during training, ignored if NULL + ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL + int64_t idata_split, // data index at which to split training and evaluation + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + // callback that prints a progress bar on stderr + GGML_API void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us); + + // fit model defined by inputs and outputs to dataset + GGML_API void ggml_opt_fit( + ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs + ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + ggml_opt_dataset_t dataset, // dataset with data and optionally also labels + enum ggml_opt_loss_type loss_type, // loss to minimize + ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) + int64_t nepoch, // how many times the dataset should be iterated over + int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs + float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) + bool silent); // whether or not info prints to stderr should be suppressed + +#ifdef __cplusplus +} +#endif diff --git a/include/ggml.h b/include/ggml.h index 3b3f6798a..69e6a2434 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -602,7 +602,6 @@ extern "C" { int32_t flags; - struct ggml_tensor * grad; struct ggml_tensor * src[GGML_MAX_SRC]; // source tensor and offset for views @@ -615,7 +614,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - // char padding[4]; + char padding[8]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -1985,28 +1984,20 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad, - float alpha, - float beta1, - float beta2, - float eps, - float wd); // weight decay + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params); // parameters such a the learning rate // // automatic differentiation // - GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); - GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate); - - GGML_API void ggml_build_opt_adamw( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - float alpha, - float beta1, - float beta2, - float eps, - float wd); // weight decay + GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand( + struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) + struct ggml_context * ctx_compute, // context for gradient computation + struct ggml_cgraph * cgraph, + bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static // graph allocation in a context GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false @@ -2026,7 +2017,9 @@ extern "C" { GGML_API size_t ggml_graph_overhead(void); GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); - GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); + GGML_API struct ggml_tensor * ggml_graph_get_tensor (const struct ggml_cgraph * cgraph, const char * name); + GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); + GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); @@ -2037,198 +2030,15 @@ extern "C" { // dump the graph into a file using the dot format GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); - // build gradient checkpointing backward graph gb for gf using provided checkpoints - // gb_tmp will contain original backward graph with rewritten backward process nodes, - // but without the second forward pass nodes. - GGML_API void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints); - // - // optimization - // - - // optimization methods - enum ggml_opt_type { - GGML_OPT_TYPE_ADAM, - GGML_OPT_TYPE_LBFGS, - }; - - // linesearch methods - enum ggml_linesearch { - GGML_LINESEARCH_DEFAULT = 1, - - GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, - GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, - GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, - }; - - // optimization return values - enum ggml_opt_result { - GGML_OPT_RESULT_OK = 0, - GGML_OPT_RESULT_DID_NOT_CONVERGE, - GGML_OPT_RESULT_NO_CONTEXT, - GGML_OPT_RESULT_INVALID_WOLFE, - GGML_OPT_RESULT_FAIL, - GGML_OPT_RESULT_CANCEL, - - GGML_LINESEARCH_FAIL = -128, - GGML_LINESEARCH_MINIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_ITERATIONS, - GGML_LINESEARCH_INVALID_PARAMETERS, - }; - - typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); // Set callback for all future logging events. // If this is not called, or NULL is supplied, everything is output on stderr. GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data); - // optimization parameters - // - // see ggml.c (ggml_opt_default_params) for default values - // - struct ggml_opt_params { - enum ggml_opt_type type; - - size_t graph_size; - - int n_threads; - - // delta-based convergence test - // - // if past == 0 - disabled - // if past > 0: - // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) - // - int past; - float delta; - - // maximum number of iterations without improvement - // - // if 0 - disabled - // if > 0: - // assume convergence if no cost improvement in this number of iterations - // - int max_no_improvement; - - bool print_forward_graph; - bool print_backward_graph; - - int n_gradient_accumulation; - - // ADAM parameters - struct { - int n_iter; - - float sched; // schedule multiplier (fixed, decay or warmup) - float decay; // weight decay for AdamW, use 0.0f to disable - int decay_min_ndim; // minimum number of tensor dimension to apply weight decay - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - float gclip; // gradient clipping - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum ggml_linesearch linesearch; - } lbfgs; - }; - - struct ggml_opt_context { - struct ggml_context * ctx; - struct ggml_opt_params params; - - int iter; - int64_t nx; // number of parameter elements - - bool just_initialized; - - float loss_before; - float loss_after; - - struct { - struct ggml_tensor * g; // current gradient - struct ggml_tensor * m; // first moment - struct ggml_tensor * v; // second moment - struct ggml_tensor * pf; // past function values - float fx_best; - float fx_prev; - int n_no_improvement; - } adam; - - struct { - struct ggml_tensor * x; // current parameters - struct ggml_tensor * xp; // previous parameters - struct ggml_tensor * g; // current gradient - struct ggml_tensor * gp; // previous gradient - struct ggml_tensor * d; // search direction - struct ggml_tensor * pf; // past function values - struct ggml_tensor * lmal; // the L-BFGS memory alpha - struct ggml_tensor * lmys; // the L-BFGS memory ys - struct ggml_tensor * lms; // the L-BFGS memory s - struct ggml_tensor * lmy; // the L-BFGS memory y - float fx_best; - float step; - int j; - int k; - int end; - int n_no_improvement; - } lbfgs; - }; - GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); - GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); - - // optimize the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f); - - // initialize optimizer context - GGML_API void ggml_opt_init( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume_g( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data); - // // quantization // diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 71934c679..ae7d3abc8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -207,9 +207,11 @@ add_library(ggml-base ../include/ggml-alloc.h ../include/ggml-backend.h ../include/ggml-cpp.h + ../include/ggml-opt.h ggml.c ggml-alloc.c ggml-backend.cpp + ggml-opt.cpp ggml-threading.cpp ggml-threading.h ggml-quants.c diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c index 041de9e3e..2b2240be8 100644 --- a/src/ggml-alloc.c +++ b/src/ggml-alloc.c @@ -466,18 +466,12 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) { return ggml_gallocr_hash_get(galloc, t)->allocated; } -static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) { - struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - hn->buffer_id = buffer_id; - hn->offset = offset; - hn->allocated = true; -} - static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) { return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated; } static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) { + GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { @@ -816,7 +810,11 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * } static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) { - size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + size_t node_size = 0; + if (!node->data && !node->view_src) { + GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API + node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + } return talloc->size_max >= node_size; } diff --git a/src/ggml-backend.cpp b/src/ggml-backend.cpp index e48877ba8..634fe38ee 100644 --- a/src/ggml-backend.cpp +++ b/src/ggml-backend.cpp @@ -279,7 +279,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz buf->iface.get_tensor(buf, tensor, data, offset, size); } -GGML_API void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { +void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (size == 0) { diff --git a/src/ggml-cpu/ggml-cpu.c b/src/ggml-cpu/ggml-cpu.c index 61f53cd01..df6487929 100644 --- a/src/ggml-cpu/ggml-cpu.c +++ b/src/ggml-cpu/ggml-cpu.c @@ -12216,11 +12216,16 @@ static void ggml_compute_forward_opt_step_adamw_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src0_grad = dst->src[1]; - const struct ggml_tensor * src0_grad_m = dst->src[2]; - const struct ggml_tensor * src0_grad_v = dst->src[3]; + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src0_grad = dst->src[1]; + const struct ggml_tensor * src0_grad_m = dst->src[2]; + const struct ggml_tensor * src0_grad_v = dst->src[3]; + const struct ggml_tensor * adamw_params = dst->src[4]; + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); const int ith = params->ith; const int nth = params->nth; @@ -12237,16 +12242,14 @@ static void ggml_compute_forward_opt_step_adamw_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - /* const float gnorm = 1.0f; */ - int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); - const float alpha = ggml_get_op_params_f32(dst, 2); - const float beta1 = ggml_get_op_params_f32(dst, 3); - const float beta2 = ggml_get_op_params_f32(dst, 4); - const float eps = ggml_get_op_params_f32(dst, 5); - const float wd = ggml_get_op_params_f32(dst, 6); - - const float beta1h = alpha/(1.0f - powf(beta1, iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); + const float * adamw_params_ptr = ggml_get_data_f32(adamw_params); + const float alpha = adamw_params_ptr[0]; + const float beta1 = adamw_params_ptr[1]; + const float beta2 = adamw_params_ptr[2]; + const float eps = adamw_params_ptr[3]; + const float wd = adamw_params_ptr[4]; + const float beta1h = adamw_params_ptr[5]; + const float beta2h = adamw_params_ptr[6]; for (int ir = ir0; ir < ir1; ++ir) { const int64_t i03 = ir/(ne02*ne01); @@ -12270,17 +12273,9 @@ static void ggml_compute_forward_opt_step_adamw_f32( // The weight decay is applied independently of the Adam momenta m and v. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. // See: https://arxiv.org/pdf/1711.05101v3.pdf - w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh; + w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh; } } - - ggml_barrier(params->threadpool); - if (ith != 0) { - return; - } - - iter++; - memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); } static void ggml_compute_forward_opt_step_adamw( diff --git a/src/ggml-cuda/opt-step-adamw.cu b/src/ggml-cuda/opt-step-adamw.cu index d6f13a9c6..35154f299 100644 --- a/src/ggml-cuda/opt-step-adamw.cu +++ b/src/ggml-cuda/opt-step-adamw.cu @@ -1,11 +1,11 @@ +#include "ggml-impl.h" #include "opt-step-adamw.cuh" #include static __global__ void opt_step_adamw_f32( - float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k, - const float alpha, const float beta1, const float beta2, const float eps, const float wd, - const float beta1h, const float beta2h) { + float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, + const float * __restrict__ pars, const int64_t k) { const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; @@ -13,6 +13,14 @@ static __global__ void opt_step_adamw_f32( return; } + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + const float gi = g[i]; const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1); const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2); @@ -23,58 +31,48 @@ static __global__ void opt_step_adamw_f32( const float mh = gmi*beta1h; const float vh = sqrtf(gvi*beta2h) + eps; - x[i] = x[i]*(1.0f - alpha*wd) - mh/vh; + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; } static void opt_step_adamw_f32_cuda( - float * x, const float * g, float * g_m, float * g_v, const int64_t k, - const float alpha, const float beta1, const float beta2, const float eps, const float wd, - const float beta1h, const float beta2h, cudaStream_t stream) { + float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) { const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); - opt_step_adamw_f32<<>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h); + opt_step_adamw_f32<<>>(x, g, g_m, g_v, pars, k); } void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src0_grad = dst->src[1]; - const ggml_tensor * src0_grad_m = dst->src[2]; - const ggml_tensor * src0_grad_v = dst->src[3]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); - GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32); - GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * src0_grad_m = dst->src[2]; + const ggml_tensor * src0_grad_v = dst->src[3]; + const ggml_tensor * adamw_params = dst->src[4]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0_grad)); GGML_ASSERT(ggml_is_contiguous(src0_grad_m)); GGML_ASSERT(ggml_is_contiguous(src0_grad_v)); + GGML_ASSERT(ggml_is_contiguous(adamw_params)); GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); - float * src0_d = (float *) src0->data; - const float * src0_grad_d = (const float *) src0_grad->data; - float * src0_grad_m_d = (float *) src0_grad_m->data; - float * src0_grad_v_d = (float *) src0_grad_v->data; + float * src0_d = (float *) src0->data; + const float * src0_grad_d = (const float *) src0_grad->data; + float * src0_grad_m_d = (float *) src0_grad_m->data; + float * src0_grad_v_d = (float *) src0_grad_v->data; + const float * adamw_params_d = (const float *) adamw_params->data; cudaStream_t stream = ctx.stream(); const int64_t ne = ggml_nelements(src0); - int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); - float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float)); - float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float)); - float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float)); - float eps; memcpy(&eps, &dst->op_params[5], sizeof(float)); - float wd; memcpy(&wd, &dst->op_params[6], sizeof(float)); - - const float beta1h = alpha/(1.0f - powf(beta1, iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); - - opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream); - - iter++; - memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); + opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream); } diff --git a/src/ggml-impl.h b/src/ggml-impl.h index aa4d2b85d..92a64fe5a 100644 --- a/src/ggml-impl.h +++ b/src/ggml-impl.h @@ -196,7 +196,7 @@ void ggml_hash_set_reset(struct ggml_hash_set * hash_set); static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); // returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key); // returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); @@ -210,7 +210,7 @@ static inline size_t ggml_hash(const struct ggml_tensor * p) { return (size_t)(uintptr_t)p >> 4; } -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) { size_t h = ggml_hash(key) % hash_set->size; // linear probing @@ -281,13 +281,14 @@ enum ggml_cgraph_eval_order { }; struct ggml_cgraph { - int size; - int n_nodes; - int n_leafs; - - struct ggml_tensor ** nodes; - struct ggml_tensor ** grads; - struct ggml_tensor ** leafs; + int size; // maximum number of nodes/leafs/grads/grad_accs + int n_nodes; // number of nodes currently in use + int n_leafs; // number of leafs currently in use + + struct ggml_tensor ** nodes; // tensors with data that can change if the graph is evaluated + struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes + struct ggml_tensor ** grad_accs; // accumulators for node gradients + struct ggml_tensor ** leafs; // tensors with constant data struct ggml_hash_set visited_hash_set; diff --git a/src/ggml-metal/ggml-metal.m b/src/ggml-metal/ggml-metal.m index b4b5cfd26..95b21fbf9 100644 --- a/src/ggml-metal/ggml-metal.m +++ b/src/ggml-metal/ggml-metal.m @@ -3639,6 +3639,12 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) return ctx->all_data; } +static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + UNUSED(buffer); +} + static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); @@ -3671,7 +3677,7 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, /* .get_base = */ ggml_backend_metal_buffer_get_base, /* .init_tensor = */ NULL, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, diff --git a/src/ggml-opt.cpp b/src/ggml-opt.cpp new file mode 100644 index 000000000..808aa0d02 --- /dev/null +++ b/src/ggml-opt.cpp @@ -0,0 +1,867 @@ +#include "ggml-opt.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include + +struct ggml_opt_dataset { + struct ggml_context * ctx; + ggml_backend_buffer_t buf; + struct ggml_tensor * data; + struct ggml_tensor * labels; + + int64_t ndata; + int64_t ndata_shard; + size_t nbs_data; + size_t nbs_labels; + + std::vector permutation; +}; + +struct ggml_opt_context { + ggml_backend_sched_t backend_sched; + ggml_cgraph * allocated_graph; + ggml_cgraph * allocated_graph_copy; + struct ggml_context * ctx_static; + struct ggml_context * ctx_static_cpu; + struct ggml_context * ctx_compute; + struct ggml_context * ctx_copy; + ggml_backend_buffer_t buf_static; + ggml_backend_buffer_t buf_static_cpu; + std::mt19937 rng; + + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; + struct ggml_tensor * labels; + + struct ggml_tensor * loss; + struct ggml_tensor * pred; + struct ggml_tensor * ncorrect; + + struct ggml_cgraph * gf; + struct ggml_cgraph * gb_grad; + struct ggml_cgraph * gb_opt; + + int64_t iter; + int32_t opt_period; + int32_t opt_i; + bool loss_per_datapoint; + + ggml_opt_get_optimizer_params get_opt_pars; + void * get_opt_pars_ud; + struct ggml_tensor * adamw_params; +}; + +struct ggml_opt_result { + int64_t ndata = 0; + std::vector loss; + std::vector pred; + int64_t ncorrect = 0; + + bool loss_per_datapoint = false; + int64_t opt_period = -1; +}; + +// ====== Dataset ====== + +ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { + GGML_ASSERT(ne_datapoint > 0); + GGML_ASSERT(ne_label >= 0); + GGML_ASSERT(ndata > 0); + GGML_ASSERT(ndata_shard > 0); + + ggml_opt_dataset_t result = new ggml_opt_dataset; + result->ndata = ndata; + result->ndata_shard = ndata_shard; + + { + struct ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx = ggml_init(params); + } + + result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata); + result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; + + if (ne_label > 0) { + result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata); + result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; + } else { + result->labels = nullptr; + result->nbs_labels = 0; + } + + result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type()); + + const int64_t nshards = ndata/ndata_shard; + result->permutation.resize(nshards); + for (int64_t i = 0; i < nshards; ++i) { + result->permutation[i] = i; + } + return result; +} + +void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { + ggml_backend_buffer_free(dataset->buf); + ggml_free(dataset->ctx); + delete dataset; +} + +struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { + return dataset->data; +} + +struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) { + return dataset->labels; +} + +void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) { + GGML_ASSERT(idata <= dataset->ndata); + + if (idata < 0) { + std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng); + return; + } + + GGML_ASSERT(idata % dataset->ndata_shard == 0); + const int64_t ishard_max = idata / dataset->ndata_shard; + std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng); +} + +void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) { + GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); + GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + + const size_t nb_data_batch = ggml_nbytes(data_batch); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + if (labels_batch) { + const size_t nb_labels_batch = ggml_nbytes(labels_batch); + GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels); + } + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data; + ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels; + ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels); + } +} + +// ====== Model / Context ====== + +struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { + GGML_UNUSED(userdata); + + ggml_opt_optimizer_params result; + + result.adamw.alpha = 0.001f; + result.adamw.beta1 = 0.9f; + result.adamw.beta2 = 0.999f; + result.adamw.eps = 1e-8f; + result.adamw.wd = 0.0f; + + return result; +} + +struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + struct ggml_context * ctx_compute, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs, + enum ggml_opt_loss_type loss_type) { + return { + /*backend_sched =*/ backend_sched, + /*ctx_compute =*/ ctx_compute, + /*inputs =*/ inputs, + /*logits =*/ outputs, + /*loss_type =*/ loss_type, + /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, + /*opt_period =*/ 1, + /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, + /*get_opt_pars_ud =*/ nullptr, + }; +} + +static ggml_tensor * map_tensor(std::map & tensor_map, ggml_context * ctx, ggml_tensor * tensor) { + if (!tensor) { + return nullptr; + } + + if (tensor_map.find(tensor) != tensor_map.end()) { + return tensor_map[tensor]; + } + + ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor); + tensor_map[tensor] = new_tensor; + + new_tensor->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + new_tensor->nb[i] = tensor->nb[i]; + } + new_tensor->flags = tensor->flags; + memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params)); + strcpy(new_tensor->name, tensor->name); + new_tensor->data = tensor->data; + new_tensor->buffer = tensor->buffer; + new_tensor->extra = tensor->extra; + new_tensor->view_offs = tensor->view_offs; + new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src); + for (int i = 0; i < GGML_MAX_SRC; i++) { + new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]); + } + + return new_tensor; +} + +static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) { + std::map tensor_map; + + ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); + + for (int i = 0; i < graph->n_leafs; i++) { + ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i])); + } + for (int i = 0; i < graph->n_nodes; i++) { + ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i])); + } + for (int i = 0; i < graph->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]); + graph->grads[igrad_dst] = new_graph->grads[igrad_src]; + graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src]; + } + + return new_graph; +} + +static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) { + GGML_ASSERT(graph); + if (opt_ctx->allocated_graph == graph) { + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + } + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; +} + +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->allocated_graph = nullptr; + result->allocated_graph_copy = nullptr; + result->ctx_compute = params.ctx_compute; + result->ctx_copy = nullptr; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->iter = 1; + result->opt_period = params.opt_period; + result->opt_i = 0; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); + GGML_ASSERT(result->opt_period >= 1); + + const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD || + (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); + + ggml_set_input(result->inputs); + ggml_set_output(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + int n_param = 0; + for (int i = 0; i < result->gf->n_nodes; ++i) { + if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { + n_param++; + } + } + + { + // The static context is used for: + // - gradients (1 tensor per param if using gradient accumulation) + // - optimizer momenta (2 tensors per param) + // - labels + // - loss + its gradient (up to 5 tensors) + // - pred + // - ncorrect (2 tensors). + const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static = ggml_init(params); + } + { + // The static cpu context is used for: + // - optimizer parameters (1 for the entire context) + const size_t size_meta = 1 * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static_cpu = ggml_init(params); + } + + + switch (params.loss_type) { + case GGML_OPT_LOSS_TYPE_MEAN: { + result->labels = nullptr; + result->loss = ggml_sum(result->ctx_static, result->outputs); + ggml_set_name(result->loss, "loss_sum"); + const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); + result->loss = ggml_scale(result->ctx_static, result->loss, scale); + ggml_set_name(result->loss, "loss_mean"); + result->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_SUM: { + result->labels = nullptr; + result->loss = ggml_sum(result->ctx_static, result->outputs); + ggml_set_name(result->loss, "loss_sum"); + result->loss_per_datapoint = false; + break; + } + case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { + result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); + ggml_set_input(result->labels); + ggml_set_name(result->labels, "labels"); + result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); + ggml_set_name(result->loss, "loss_cross_entropy"); + if (result->opt_period > 1) { + result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); + ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + } + result->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { + result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); + ggml_set_input(result->labels); + ggml_set_name(result->labels, "labels"); + result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels); + ggml_set_name(result->loss, "loss_error"); + result->loss = ggml_sqr(result->ctx_static, result->loss); + ggml_set_name(result->loss, "loss_squared_error"); + result->loss = ggml_sum(result->ctx_static, result->loss); + ggml_set_name(result->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); + result->loss = ggml_scale(result->ctx_static, result->loss, scale); + ggml_set_name(result->loss, "loss_mean_squared_error"); + result->loss_per_datapoint = true; + break; + } + } + ggml_set_output(result->loss); + ggml_set_loss(result->loss); + ggml_build_forward_expand(result->gf, result->loss); + + result->pred = ggml_argmax(result->ctx_static, result->outputs); + ggml_set_name(result->pred, "pred"); + ggml_set_output(result->pred); + ggml_build_forward_expand(result->gf, result->pred); + + if (result->labels) { + result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels)); + ggml_set_name(result->ncorrect, "ncorrect"); + ggml_set_output(result->ncorrect); + ggml_build_forward_expand(result->gf, result->ncorrect); + } else { + result->ncorrect = nullptr; + } + + if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + result->gb_grad = nullptr; + result->gb_opt = nullptr; + + result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + result->buf_static_cpu = nullptr; + + ggml_opt_alloc_graph(result, result->gf); + + return result; + } + + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf); + ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + + if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) { + result->gb_opt = nullptr; + + result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + result->buf_static_cpu = nullptr; + + ggml_opt_alloc_graph(result, result->gb_grad); + ggml_graph_reset(result->gb_grad); + + return result; + } + + GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad); + + result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7); + ggml_set_input(result->adamw_params); + ggml_set_name(result->adamw_params, "adamw_params"); + + for (int i = result->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = result->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); + + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); + struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node); + struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); + ggml_build_forward_expand(result->gb_opt, opt_step); + } + } + + result->buf_static = ggml_backend_alloc_ctx_tensors( + result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + + result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type()); + + ggml_opt_alloc_graph(result, result->gb_opt); + ggml_graph_reset(result->gb_opt); + + return result; +} + +void ggml_opt_free(ggml_opt_context_t opt_ctx) { + if (opt_ctx == nullptr) { + return; + } + ggml_backend_buffer_free(opt_ctx->buf_static); + ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + ggml_free(opt_ctx->ctx_static); + ggml_free(opt_ctx->ctx_static_cpu); + delete opt_ctx; +} + +void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { + if (optimizer) { + ggml_graph_reset(opt_ctx->gb_opt); + opt_ctx->iter = 1; + } else { + ggml_graph_reset(opt_ctx->gb_grad); + } +} + +struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->inputs; +} + +struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->outputs; +} + +struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) { + return opt_ctx->labels; +} + +struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) { + return opt_ctx->loss; +} + +struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) { + return opt_ctx->pred; +} + +struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) { + return opt_ctx->ncorrect; +} + +struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) { + return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node); +} + +// ====== Optimization Result ====== + +ggml_opt_result_t ggml_opt_result_init() { + return new ggml_opt_result; +} + +void ggml_opt_result_free(ggml_opt_result_t result) { + delete result; +} + +void ggml_opt_result_reset(ggml_opt_result_t result) { + result->ndata = 0; + result->loss.clear(); + result->pred.clear(); + result->ncorrect = 0; +} + +void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) { + *ndata = result->ndata; +} + +void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) { + const int64_t nbatches = result->loss.size(); // Number of physical batches. + + if (nbatches == 0) { + *loss = 0.0; + *unc = NAN; + return; + } + + double sum = 0.0; + double sum_squared = 0.0; + + for (const float & loss : result->loss) { + // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch. + const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss; + sum += loss_scaled; + sum_squared += loss_scaled*loss_scaled; + } + + const double mean = sum/nbatches; + *loss = result->loss_per_datapoint ? mean : sum; + + if (!unc) { + return; + } + + if (nbatches < 2) { + *unc = NAN; + return; + } + + const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1) + *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1)); +} + +void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) { + for (size_t i = 0; i < result->pred.size(); ++i) { + pred[i] = result->pred[i]; + } +} + +void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) { + *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN; + + if (!unc) { + return; + } + + *unc = result->ncorrect >= 0 && result->ndata >= 2 ? + sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN; +} + +// ====== Computation ====== + +static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) { + if (graph != opt_ctx->gf) { + struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); + + GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.eps >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd <= 1.0f); + + // beta1, beta2 after applying warmup + const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); + const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); + + float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params); + adamw_par_data[0] = opt_pars.adamw.alpha; + adamw_par_data[1] = opt_pars.adamw.beta1; + adamw_par_data[2] = opt_pars.adamw.beta2; + adamw_par_data[3] = opt_pars.adamw.eps; + adamw_par_data[4] = opt_pars.adamw.wd; + adamw_par_data[5] = beta1h; + adamw_par_data[6] = beta2h; + } + + ggml_opt_alloc_graph(opt_ctx, graph); + ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + + if (!result) { + return; + } + + if (result->ndata == 0) { + result->loss_per_datapoint = opt_ctx->loss_per_datapoint; + result->opt_period = opt_ctx->opt_period; + } else { + GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint); + GGML_ASSERT(result->opt_period == opt_ctx->opt_period); + } + + const int64_t ndata = opt_ctx->outputs->ne[1]; + GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported"); + result->ndata += ndata; + + GGML_ASSERT(ggml_is_scalar(opt_ctx->loss)); + GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32); + float loss; + ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); + result->loss.push_back(loss); + + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + + if (!opt_ctx->labels || result->ncorrect < 0) { + result->ncorrect = -1; + return; + } + + GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect)); + GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64); + int64_t ncorrect; + ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect)); + result->ncorrect += ncorrect; +} + +void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); +} + +void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { + if (opt_ctx->opt_period == 1) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + return; + } + + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + if (opt_i_next == 0) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + ggml_opt_reset(opt_ctx, /*optimizer =*/ false); + } else { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); + } + opt_ctx->opt_i = opt_i_next; +} + +// ====== High-Level Functions ====== + +void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + struct ggml_tensor * data = ggml_opt_dataset_data(dataset); + GGML_ASSERT(data->ne[0] == inputs->ne[0]); + + const int64_t ndata = data->ne[1]; + const int64_t ndata_batch = inputs->ne[1]; + + GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0); + const int64_t nbatches = ndata/ndata_batch; + + idata_split = idata_split < 0 ? ndata : idata_split; + GGML_ASSERT(idata_split % ndata_batch == 0); + const int64_t ibatch_split = idata_split / ndata_batch; + + int64_t ibatch = 0; + int64_t t_loop_start = ggml_time_us(); + for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_forward_backward(opt_ctx, result_train); + if (callback_train) { + callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); + } + } + t_loop_start = ggml_time_us(); + for (; ibatch < nbatches; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_forward(opt_ctx, result_eval); + if (callback_eval) { + callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); + } + } +} + +void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us) { + fprintf(stderr, "%s[", train ? "train: " : "val: "); + + constexpr int64_t bar_length = 25; + for (int64_t j = 0; j < bar_length; ++j) { + const int64_t ibatch_j = ibatch_max * j/bar_length; + if (ibatch_j < ibatch) { + fprintf(stderr, "="); + } else if (ibatch_max * (j - 1)/bar_length < ibatch) { + fprintf(stderr, ">"); + } else { + fprintf(stderr, " "); + } + } + + const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1]; + const int64_t idata = ibatch*batch_size; + const int64_t idata_max = ibatch_max*batch_size; + + double loss; + double loss_unc; + ggml_opt_result_loss(result, &loss, &loss_unc); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc); + + const int64_t t_ibatch_us = ggml_time_us() - t_start_us; + int64_t t_ibatch_s = t_ibatch_us / 1000000; + const int64_t t_ibatch_h = t_ibatch_s / 3600; + t_ibatch_s -= t_ibatch_h * 3600; + const int64_t t_ibatch_m = t_ibatch_s / 60; + t_ibatch_s -= t_ibatch_m * 60; + + const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch; + int64_t t_eta_s = t_eta_us / 1000000; + const int64_t t_eta_h = t_eta_s / 3600; + t_eta_s -= t_eta_h * 3600; + const int64_t t_eta_m = t_eta_s / 60; + t_eta_s -= t_eta_m * 60; + + fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, + t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); + if (ibatch == ibatch_max) { + fprintf(stderr, "\n"); + } + fflush(stderr); + + GGML_UNUSED(dataset); +} + +void ggml_opt_fit( + ggml_backend_sched_t backend_sched, + ggml_context * ctx_compute, + ggml_tensor * inputs, + ggml_tensor * outputs, + ggml_opt_dataset_t dataset, + enum ggml_opt_loss_type loss_type, + ggml_opt_get_optimizer_params get_opt_pars, + int64_t nepoch, + int64_t nbatch_logical, + float val_split, + bool silent) { + ggml_time_init(); + const int64_t t_start_us = ggml_time_us(); + + const int64_t ndata = ggml_opt_dataset_data(dataset)->ne[1]; + const int64_t nbatch_physical = inputs->ne[1]; + GGML_ASSERT(ndata % nbatch_logical == 0); + GGML_ASSERT(nbatch_logical % nbatch_physical == 0); + + const int64_t opt_period = nbatch_logical / nbatch_physical; + const int64_t nbatches_logical = ndata / nbatch_logical; + + GGML_ASSERT(val_split >= 0.0f); + GGML_ASSERT(val_split < 1.0f); + const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical) + const int64_t idata_split = ibatch_split * nbatch_physical; + + int64_t epoch = 1; + + ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + params.opt_period = opt_period; + params.get_opt_pars = get_opt_pars; + params.get_opt_pars_ud = &epoch; + ggml_opt_context_t opt_ctx = ggml_opt_init(params); + + // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch. + if (nbatch_logical < ndata) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation). + } + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_val = ggml_opt_result_init(); + + ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar; + + for (; epoch <= nepoch; ++epoch) { + if (nbatch_logical < idata_split) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split); + } + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_val); + + if (!silent) { + fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch); + } + ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback); + if (!silent) { + fprintf(stderr, "\n"); + } + } + + if (!silent) { + int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000; + const int64_t t_total_h = t_total_s / 3600; + t_total_s -= t_total_h * 3600; + const int64_t t_total_m = t_total_s / 60; + t_total_s -= t_total_m * 60; + fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s); + } + + ggml_opt_free(opt_ctx); + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_val); +} diff --git a/src/ggml.c b/src/ggml.c index 5cdf59f25..4a478fcaa 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1592,14 +1592,13 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.op =*/ GGML_OP_NONE, /*.op_params =*/ { 0 }, /*.flags =*/ 0, - /*.grad =*/ NULL, /*.src =*/ { NULL }, /*.view_src =*/ view_src, /*.view_offs =*/ view_offs, /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - ///*.padding =*/ { 0 }, + /*.padding =*/ { 0 }, }; #ifdef __clang__ @@ -4194,8 +4193,6 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(mask); } - bool is_node = false; - // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); @@ -4203,8 +4200,7 @@ struct ggml_tensor * ggml_flash_attn_ext( float params[] = { scale, max_bias, logit_softcap }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_FLASH_ATTN_EXT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_FLASH_ATTN_EXT; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -4272,14 +4268,6 @@ struct ggml_tensor * ggml_flash_attn_back( GGML_ASSERT(ne2 % kvne2 == 0); - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - // when using this operation (in backwards pass) these grads are set. - // we don't want to create (big) grad of our result, so is_node is false. - is_node = false; - } - // store gradients of q, k and v as continuous tensors concatenated in result. // note: v and gradv are actually transposed, i.e. v->ne[0] != D. const int64_t elem_q = ggml_nelements(q); @@ -4302,8 +4290,7 @@ struct ggml_tensor * ggml_flash_attn_back( int32_t masked_i = masked ? 1 : 0; ggml_set_op_params(result, &masked_i, sizeof(masked_i)); - result->op = GGML_OP_FLASH_ATTN_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_FLASH_ATTN_BACK; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -4945,34 +4932,24 @@ struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params) { GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); GGML_ASSERT(ggml_are_same_shape(a, grad)); - GGML_ASSERT(alpha > 0.0f); - GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); - GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); - GGML_ASSERT(eps >= 0.0f); - GGML_ASSERT(wd >= 0.0f && wd <= 1.0f); + GGML_ASSERT(ggml_are_same_shape(a, m)); + GGML_ASSERT(ggml_are_same_shape(a, v)); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); struct ggml_tensor * result = ggml_view_tensor(ctx, a); - const int64_t iter = 1; - memcpy(&result->op_params[0], &iter, sizeof(int64_t)); - ggml_set_op_params_f32(result, 2, alpha); - ggml_set_op_params_f32(result, 3, beta1); - ggml_set_op_params_f32(result, 4, beta2); - ggml_set_op_params_f32(result, 5, eps); - ggml_set_op_params_f32(result, 6, wd); - result->op = GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; result->src[1] = grad; - result->src[2] = ggml_dup_tensor(ctx, grad); - result->src[3] = ggml_dup_tensor(ctx, grad); + result->src[2] = m; + result->src[3] = v; + result->src[4] = adamw_params; return result; } @@ -5041,1112 +5018,514 @@ static void ggml_hash_map_free(struct hash_map * map) { GGML_FREE(map); } -// gradient checkpointing - -static struct ggml_tensor * ggml_recompute_graph_node( - struct ggml_context * ctx, - struct ggml_cgraph * graph, - struct hash_map * replacements, - struct ggml_tensor * node) { - - if (node == NULL) { - return NULL; - } - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - return node; - } - - if (!ggml_hash_contains(&graph->visited_hash_set, node)) { - return node; - } - - int count_children = 0; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - if (node->src[k]) { - ++count_children; - } - } - - if (count_children == 0) { - return node; - } - - size_t i = ggml_hash_find(&replacements->set, node); - GGML_ASSERT(i != GGML_HASHSET_FULL); // assert that not full - if (replacements->set.keys[i] == node) { - return replacements->vals[i]; - } - - struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne); - - // insert clone into replacements - GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite - replacements->set.keys[i] = node; - replacements->vals[i] = clone; - - clone->op = node->op; - clone->grad = node->grad; - clone->flags = node->flags; - clone->extra = node->extra; - for (int k = 0; k < GGML_MAX_DIMS; ++k) { - clone->nb[k] = node->nb[k]; - } - for (int k = 0; k < GGML_MAX_SRC; ++k) { - clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); - } - if (node->view_src != NULL) { - clone->data = (node->view_src->data == NULL) - ? NULL // view_src not yet allocated - : (char *) node->view_src->data // view_src already allocated - + node->view_offs; - clone->view_src = node->view_src; - clone->view_offs = node->view_offs; - } - - GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); - GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); - memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); - ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); - - return clone; -} - -void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints) { - ggml_graph_cpy(gf, gb_tmp); - ggml_build_backward_expand(ctx, gf, gb_tmp, false); - - if (n_checkpoints <= 0) { - ggml_graph_cpy(gb_tmp, gb); - return; - } - - struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); - - // insert checkpoints in replacements - for (int i = 0; i < n_checkpoints; ++i) { - size_t k = ggml_hash_find(&replacements->set, checkpoints[i]); - GGML_ASSERT(k != GGML_HASHSET_FULL); // assert that not full - GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite - replacements->set.keys[k] = checkpoints[i]; - replacements->vals[k] = checkpoints[i]; - } - - ggml_graph_cpy(gf, gb); - // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], - // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), - // by recomputing them from checkpoints - for (int i = gf->n_nodes; in_nodes; ++i) { - struct ggml_tensor * node = gb_tmp->nodes[i]; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - // insert new tensors recomputing src, reusing already made replacements, - // remember replacements: remember new tensors with mapping from corresponding gf nodes - // recurse for input tensors, - // unless (i.e. terminating when) input tensors are replacements (like checkpoints) - node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); - } - // insert rewritten backward node with replacements made into resulting backward graph gb - ggml_build_forward_expand(gb, node); - } - - ggml_hash_map_free(replacements); -} - // utility functions to change gradients // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator // else if a is in zero_table, replace a // else, just add/subtract/etc. the gradients -static struct ggml_tensor * ggml_add_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (ggml_hash_contains(zero_table, a)) { - return b; +static void ggml_add_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = tensor; } - return ggml_add_impl(ctx, a, b, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct ggml_tensor * ggml_acc_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const size_t nb1, - const size_t nb2, - const size_t nb3, - const size_t offset, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (ggml_hash_contains(zero_table, a)) { - struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN - return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); +static void ggml_acc_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * src, + struct ggml_tensor * tensor, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]); + } else { + struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN + cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false); } - return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct ggml_tensor * ggml_add1_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (ggml_hash_contains(zero_table, a)) { - return ggml_repeat(ctx, b, a); +static void ggml_add1_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * src, + struct ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src); } - return ggml_add1_impl(ctx, a, b, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct ggml_tensor * ggml_sub_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; - } - if (ggml_hash_contains(zero_table, a)) { - return ggml_neg(ctx, b); +static void ggml_sub_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_neg(ctx, tensor); } - return ggml_sub_impl(ctx, a, b, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) { +static void ggml_compute_backward( + struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) { + struct ggml_tensor * tensor = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); + + if (!grad) { + return; + } + struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; struct ggml_tensor * src2 = tensor->src[2]; + struct ggml_hash_set * hash_set = &cgraph->visited_hash_set; + const size_t isrc0 = ggml_hash_find(hash_set, src0); + const size_t isrc1 = ggml_hash_find(hash_set, src1); + const size_t isrc2 = ggml_hash_find(hash_set, src2); + const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; + const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; + const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; switch (tensor->op) { - case GGML_OP_DUP: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_OP_ADD: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - if (ggml_are_same_shape(src0, src1)) { - src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } else { - src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table); - } - } - } break; - case GGML_OP_ADD1: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = ggml_add_or_set(ctx, - src1->grad, - ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table, acc_table); - } - } break; - case GGML_OP_ACC: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], - nb1, nb2, nb3, offset); - - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_SUB: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_OP_MUL: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, src1, tensor->grad), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_mul(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_DIV: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, tensor->grad, src1), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - ggml_sub_or_set(ctx, - src1->grad, - ggml_mul(ctx, - tensor->grad, - ggml_div(ctx, tensor, src1)), - zero_table, acc_table); - } - } break; - case GGML_OP_SQR: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_mul(ctx, src0, tensor->grad), - 2.0f), - zero_table, acc_table); - } - } break; - case GGML_OP_SQRT: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_div(ctx, - tensor->grad, - tensor), - 0.5f), - zero_table, acc_table); - } - } break; - case GGML_OP_LOG: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, - tensor->grad, - src0), - zero_table, acc_table); - } - } break; - case GGML_OP_SIN: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - tensor->grad, - ggml_cos(ctx, src0)), - zero_table, acc_table); - } - } break; - case GGML_OP_COS: - { - if (src0->grad) { - src0->grad = - ggml_sub_or_set(ctx, - src0->grad, - ggml_mul(ctx, - tensor->grad, - ggml_sin(ctx, src0)), - zero_table, acc_table); - } - } break; - case GGML_OP_SUM: - { - if (src0->grad) { - src0->grad = - ggml_add1_or_set(ctx, - src0->grad, - tensor->grad, - zero_table, acc_table); - } - } break; - case GGML_OP_SUM_ROWS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, - tensor->grad, - src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_COUNT_EQUAL: - { - GGML_ABORT("fatal error"); // TODO: implement - } - case GGML_OP_REPEAT: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_REPEAT_BACK: - { - if (src0->grad) { - // TODO: test this - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_CONCAT: - { - GGML_ABORT("fatal error"); // TODO: implement - } - case GGML_OP_SILU_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + case GGML_OP_DUP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_NORM: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_ADD: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_RMS_NORM: - { - // necessary for llama - if (src0->grad) { - float eps; - memcpy(&eps, tensor->op_params, sizeof(float)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table, acc_table); + if (src1_needs_grads) { + struct ggml_tensor * tmp = grad; + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); } - } break; - case GGML_OP_RMS_NORM_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc1, tmp); } - case GGML_OP_GROUP_NORM: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_ADD1: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_MUL_MAT: - { - // https://cs231n.github.io/optimization-2/#staged - // # forward pass - // s0 = np.random.randn(5, 10) - // s1 = np.random.randn(10, 3) - // t = s0.dot(s1) - - // # now suppose we had the gradient on t from above in the circuit - // dt = np.random.randn(*t.shape) # same shape as t - // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix - // ds1 = t.T.dot(dt) - - // tensor.shape [m,p,qq,rr] - // src0.shape [n,m,q1,r1] - // src1.shape [n,p,qq,rr] - - // necessary for llama - if (src0->grad) { - struct ggml_tensor * s1_tg = - ggml_out_prod(ctx, // [n,m,qq,rr] - src1, // [n,p,qq,rr] - tensor->grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); - } - src0->grad = - ggml_add_or_set(ctx, - src0->grad, // [n,m,q1,r1] - s1_tg, // [n,m,q1,r1] - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, // [n,p,qq,rr] - // ggml_mul_mat(ctx, // [n,p,qq,rr] - // ggml_cont(ctx, // [m,n,q1,r1] - // ggml_transpose(ctx, src0)), // [m,n,q1,r1] - // tensor->grad), // [m,p,qq,rr] - - // // when src0 is bigger than tensor->grad (this is mostly the case in llama), - // // avoid transpose of src0, rather transpose smaller tensor->grad - // // and then use ggml_out_prod - ggml_out_prod(ctx, // [n,p,qq,rr] - src0, // [n,m,q1,r1] - ggml_transpose(ctx, // [p,m,qq,rr] - tensor->grad)), // [m,p,qq,rr] - zero_table, acc_table); - } - } break; - case GGML_OP_MUL_MAT_ID: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean } - case GGML_OP_OUT_PROD: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_ACC: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_SCALE: - { - // necessary for llama - if (src0->grad) { - float s; - memcpy(&s, tensor->op_params, sizeof(float)); - - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale_impl(ctx, tensor->grad, s, false), - zero_table, acc_table); - } - } break; - case GGML_OP_SET: - { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = NULL; - - if (src0->grad || src1->grad) { - GGML_ASSERT(src0->type == tensor->type); - GGML_ASSERT(tensor->grad->type == tensor->type); - GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); - - tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - nb1, nb2, nb3, offset); - } + if (src1_needs_grads) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_acc_impl(ctx, - tensor->grad, - ggml_neg(ctx, tensor_grad_view), - nb1, nb2, nb3, offset, false), - zero_table, acc_table); - } + struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_CPY: - { - // necessary for llama - // cpy overwrites value of src1 by src0 and returns view(src1) - // the overwriting is mathematically equivalent to: - // tensor = src0 * 1 + src1 * 0 - if (src0->grad) { - // dsrc0 = dtensor * 1 - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - // dsrc1 = dtensor * 0 -> noop - } - } break; - case GGML_OP_CONT: - { - // same as cpy - if (src0->grad) { - GGML_ASSERT(ggml_is_contiguous(src0->grad)); - GGML_ASSERT(ggml_is_contiguous(tensor->grad)); - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_OP_RESHAPE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_reshape(ctx, - ggml_is_contiguous(tensor->grad) - ? tensor->grad - : ggml_cont(ctx, tensor->grad), - src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_VIEW: - { - // necessary for llama - if (src0->grad) { - size_t offset; - - memcpy(&offset, tensor->op_params, sizeof(offset)); - - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; - - if (src0->type != src0->grad->type) { - // gradient is typically F32, but src0 could be other type - size_t ng = ggml_element_size(src0->grad); - size_t n0 = ggml_element_size(src0); - GGML_ASSERT(offset % n0 == 0); - GGML_ASSERT(nb1 % n0 == 0); - GGML_ASSERT(nb2 % n0 == 0); - GGML_ASSERT(nb3 % n0 == 0); - offset = (offset / n0) * ng; - nb1 = (nb1 / n0) * ng; - nb2 = (nb2 / n0) * ng; - nb3 = (nb3 / n0) * ng; - } - - src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table); - } - } break; - case GGML_OP_PERMUTE: - { - // necessary for llama - if (src0->grad) { - int32_t * axes = (int32_t *) tensor->op_params; - int axis0 = axes[0] & 0x3; - int axis1 = axes[1] & 0x3; - int axis2 = axes[2] & 0x3; - int axis3 = axes[3] & 0x3; - int axes_backward[4] = {0,0,0,0}; - axes_backward[axis0] = 0; - axes_backward[axis1] = 1; - axes_backward[axis2] = 2; - axes_backward[axis3] = 3; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_permute(ctx, - tensor->grad, - axes_backward[0], - axes_backward[1], - axes_backward[2], - axes_backward[3]), - zero_table, acc_table); - } - } break; - case GGML_OP_TRANSPOSE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_transpose(ctx, tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_GET_ROWS: - { - // necessary for llama (only for tokenizer) - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - // last ggml_get_rows_back argument src0->grad is only - // necessary to setup correct output shape - ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table, acc_table); - } - if (src1->grad) { - // noop - } - } break; - case GGML_OP_GET_ROWS_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); } - case GGML_OP_DIAG: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SUB: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_DIAG_MASK_INF: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - /* ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); - } - } break; - case GGML_OP_DIAG_MASK_ZERO: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); - } - } break; - case GGML_OP_SOFT_MAX: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table, acc_table); - } - GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented"); - } break; - case GGML_OP_SOFT_MAX_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, grad); } - case GGML_OP_ROPE: - { - // necessary for llama - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_back(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow), - zero_table, acc_table); - } - GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented"); - } break; - case GGML_OP_ROPE_BACK: - { - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_impl(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - false), - zero_table, acc_table); + } break; + case GGML_OP_MUL: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); + } + if (src1_needs_grads) { + struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); } - } break; - case GGML_OP_CLAMP: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc1, tmp); } - case GGML_OP_CONV_TRANSPOSE_1D: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_DIV: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1)); } - case GGML_OP_IM2COL: - { - if (src1->grad) { - const int32_t s0 = ggml_get_op_params_i32(tensor, 0); - const int32_t s1 = ggml_get_op_params_i32(tensor, 1); - const int32_t p0 = ggml_get_op_params_i32(tensor, 2); - const int32_t p1 = ggml_get_op_params_i32(tensor, 3); - const int32_t d0 = ggml_get_op_params_i32(tensor, 4); - const int32_t d1 = ggml_get_op_params_i32(tensor, 5); - const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; - - src1->grad = ggml_add_or_set(ctx, - src1->grad, - ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), - zero_table, acc_table); - } - } break; - case GGML_OP_IM2COL_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1))); } - case GGML_OP_CONV_TRANSPOSE_2D: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SQR: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f)); } - case GGML_OP_POOL_1D: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SQRT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f)); } - case GGML_OP_POOL_2D: - { - if (src0->grad) { - const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); - const int32_t k0 = ggml_get_op_params_i32(tensor, 1); - const int32_t k1 = ggml_get_op_params_i32(tensor, 2); - const int32_t s0 = ggml_get_op_params_i32(tensor, 3); - const int32_t s1 = ggml_get_op_params_i32(tensor, 4); - const int32_t p0 = ggml_get_op_params_i32(tensor, 5); - const int32_t p1 = ggml_get_op_params_i32(tensor, 6); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), - zero_table, acc_table); - } - } break; - case GGML_OP_POOL_2D_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_LOG: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0)); } - case GGML_OP_UPSCALE: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SIN: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0))); } - case GGML_OP_PAD: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_COS: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0))); } - case GGML_OP_ARANGE: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SUM: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad); } - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SUM_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); } - case GGML_OP_ARGSORT: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_MEAN: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); } - case GGML_OP_LEAKY_RELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_REPEAT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0)); } - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ABORT("FA backward pass not adapted after rework"); - struct ggml_tensor * flash_grad = NULL; - if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - flash_grad = - ggml_flash_attn_back(ctx, - src0, - src1, - tensor->src[2], - tensor->grad, - masked); + } break; + case GGML_OP_REPEAT_BACK: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); + } + } break; + case GGML_OP_RMS_NORM: { + if (src0_needs_grads) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps)); + } + } break; + case GGML_OP_MUL_MAT: { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) + + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) + + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] + + if (src0_needs_grads) { + struct ggml_tensor * s1_tg = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = ggml_repeat_back(ctx, s1_tg, src0); } + ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // grad), // [m,p,qq,rr] + + // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // avoid transpose of src0, rather transpose smaller tensor->grad + // and then use ggml_out_prod + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + grad))); // [m,p,qq,rr] + } + } break; + case GGML_OP_SCALE: { + if (src0_needs_grads) { + float s; + memcpy(&s, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + } + } break; + case GGML_OP_SET: { + const size_t nb1 = ((const int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((const int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((const int32_t *) tensor->op_params)[2]; + const size_t offset = ((const int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = NULL; + + if (src0_needs_grads || src1_needs_grads) { + GGML_ASSERT(src0->type == tensor->type); + GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type); + GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type); + + tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); + } - const int64_t elem_q = ggml_nelements(src0); - const int64_t elem_k = ggml_nelements(src1); - const int64_t elem_v = ggml_nelements(src2); - - enum ggml_type result_type = flash_grad->type; - GGML_ASSERT(ggml_blck_size(result_type) == 1); - const size_t tsize = ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); - const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); - - if (src0->grad) { - struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q); - struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - grad_q, - zero_table, acc_table); - } - if (src1->grad) { - struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); - struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1); - src1->grad = ggml_add_or_set(ctx, - src1->grad, - grad_k, - zero_table, acc_table); - } - if (src2->grad) { - struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); - struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2); - src2->grad = ggml_add_or_set(ctx, - src2->grad, - grad_v, - zero_table, acc_table); + if (src0_needs_grads) { + struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false)); + } + + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); + } + } break; + case GGML_OP_CPY: { + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0_needs_grads) { + // dsrc0 = dtensor * 1 + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + if (src1_needs_grads) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case GGML_OP_CONT: { + // same as cpy + if (src0_needs_grads) { + GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); + GGML_ASSERT(ggml_is_contiguous(grad)); + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_OP_RESHAPE: { + if (src0_needs_grads) { + struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0)); + } + } break; + case GGML_OP_VIEW: { + if (src0_needs_grads) { + size_t offset; + + memcpy(&offset, tensor->op_params, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = ggml_element_size(cgraph->grads[isrc0]); + size_t n0 = ggml_element_size(src0); + GGML_ASSERT(offset % n0 == 0); + GGML_ASSERT(nb1 % n0 == 0); + GGML_ASSERT(nb2 % n0 == 0); + GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; } - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - GGML_ABORT("fatal error"); // not supported + + ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset); } - case GGML_OP_SSM_CONV: - case GGML_OP_SSM_SCAN: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_PERMUTE: { + if (src0_needs_grads) { + const int32_t * axes = (const int32_t *) tensor->op_params; + const int axis0 = axes[0] & 0x3; + const int axis1 = axes[1] & 0x3; + const int axis2 = axes[2] & 0x3; + const int axis3 = axes[3] & 0x3; + int axb[4] = {0,0,0,0}; // axes backward + axb[axis0] = 0; + axb[axis1] = 1; + axb[axis2] = 2; + axb[axis3] = 3; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3])); } + } break; + case GGML_OP_TRANSPOSE: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad)); + } + } break; + case GGML_OP_GET_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0)); + } + if (src1_needs_grads) { + // noop + } + } break; + case GGML_OP_DIAG_MASK_INF: { + if (src0_needs_grads) { + /* ggml_diag_mask_inf_impl() shouldn't be here */ + /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_DIAG_MASK_ZERO: { + if (src0_needs_grads) { + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_SOFT_MAX: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor)); + } + GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); + } break; + case GGML_OP_ROPE: { + if (src0_needs_grads) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((const int32_t *) tensor->op_params)[1]; + const int mode = ((const int32_t *) tensor->op_params)[2]; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + + memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); + + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow)); + } + GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); + } break; + case GGML_OP_IM2COL: { + if (src1_needs_grads) { + const int32_t s0 = ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = ggml_get_op_params_i32(tensor, 5); + const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; + + ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); + } + } break; + case GGML_OP_POOL_2D: { + if (src0_needs_grads) { + const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = ggml_get_op_params_i32(tensor, 6); + + ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1)); + } + } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: - case GGML_OP_UNARY: - { - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_ABS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_sgn(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_NEG: - { - if (src0->grad) { - src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_TANH: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_ELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_RELU: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_step(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_SIGMOID: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_GELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_GELU_QUICK: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_SILU: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_silu_back(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_EXP: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, tensor, tensor->grad), - zero_table, acc_table); - } - } break; - default: - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_GET_REL_POS: - case GGML_OP_ADD_REL_POS: - case GGML_OP_RWKV_WKV6: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: - case GGML_OP_MAP_CUSTOM1: - case GGML_OP_MAP_CUSTOM2: - case GGML_OP_MAP_CUSTOM3: - { - GGML_ABORT("fatal error"); // not supported - } - case GGML_OP_CROSS_ENTROPY_LOSS: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_cross_entropy_loss_back(ctx, - src0, - src1, - tensor->grad), - zero_table, acc_table); - } - GGML_ASSERT(!src1->grad && "backward pass for labels not implemented"); - } break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - GGML_ABORT("fatal error"); // not supported + case GGML_OP_UNARY: { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_ABS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SGN: { + // noop + } break; + case GGML_UNARY_OP_NEG: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_UNARY_OP_STEP: { + // noop + } break; + case GGML_UNARY_OP_RELU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SILU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad)); + } + } break; + case GGML_UNARY_OP_EXP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad)); + } + } break; + default: { + fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", + __func__, ggml_unary_op_name(ggml_get_unary_op(tensor))); + GGML_ABORT("fatal error"); + } break; } - case GGML_OP_OPT_STEP_ADAMW: - { - GGML_ABORT("fatal error"); // not supported + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad)); } - case GGML_OP_NONE: - { - // nop - } break; + GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); + } break; + case GGML_OP_NONE: { + // noop + } break; case GGML_OP_COUNT: - { - GGML_ABORT("fatal error"); - } + default: { + fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); + GGML_ABORT("fatal error"); + } break; } - for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (tensor->src[i] && tensor->src[i]->grad) { - GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); - } - } + GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0])); + GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1])); + GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != GGML_OP_NONE) { - //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - // check if already visited if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) { return; @@ -6207,18 +5586,42 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * ggml_build_forward_impl(cgraph, tensor, true); } -void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate) { - GGML_ASSERT(gf->n_nodes > 0); - GGML_ASSERT(gf->grads); +void ggml_build_backward_expand( + struct ggml_context * ctx_static, + struct ggml_context * ctx_compute, + struct ggml_cgraph * cgraph, + bool accumulate) { + GGML_ASSERT(cgraph->n_nodes > 0); + GGML_ASSERT(cgraph->grads); + GGML_ASSERT(cgraph->grad_accs); + + const int n_nodes_f = cgraph->n_nodes; - for (int i = 0; i < gf->n_nodes; ++i) { - struct ggml_tensor * node = gf->nodes[i]; + const size_t hash_size = ggml_hash_size(2*cgraph->size); + memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *)); + bool * grads_needed = calloc(hash_size, sizeof(bool)); + + { + bool any_params = false; + bool any_loss = false; + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; + any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM); + any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS); + } + GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?"); + GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?"); + } + + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; if (node->type == GGML_TYPE_I32) { continue; } - bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM; + bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM; bool ignore_src[GGML_MAX_SRC] = {false}; switch (node->op) { // gradients in node->src[0] for one reason or another have no effect on output gradients @@ -6246,14 +5649,14 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * break; } for (int j = 0; j < GGML_MAX_SRC; ++j) { - if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) { + if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) { continue; } GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16); - needs_grad = true; + node_needs_grad = true; break; } - if (!needs_grad) { + if (!node_needs_grad) { continue; } @@ -6261,73 +5664,21 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); - // create a new tensor with the same type and shape as the node and set it as grad - node->grad = ggml_dup_tensor(ctx, node); - } - - // keep tables of original gradients for replacement/accumulation logic - struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size); - struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size); - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->grad) { - { - const size_t insert_result = ggml_hash_insert(&zero_table, node->grad); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - } - - // only gradients of trainable parameters should be accumulated - if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) { - const size_t insert_result = ggml_hash_insert(&acc_table, node->grad); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - } + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node); + cgraph->grad_accs[igrad] = cgraph->grads[igrad]; } + grads_needed[igrad] = true; } - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_tensor * node = gf->nodes[i]; - + for (int i = n_nodes_f - 1; i >= 0; --i) { // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - if (node->grad) { - ggml_compute_backward(ctx, node, &zero_table, &acc_table); - } + ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); } - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_expand(gb, node->grad); - } - } - - ggml_hash_set_free(&zero_table); - ggml_hash_set_free(&acc_table); -} - -void ggml_build_opt_adamw( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); - ggml_build_forward_expand(gb, opt_step); - } - } + free(grads_needed); } static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { @@ -6345,7 +5696,8 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) { incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys if (grads) { - incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs } incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); @@ -6371,10 +5723,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz void * p = cgraph + 1; - struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); // check that we allocated the correct amount of memory @@ -6386,12 +5740,17 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.n_leafs =*/ 0, /*.nodes =*/ nodes_ptr, /*.grads =*/ grads_ptr, + /*.grad_accs =*/ grad_accs_ptr, /*.leafs =*/ leafs_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, }; ggml_hash_set_reset(&cgraph->visited_hash_set); + if (grads) { + memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *)); + } return cgraph; } @@ -6407,6 +5766,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.n_leafs =*/ 0, /*.nodes =*/ cgraph0->nodes + i0, /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, + /*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL, /*.leafs =*/ NULL, /*.hash_table =*/ { 0, NULL, NULL }, /*.order =*/ cgraph0->order, @@ -6432,19 +5792,23 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { dst->nodes[i] = src->nodes[i]; } - if (src->grads) { - GGML_ASSERT(dst->grads != NULL); - for (int i = 0; i < src->n_nodes; ++i) { - dst->grads[i] = src->grads[i]; - } - } - for (size_t i = 0; i < src->visited_hash_set.size; ++i) { // copy all hashset keys (tensors) that are in use if (ggml_bitset_get(src->visited_hash_set.used, i)) { ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); } } + + if (src->grads) { + GGML_ASSERT(dst->grads != NULL); + GGML_ASSERT(dst->grad_accs != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + } } struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { @@ -6470,29 +5834,36 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) { GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; + struct ggml_tensor * node = cgraph->nodes[i]; + struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node); + + if (node->op == GGML_OP_OPT_STEP_ADAMW) { + // clear momenta + if (node->src[2]->data) { + ggml_set_zero(node->src[2]); + } + if (node->src[3]->data) { + ggml_set_zero(node->src[3]); + } + } // initial gradients of loss should be 1, 0 otherwise - if (node->grad) { + if (grad_acc) { if (node->flags & GGML_TENSOR_FLAG_LOSS) { - GGML_ASSERT(node->grad->buffer); - GGML_ASSERT(node->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_scalar(node)); + GGML_ASSERT(grad_acc->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_scalar(grad_acc)); const float onef = 1.0f; - ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad)); + if (grad_acc->buffer) { + ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float)); + } else { + GGML_ASSERT(grad_acc->data); + *((float *) grad_acc->data) = onef; + } } else { - ggml_set_zero(node->grad); + ggml_set_zero(grad_acc); } } - - GGML_ASSERT(node); - if (node->op == GGML_OP_OPT_STEP_ADAMW) { - // set iteration to 1 and clear momenta - ggml_set_op_params_i32(node, 0, 1); - ggml_set_zero(node->src[2]); - ggml_set_zero(node->src[3]); - } } } @@ -6530,7 +5901,7 @@ void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tenso cgraph->n_nodes++; } -struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) { +struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) { for (int i = 0; i < cgraph->n_leafs; i++) { struct ggml_tensor * leaf = cgraph->leafs[i]; @@ -6550,6 +5921,16 @@ struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const ch return NULL; } +struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL; +} + +struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL; +} + void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_LOG_INFO("=== GRAPH ===\n"); @@ -6560,7 +5941,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", i, node->ne[0], node->ne[1], node->ne[2], - ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); + ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : + ggml_graph_get_grad(cgraph, node) ? "g" : " "); } GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); @@ -6595,8 +5977,9 @@ static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * parent = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent); - if (parent->grad == node) { + if (grad == node) { return parent; } } @@ -6636,6 +6019,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph for (int i = 0; i < gb->n_nodes; i++) { struct ggml_tensor * node = gb->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(gb, node); if (ggml_graph_get_parent(gb, node) != NULL) { continue; @@ -6643,7 +6027,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { + } else if (grad) { if (ggml_graph_find(gf, node)) { snprintf(color, sizeof(color), "green"); } else { @@ -6670,8 +6054,8 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); } - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(node->grad->op)); + if (grad) { + fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(grad->op)); } else { fprintf(fp, "\"; ]\n"); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f0a3386a0..e1ee785fa 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -176,23 +176,14 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") endif() # -# test-grad0 +# test-opt -set(TEST_TARGET test-grad0) +set(TEST_TARGET test-opt) add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) target_link_libraries(${TEST_TARGET} PRIVATE ggml) add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") -# -# test-opt - -# set(TEST_TARGET test-opt) -# add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) -# target_link_libraries(${TEST_TARGET} PRIVATE ggml) -# add_test(NAME ${TEST_TARGET} COMMAND $) -# set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") - # # test-quantize-fns @@ -268,36 +259,6 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml) add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") -# -# test1 - -set(TEST_TARGET test1) -add_executable(${TEST_TARGET} ${TEST_TARGET}.c) -target_link_libraries(${TEST_TARGET} PRIVATE ggml) -if (MSVC) - target_link_options(${TEST_TARGET} PRIVATE "/STACK: 8388608") # 8MB -endif() -add_test(NAME ${TEST_TARGET} COMMAND $) -set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") - -# -# test2 - -# set(TEST_TARGET test2) -# add_executable(${TEST_TARGET} ${TEST_TARGET}.c) -# target_link_libraries(${TEST_TARGET} PRIVATE ggml) -# add_test(NAME ${TEST_TARGET} COMMAND $) -# set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") - -# -# test3 - -# set(TEST_TARGET test3) -# add_executable(${TEST_TARGET} ${TEST_TARGET}.c) -# target_link_libraries(${TEST_TARGET} PRIVATE ggml) -# add_test(NAME ${TEST_TARGET} COMMAND $) -# set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") - # # test-pool diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6618d03d1..f8a59b6df 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -811,11 +811,11 @@ struct test_case { ggml_build_forward_expand(gf, out); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx, gf, gb, false); + ggml_build_backward_expand(ctx, ctx, gb, false); if (expect.size() != 1 || expect[0] != 0.0f) { GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf)); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || t->grad->op != GGML_OP_NONE); + GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE); } } @@ -862,7 +862,13 @@ struct test_case { const char * bn = ggml_backend_name(backend); const int64_t ne = ggml_nelements(t); - std::vector ga = tensor_to_float(t->grad); + std::vector ga; + struct ggml_tensor * grad = ggml_graph_get_grad(gb, t); + if (grad) { + ga = tensor_to_float(grad); + } else { + ga.resize(ne); // default value is 0.0f + } for (int64_t i = 0; i < ne; ++i) { // gradient algebraic // check for nans @@ -2500,6 +2506,35 @@ struct test_sum_rows : public test_case { } }; +// GGML_OP_MEAN +struct test_mean : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_mean(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_mean(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + float grad_eps() override { + return 0.1f * ne[0]*ne[1]*ne[2]*ne[3]; + } +}; + // GGML_OP_UPSCALE struct test_upscale : public test_case { const ggml_type type; @@ -2834,24 +2869,14 @@ struct test_cross_entropy_loss : public test_case { struct test_opt_step_adamw : public test_case { const ggml_type type; const std::array ne; - const float alpha; - const float beta1; - const float beta2; - const float eps; - const float wd; std::string vars() override { - return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd); + return VARS_TO_STR2(type, ne); } test_opt_step_adamw(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 5, 4, 3}, - float alpha = 1e-3f, - float beta1 = 0.9f, - float beta2 = 0.999f, - float eps = 1e-8f, - float wd = 0.0f) - : type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {} + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); @@ -2861,7 +2886,16 @@ struct test_opt_step_adamw : public test_case { ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); ggml_set_name(grad, "grad"); - ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd); + ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_name(grad_m, "grad_m"); + + ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_name(grad_v, "grad_v"); + + ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7); + ggml_set_name(adamw_params, "adamw_params"); + + ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params); ggml_set_name(out, "out"); return out; @@ -2869,7 +2903,7 @@ struct test_opt_step_adamw : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values. + init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values. } } @@ -3735,6 +3769,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); + test_cases.emplace_back(new test_mean()); test_cases.emplace_back(new test_upscale()); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); test_cases.emplace_back(new test_upscale_ext()); @@ -3766,9 +3801,7 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_cross_entropy_loss()); - for (float wd : {0.0f, 1e-2f}) { - test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd)); - } + test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); // these tests are disabled to save execution time, but they can be handy for debugging #if 0 @@ -3938,6 +3971,8 @@ int main(int argc, char ** argv) { ggml_backend_free(backend); } + ggml_quantize_free(); + printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count()); if (n_ok != ggml_backend_dev_count()) { @@ -3945,8 +3980,6 @@ int main(int argc, char ** argv) { return 1; } - ggml_quantize_free(); - printf("\033[1;32mOK\033[0m\n"); return 0; } diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp deleted file mode 100644 index c712dba7f..000000000 --- a/tests/test-grad0.cpp +++ /dev/null @@ -1,1684 +0,0 @@ -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows -#include "ggml.h" -#include "ggml-cpu.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Wdouble-promotion" -#endif - -#define MAX_NARGS 3 - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -#define GGML_SILU_FP16 - -// -// logging -// - -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - -static float frand(void) { - return (float)rand()/(float)RAND_MAX; -} - -static int irand(int n) { - if (n == 0) return 0; - return rand()%n; -} - -static void get_random_dims(int64_t * dims, int ndims) { - dims[0] = dims[1] = dims[2] = dims[3] = 1; - - for (int i = 0; i < ndims; i++) { - dims[i] = 1 + irand(4); - } -} - -static struct ggml_tensor * get_random_tensor_f32( - struct ggml_context * ctx0, - int ndims, - int64_t ne[], - float fmin, - float fmax) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne); - - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin; - } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - } - } - break; - default: - assert(false); - } - - return result; -} - -static struct ggml_tensor * get_random_tensor_f16( - struct ggml_context * ctx0, - int ndims, - int64_t ne[], - float fmin, - float fmax) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne); - - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - } - } - } - break; - default: - assert(false); - } - - return result; -} - -static struct ggml_tensor * get_random_tensor_i32( - struct ggml_context * ctx0, - int ndims, - int64_t ne[], - int32_t imin, - int32_t imax) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne); - - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i0] = irand(imax - imin) + imin; - } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin; - } - } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin; - } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin; - } - } - } - } - break; - default: - assert(false); - } - - return result; -} - -static bool check_gradient( - const char * op_name, - struct ggml_context * ctx0, - struct ggml_tensor * x[], - struct ggml_tensor * f, - int ndims, - int nargs, - float eps, - float max_error_abs, - float max_error_rel, - std::vector expected_vals) { - - static int n_threads = -1; - if (n_threads < 0) { - n_threads = GGML_DEFAULT_N_THREADS; - - const char *env = getenv("GGML_N_THREADS"); - if (env) { - n_threads = atoi(env); - } - - printf("GGML_N_THREADS = %d\n", n_threads); - } - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, f); - ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - - ggml_graph_reset(gb); - if (f->grad) { - ggml_set_f32(f->grad, 1.0f); - } - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - // ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot"); - // ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot"); - - for (int i = 0; i < nargs; ++i) { - bool all_g0_bad = true; - const int nelements = ggml_nelements(x[i]); - for (int k = 0; k < nelements; ++k) { - // Calculate gradient numerically: - const float x0 = ggml_get_f32_1d(x[i], k); - const float xm = x0 - eps; - const float xp = x0 + eps; - ggml_set_f32_1d(x[i], k, xp); - - ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - - const double f0 = ggml_get_f32_1d(f, 0); - - ggml_set_f32_1d(x[i], k, xm); - - ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - - const double f1 = ggml_get_f32_1d(f, 0); - const double g0 = (f0 - f1)/(2.0*(double) eps); - - // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU). - // In such cases, provide a vector of expected values and skip the comparison for failed calculations. - if (!expected_vals.empty()) { - bool matches_any = false; - for (const double & ev : expected_vals) { - const double error_abs = std::fabs(g0 - ev); - if (error_abs > max_error_abs) { - continue; - } - const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0; - if (error_rel > max_error_rel) { - continue; - } - matches_any = true; - break; - } - if (!matches_any) { - continue; - } - } - all_g0_bad = false; - - ggml_set_f32_1d(x[i], k, x0); - - // compute gradient using backward graph - ggml_graph_reset(gb); - if (f->grad) { - ggml_set_f32(f->grad, 1.0f); - } - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - const double g1 = ggml_get_f32_1d(x[i]->grad, k); - - const double error_abs = fabs(g0 - g1); - const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0; - - if (error_abs > max_error_abs || error_rel > max_error_rel) { - printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n", - op_name, ndims, i, k, x0, xm, xp, f0, f1, g0, g1, eps, error_abs, error_rel); - //assert(false); - return false; - } - } - if (all_g0_bad) { - printf("%s: numerical calculation of the gradient failed for all values\n", op_name); - return false; - } - } - - return true; -} - -// TODO: clean-up this .. -static bool check_mat_mul( - const struct ggml_tensor * y, - const struct ggml_tensor * x0, - const struct ggml_tensor * x1) { - float * dst = (float *) y->data; - float * src0 = (float *) x0->data; - float * src1 = (float *) x1->data; - - const int nc = x0->ne[1]; - const int nr = x1->ne[1]; - const int nk = x0->ne[0]; - - GGML_PRINT_DEBUG("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk); - - GGML_PRINT_DEBUG("x0:\n"); - for (int j = 0; j < x0->ne[1]; ++j) { - for (int i = 0; i < x0->ne[0]; ++i) { - GGML_PRINT_DEBUG("%6.3f ", src0[j*nk + i]); - } - GGML_PRINT_DEBUG("\n"); - } - GGML_PRINT_DEBUG("\n"); - - GGML_PRINT_DEBUG("x1:\n"); - for (int j = 0; j < x1->ne[1]; ++j) { - for (int i = 0; i < x1->ne[0]; ++i) { - GGML_PRINT_DEBUG("%6.3f ", src1[j*nk + i]); - } - GGML_PRINT_DEBUG("\n"); - } - GGML_PRINT_DEBUG("\n"); - - GGML_PRINT_DEBUG("y: n_dims = %d, (%lld, %lld)\n", y->n_dims, y->ne[0], y->ne[1]); - for (int j = 0; j < y->ne[1]; ++j) { - for (int i = 0; i < y->ne[0]; ++i) { - GGML_PRINT_DEBUG("%6.3f ", dst[j*nr + i]); - } - GGML_PRINT_DEBUG("\n"); - } - - for (int i = 0; i < nr; ++i) { - for (int j = 0; j < nc; ++j) { - float sum = 0.0f; - - for (int k = 0; k < nk; ++k) { - sum += src0[j*nk + k]*src1[i*nk + k]; - } - - if (fabsf(dst[i*nc + j] - sum) > 1e-5f) { - fprintf(stderr, "check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum); - assert(false); - return false; - } - } - } - - return true; -} - -#define NUM_PERMUTATIONS (4*3*2*1) - -int main(int argc, const char ** argv) { - struct ggml_init_params params = { - /* .mem_size = */ 256*1024*1024, - /* .mem_buffer = */ NULL, - /* .no_alloc = */ false, - }; - - int64_t ne[4]; - - int all_permutations[4 * NUM_PERMUTATIONS]; - { - int count = 0; - for (int ax0=0; ax0<4; ++ax0) { - for (int ax1=0; ax1<4; ++ax1) { - if (ax1 == ax0) continue; - for (int ax2=0; ax2<4; ++ax2) { - if (ax2 == ax0) continue; - if (ax2 == ax1) continue; - for (int ax3=0; ax3<4; ++ax3) { - if (ax3 == ax0) continue; - if (ax3 == ax1) continue; - if (ax3 == ax2) continue; - assert(count < NUM_PERMUTATIONS); - all_permutations[count*4+0] = ax0; - all_permutations[count*4+1] = ax1; - all_permutations[count*4+2] = ax2; - all_permutations[count*4+3] = ax3; - ++count; - } - } - } - } - } - - unsigned seed_iter = 1; - - // original loop: 1000 - int niter = 4; - const char *env = getenv("GGML_NLOOP"); - if (env != NULL) { - niter = atoi(env); - } - if (argc > 1) { - niter = atoi(argv[1]); - } - for (int iter = 0; iter < niter; ++iter) { - srand(seed_iter); - seed_iter = rand(); - unsigned seed = rand(); - - printf("test-grad0: iter:%d/%d\n", (iter+1), niter); - struct ggml_context * ctx0 = ggml_init(params); - - get_random_dims(ne, 4); - - struct ggml_tensor * x[MAX_NARGS]; - - // add f32 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - - check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {}); - } - } - - // add f16 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - - check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {}); - } - } - - // sub - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1])); - - check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // mul - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1])); - - check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // div - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1])); - - check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {}); - } - } - - // sqr - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0])); - - check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // sqrt - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0])); - - check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {}); - } - } - - // log - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0])); - - check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {}); - } - } - - // sum - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, x[0]); - - check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - - // sum_rows - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0]))); - - check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); - } - } - - // mean, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0])); - - check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // argmax - if (0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0])); - - check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // repeat - { - srand(seed); - int64_t ne2[4]; - get_random_dims(ne2, 4); - - ne2[0] = ne[0] * ne2[0]; - ne2[1] = ne[1] * ne2[1]; - ne2[2] = 1; - ne2[3] = 1; - - const int nargs = 1; - for (int ndims = 1; ndims <= 2; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1])))); - - check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); - } - } - - // repeat back - { - srand(seed); - int64_t ne2[4]; - get_random_dims(ne2, 4); - - ne2[0] = ne[0] * ne2[0]; - ne2[1] = ne[1] * ne2[1]; - ne2[2] = 1; - ne2[3] = 1; - - const int nargs = 1; - for (int ndims = 1; ndims <= 2; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0])))); - - check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); - } - } - - // abs - { - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0])); - - check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0}); - } - } - - // sgn - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0])); - - check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); - } - } - - // neg - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0])); - - check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // step - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0])); - - check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); - } - } - - // tanh, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0])); - - check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // mul_mat - { - srand(seed); - const int nargs = 2; - - for (int ndims = 2; ndims <= 4; ++ndims) { - int max_nrep = (ndims >= 3) ? 2 : 1; - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) { - for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) { - { - int64_t ne2[4]; - get_random_dims(ne2, 4); - ne2[0] = ne[0]; - ne2[2] = nrep2 * ne[2]; - ne2[3] = nrep3 * ne[3]; - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - } - - ggml_set_param(ctx0, x[0]); - ggml_set_param(ctx0, x[1]); - - struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]); - struct ggml_tensor * f = ggml_sum(ctx0, m); - - GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims); - - check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - if (ndims == 2) { - // check_mat_mul does not support ndims > 2 - check_mat_mul(m, x[1], x[0]); - } - } - } - } - } - - // elu, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0])); - - check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // relu - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0])); - - check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0}); - } - } - - // gelu, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0])); - - check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // silu - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0])); - -#ifdef GGML_SILU_FP16 - // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds. - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {}); -#else - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); -#endif - } - } - - // rms_norm - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f)); - - check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {}); - } - } - - // scale - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - const float s = -1.0f + 2.0f*frand(); - - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s)); - - check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // cpy f32 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - // x[1] is overwritten by x[0], so the gradients don't propagate to x[1] - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - - check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // cpy f16 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - // x[1] is overwritten by x[0], so the gradients don't propagate to x[1] - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - - check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); - } - } - - // reshape (1d->nd) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - int64_t ne2[4]; - ne2[0] = 1; - ne2[1] = 1; - ne2[2] = 1; - ne2[3] = 1; - for (int i = 0; i < ndims; ++i) { - ne2[0] *= ne[i]; - } - x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // reshape (nd->1d) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - int64_t ne2[4]; - ne2[0] = 1; - ne2[1] = 1; - ne2[2] = 1; - ne2[3] = 1; - for (int i = 0; i < ndims; ++i) { - ne2[0] *= ne[i]; - } - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 1d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - - const int nargs = 2; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 1); - while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 1); - } - - x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1])); - const int offset = irand(max_offset) * ggml_element_size(x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 2d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 2; - for (int ndims = 2; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 2); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 2); - } - - x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - const int offset = offsets[0] + offsets[1]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 3d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 2; - for (int ndims = 3; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 3); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 3); - } - - x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - offsets[2] = irand(max_offsets[2]) * x[0]->nb[2]; - const int offset = offsets[0] + offsets[1] + offsets[2]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 4d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 2; - for (int ndims = 4; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 4); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[3] > ne[3]) || (ne2[0]*ne2[1]*ne2[2]*ne2[3] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 4); - } - - x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]); - max_offsets[3] = MAX(0, x[0]->ne[3] - x[1]->ne[3]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - offsets[2] = irand(max_offsets[2]) * x[0]->nb[2]; - offsets[3] = irand(max_offsets[3]) * x[0]->nb[3]; - const int offset = offsets[0] + offsets[1] + offsets[2] + offsets[3]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // set_1d - { - srand(seed); - int64_t ne2[4]; - - const int nargs = 2; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 1); - while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 1); - } - - x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1])); - const int offset = irand(max_offset) * ggml_element_size(x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset)); - - check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // set_2d - { - srand(seed); - int64_t ne2[4]; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 1; - for (int ndims = 2; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 2); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 2); - } - - x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - const int offset = offsets[0] + offsets[1]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset)); - - check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // view_1d - { - srand(seed); - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - ggml_set_param(ctx0, x[0]); - - const int k0 = irand(ggml_nelements(x[0])); - const int k1 = irand(ggml_nelements(x[0])); - const int i0 = MIN(k0, k1); - const int i1 = MAX(k0, k1); - - const int offset = i0 * sizeof(float); - const int nelem = i1 - i0; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset)); - - check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // view_2d - { - srand(seed); - int64_t ne2[4]; - int64_t nb2[4]; - - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - get_random_dims(ne2, 2); - while (ne2[0]*ne2[1] > ggml_nelements(x[0])) { - get_random_dims(ne2, 2); - } - const int count = ne2[0]*ne2[1]; - - nb2[0] = sizeof(float); - nb2[1] = nb2[0]*ne2[0]; - - ggml_set_param(ctx0, x[0]); - - const int max_offset = ggml_nelements(x[0]) - count; - const int offset = irand(max_offset+1) * sizeof(float); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset)); - - check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // view_3d - { - srand(seed); - int64_t ne2[4] = {1,1,1,1}; - int64_t nb2[4] = {0,0,0,0}; - - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - get_random_dims(ne2, 3); - while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) { - get_random_dims(ne2, 3); - } - const int count = ne2[0]*ne2[1]*ne2[2]; - - nb2[0] = sizeof(float); - nb2[1] = nb2[0]*ne2[0]; - nb2[2] = nb2[1]*ne2[1]; - - ggml_set_param(ctx0, x[0]); - - const int max_offset = ggml_nelements(x[0]) - count; - const int offset = irand(max_offset+1) * sizeof(float); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset)); - - check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // permute - { - srand(seed); - int64_t ne2[4]; - - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) - { - // ggml_permute will set axes of dimensions below n_dims to 1. - // to make ggml_permute work correctly on all axes, - // the input tensor needs maximal n_dim of 4. - for (int i=0; i finite differences should not work - // instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0) - struct ggml_tensor * f = ggml_sum(ctx0, - ggml_log(ctx0, - ggml_add1(ctx0, - ggml_scale(ctx0, - ggml_soft_max(ctx0, x[0]), - 1.0f - eps), - ggml_new_f32(ctx0, eps)))); - - check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {}); - // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf. - // this may result in different gradients too finite differences. - // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause. - // if only the table lookup causes gradients to differ this is acceptable. - } - } - - // cross_entropy_loss - { - srand(seed); - const int nargs = 1; - - int64_t ne2[4]; - get_random_dims(ne2, 4); - - for (int ndims = 1; ndims <= 4; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f); - // the second argument to cross_entropy_loss must sum up to 1 for each row - int nr = ggml_nrows(x[1]); - int nc = ggml_nelements(x[1]) / nr; - for (int ir = 0; ir < nr; ++ir) { - float sum = 0; - for (int ic = 0; ic < nc; ++ic) { - sum += ((float *) x[1]->data)[ic + ir*nc]; - } - for (int ic = 0; ic < nc; ++ic) { - ((float *) x[1]->data)[ic + ir*nc] /= sum; - } - } - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]); - - check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // rope f32 - { - srand(seed); - const int nargs = 1; - - int64_t ne2[4]; - get_random_dims(ne2, 4); - ne2[0] += ne2[0] % 2; - int n_rot = ne2[0]; - - for (int ndims = 3; ndims <= 4; ++ndims) { - for (int mode = 0; mode < 4; ++mode) { - for (int n_past = 1; n_past < ne2[2]; ++n_past) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - - struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); - for (int i = 0; i < ne2[2]; ++i) { - ((int32_t *) p->data)[i] = n_past + i; - } - - ggml_set_param(ctx0, x[0]); - - const bool skip_past = (mode & 1); - if (skip_past) { - // we have no past, so this would have to work on uninitialized memory. - // we only test the gradients here; - // skip_past should have no influence on gradient computation. - // so when other modes work, we assume that this does as well. - continue; - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); - - GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); - } - } - } - } - - // rope f16 - { - srand(seed); - const int nargs = 1; - - int64_t ne2[4]; - get_random_dims(ne2, 4); - ne2[0] += ne2[0] % 2; - int n_rot = ne2[0]; - - for (int ndims = 3; ndims <= 4; ++ndims) { - for (int mode = 0; mode < 4; ++mode) { - for (int n_past = 1; n_past < ne2[2]; ++n_past) { - x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f); - - struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); - for (int i = 0; i < ne2[2]; ++i) { - ((int32_t *) p->data)[i] = n_past + i; - } - - ggml_set_param(ctx0, x[0]); - - const bool skip_past = (mode & 1); - if (skip_past) { - // we have no past, so this would have to work on uninitialized memory. - // we only test the gradients here; - // skip_past should have no influence on gradient computation. - // so when other modes work, we assume that this does as well. - continue; - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); - - GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); - } - } - } - } - - // im2col f32 - { - srand(seed); - const int nargs = 1; - const int ndims = 4; - - for (const bool is_2D : {false, true}) { - int64_t ne0[ndims]; - int64_t ne1[ndims]; - get_random_dims(ne0, ndims); - get_random_dims(ne1, ndims); - - // // Ensure that the output is not zero-sized: - ne1[0] += 8; - ne1[1] += 8; - - if (is_2D) { - ne1[2] = ne0[2]; - } else { - ne1[1] = ne0[1]; - ne0[3] = 1; - ne1[3] = 1; - } - - // The order of arguments is swapped because the first tensor is only used for its shape. - x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f); - x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f); - - ggml_set_param(ctx0, x[0]); - - const int s0 = 1 + irand(2); - const int s1 = is_2D ? 1 + irand(2) : 0; - const int p0 = 0 + irand(2); - const int p1 = is_2D ? 0 + irand(2) : 0; - const int d0 = 1 + irand(2); - const int d1 = is_2D ? 1 + irand(2) : 0; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32)); - - GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1); - check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); - } - } - - // pool_2d f32 - { - srand(seed); - const int nargs = 1; - const int ndims = 4; - - for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { - int64_t ne0[ndims]; - get_random_dims(ne0, ndims); - - ne0[0] += 8; - ne0[1] += 8; - - x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f); - - ggml_set_param(ctx0, x[0]); - - const int k0 = 2 + irand(2); - const int k1 = 2 + irand(2); - const int s0 = 2 + irand(2); - const int s1 = 2 + irand(2); - const int p0 = 0 + irand(2); - const int p1 = 0 + irand(2); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1)); - - GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n", - op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1); - std::vector expected_vals; - if (op == GGML_OP_POOL_MAX) { - expected_vals.push_back(0.0); - expected_vals.push_back(1.0); - } - check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals); - } - } - - // flash_attn f32 - // TODO: adapt to ggml_flash_attn_ext() changes - //{ - // srand(seed); - // const int nargs = 3; - - // int64_t ne2[4]; - - // get_random_dims(ne2, 4); - // int64_t D = ne2[0]; - // int64_t N = ne2[1]; - // int64_t M = ne2[2] + N; - // int64_t B = ne2[3]; - - // for (int masked = 0; masked <= 1; ++masked) { - // for (int ndims = 2; ndims <= 4; ++ndims) { - // int max_nrep = (ndims >= 3) ? 2 : 1; - // for (int nrep = 1; nrep < max_nrep; ++nrep) { - // int64_t neq[4] = { D, N, B*nrep, ne[3] }; - // int64_t nek[4] = { D, M, B, ne[3] }; - // int64_t nev[4] = { M, D, B, ne[3] }; - // if (ndims == 2) { - // neq[2] = 1; neq[3] = 1; - // nek[2] = 1; nek[3] = 1; - // nev[2] = 1; nev[3] = 1; - // } else if (ndims == 3) { - // neq[3] = 1; - // nek[3] = 1; - // nev[3] = 1; - // } - // x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f); - // x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f); - // x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f); - // ggml_set_param(ctx0, x[0]); - // ggml_set_param(ctx0, x[1]); - // ggml_set_param(ctx0, x[2]); - - // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - - // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {}); - // } - // } - // } - //} - - ggml_free(ctx0); - } - - return 0; -} diff --git a/tests/test-mul-mat0.c b/tests/test-mul-mat0.c index 5a6453c60..6e561efe3 100644 --- a/tests/test-mul-mat0.c +++ b/tests/test-mul-mat0.c @@ -97,15 +97,15 @@ bool check_gradient( float max_error_abs, float max_error_rel) { const int n_threads = 1; + ggml_set_loss(f); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); ggml_build_forward_expand(gf, f); struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); + ggml_build_backward_expand(ctx0, ctx0, gb, false); ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); + ggml_graph_reset(gb); ggml_graph_compute_with_ctx(ctx0, gb, n_threads); ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot"); @@ -132,11 +132,10 @@ bool check_gradient( set_element(x[i], k, x0); // compute gradient using backward graph - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); + ggml_graph_reset(gb); ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - const float g1 = get_element(x[i]->grad, k); + const float g1 = get_element(ggml_graph_get_grad(gb, x[i]), k); const float error_abs = fabsf(g0 - g1); const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0; diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index 546ca230b..4abe85c74 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -1,181 +1,892 @@ #include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ggml-opt.h" #include -#include -#include -#include - -#define MAX_NARGS 2 - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Wdouble-promotion" -#endif - -// -// logging -// -#define GGML_DEBUG 0 -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - - -static float frand(void) { - return (float)rand()/(float)RAND_MAX; -} - -static struct ggml_tensor * get_random_tensor( - struct ggml_context * ctx0, int ndims, int64_t ne[], float fmin, float fmax -) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne); - - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin; +#include +#include +#include +#include +#include + +static bool almost_equal(const double a, const double b, const double atol) { + return fabs(a - b) < atol; +} + +constexpr int64_t ne_datapoint = 2; +constexpr int64_t ne_label = 1; +constexpr int64_t ndata = 6; + +struct helper_ctx_data { + std::vector datasets_supervised; + std::vector data_batch; + std::vector labels_batch; + + ggml_opt_dataset_t dataset_unsupervised; + struct ggml_context * ctx_static; + struct ggml_context * ctx_compute; + struct ggml_opt_params opt_params; + ggml_opt_context_t opt_ctx; + struct ggml_tensor * inputs; + struct ggml_tensor * weights; + struct ggml_tensor * outputs; + ggml_backend_buffer_t buf; + ggml_opt_result_t result; + ggml_opt_result_t result2; +}; + +// These default values make it easier to check optimization results vs. expected values. +static ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) { + ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata); + result.adamw.alpha = 1.0f; + result.adamw.beta1 = 0.0f; + result.adamw.beta2 = 0.0f; + result.adamw.eps = 0.0f; + return result; +} + +static helper_ctx_data helper_get_ctx_data( + ggml_backend_sched_t backend_sched, + ggml_backend_t backend, + const bool init_opt_ctx = true, + const bool optimizer_defaults = true, + int64_t nbatch_logical = 1, + int64_t nbatch_physical = 1, + enum ggml_opt_loss_type loss_type = GGML_OPT_LOSS_TYPE_SUM) { + std::vector datasets(ndata); + for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) { + ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard); + + float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset)); + float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset)); + + for (int64_t idata = 0; idata < ndata; ++idata) { + for (int64_t id = 0; id < ne_datapoint; ++id) { + data[ idata*ne_datapoint + id] = 16*idata + id; } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } + for (int64_t il = 0; il < ne_label; ++il) { + labels[idata*ne_label + il] = 16*(16*idata + il); } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + + datasets[ndata_shard-1] = dataset; + } + + ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1); + + float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised)); + + for (int64_t idata = 0; idata < ndata; ++idata) { + data[idata] = idata; + } + + struct ggml_context * ctx_static; + struct ggml_context * ctx_compute; + { + struct ggml_init_params params = { + /*.mem_size =*/ (2*ndata + 2)*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_static = ggml_init(params); + } + { + struct ggml_init_params params = { + /*.mem_size =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute = ggml_init(params); + } + + std::vector data_batch(ndata); + std::vector labels_batch(ndata); + for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) { + data_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_datapoint); + labels_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_label); + } + + struct ggml_tensor * inputs = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, nbatch_physical); + ggml_set_name(inputs, "inputs"); + + struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); + ggml_set_name(weights, "weights"); + ggml_set_param(ctx_static, weights); + + struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights); + + struct ggml_tensor * outputs = ggml_scale(ctx_compute, intermediary, 1.0f); + ggml_set_name(outputs, "outputs"); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend); + const float w0 = float(ndata)/2; + ggml_backend_tensor_set(weights, &w0, 0, sizeof(float)); + + GGML_ASSERT(nbatch_logical % nbatch_physical == 0); + const int32_t opt_period = nbatch_logical / nbatch_physical; + + struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + opt_params.opt_period = opt_period; + if (!optimizer_defaults) { + opt_params.get_opt_pars = helper_get_test_opt_pars; + } + ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr; + + ggml_opt_result_t result = ggml_opt_result_init(); + ggml_opt_result_t result2 = ggml_opt_result_init(); + + return {datasets, data_batch, labels_batch, dataset_unsupervised, ctx_static, ctx_compute, opt_params, opt_ctx, inputs, weights, outputs, buf, result, result2}; +} + +static void helper_free_ctx_data(struct helper_ctx_data ctx_data) { + ggml_opt_result_free(ctx_data.result); + ggml_opt_result_free(ctx_data.result2); + ggml_opt_free(ctx_data.opt_ctx); + ggml_backend_buffer_free(ctx_data.buf); + ggml_free(ctx_data.ctx_static); + ggml_free(ctx_data.ctx_compute); + for (ggml_opt_dataset_t dataset : ctx_data.datasets_supervised) { + ggml_opt_dataset_free(dataset); + } + ggml_opt_dataset_free(ctx_data.dataset_unsupervised); +} + +static void helper_after_test( + const char * func, const bool high_level, const std::string options, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + printf(" %s(high_level=%s%s, subtest=%s): ", + func, high_level ? "yes" : "no", options.c_str(), subtest.c_str()); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; +} + +static std::pair test_dataset(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend); + + for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) { + ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1]; + + if (shuffle) { + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + } + + for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) { + if (ndata_batch % ndata_shard != 0) { + continue; + } + bool subtest_ok = true; + + struct ggml_tensor * data_batch = cd.data_batch[ndata_batch-1]; + struct ggml_tensor * labels_batch = cd.labels_batch[ndata_batch-1]; + + std::vector data(ggml_nelements( data_batch)); + std::vector labels(ggml_nelements(labels_batch)); + + std::vector idata_shuffled; + const int64_t nbatches = ndata / ndata_batch; + for (int64_t ibatch = 0; ibatch < nbatches; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, data_batch, labels_batch, ibatch); + + ggml_backend_tensor_get( data_batch, data.data(), 0, ggml_nbytes( data_batch)); + ggml_backend_tensor_get(labels_batch, labels.data(), 0, ggml_nbytes(labels_batch)); + + for (int64_t idata_batch = 0; idata_batch < ndata_batch; ++idata_batch) { + const int64_t idata = ibatch*ndata_batch + idata_batch; + const int64_t idata_found = data[idata_batch*ne_datapoint] / 16; + subtest_ok = subtest_ok && (shuffle || idata_found == idata); + idata_shuffled.push_back(idata_found); + + for (int64_t id = 0; id < ne_datapoint; ++id) { + if (data[ idata_batch*ne_datapoint + id] != 16*idata_found + id) { + subtest_ok = false; + } + } + for (int64_t il = 0; il < ne_label; ++il) { + if (labels[idata_batch*ne_label + il] != 16*(16*idata_found + il)) { + subtest_ok = false; + } } } } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } + + if (!shuffle || ndata % ndata_batch == 0) { + const int ndata_max = (ndata / ndata_batch) * ndata_batch; + + for (int64_t idata = 0; subtest_ok && idata < ndata_max; ++idata) { + int ninstances = 0; + for (int64_t id : idata_shuffled) { + ninstances += id == idata; } + if (ninstances != 1) { + subtest_ok = false; + } + } + } + + printf(" %s(shuffle=%s, ndata_shard=%" PRId64 ", ndata_batch=%" PRId64 "): ", + __func__, shuffle ? "yes" : "no", ndata_shard, ndata_batch); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + } + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static std::pair test_grad(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, + /*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1); + + std::vector grad_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + grad_history[idata] = NAN; + } + + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float)); + } + + { + bool subtest_ok = true; + for (int idata = 0; idata < ndata; ++idata) { + if (grad_history[idata] != idata + 1) { + subtest_ok = false; + } + } + printf(" %s(): ", __func__); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static void helper_after_test_forward_backward( + const char * func, const bool high_level, const bool shuffle, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + std::string options = ", shuffle="; + options += shuffle ? "yes" : "no"; + helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass); +} + +static std::pair test_forward_backward( + ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false); + struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx); + + std::vector loss_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + loss_history[idata] = NAN; + } + + { + int64_t ndata; + ggml_opt_result_ndata(cd.result, &ndata); + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc); + helper_after_test_forward_backward(__func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass); + } + + if (high_level) { + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + if (shuffle) { + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + } + ggml_opt_epoch(cd.opt_ctx, dataset, nullptr, cd.result, 0, nullptr, nullptr); + } else { + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + } + + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == ndata/2; + helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass); + } + { + int64_t ndata; + ggml_opt_result_ndata(cd.result, &ndata); + bool subtest_ok = ndata == 6; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_forward_backward(__func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass); + } + + float w0; + ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float)); + for (int i = 0; i < 10; ++i) { + ggml_opt_forward_backward(cd.opt_ctx, nullptr); + } + ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float)); + + ggml_opt_reset(cd.opt_ctx, /*optimizer =*/ false); + ggml_opt_result_reset(cd.result); + + for (int64_t idata = 0; idata < ndata; ++idata) { + loss_history[idata] = NAN; + } + + if (high_level) { + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + if (shuffle) { + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + } + ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr); + } else { + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + } + + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == -ndata/2; + helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass); + } + { + int64_t ndata; + ggml_opt_result_ndata(cd.result, &ndata); + bool subtest_ok = ndata == 6; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_forward_backward(__func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass); + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static std::pair test_epoch_vs_fit(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int ntest = 0; + int npass = 0; + + float weights_epoch; + float weights_fit; + + { + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true); + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr); + + ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights)); + helper_free_ctx_data(cd); + } + { + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ false); + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + + ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset, + GGML_OPT_LOSS_TYPE_SUM, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true); + + ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights)); + helper_free_ctx_data(cd); + } + + const bool subtest_ok = weights_epoch == weights_fit; + + printf(" %s(): ", __func__); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + + return std::make_pair(npass, ntest); +} + +static void helper_after_test_idata_split( + const char * func, const bool high_level, const int epoch, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + std::string options = ", epoch="; + options += std::to_string(epoch); + helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass); +} + +static std::pair test_idata_split(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false); + struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx); + const int idata_split = ndata * 2/3; + + std::vector loss_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + loss_history[idata] = NAN; + } + + for (int epoch = 1; epoch <= 4; ++epoch) { + if (high_level) { + ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr); + } else { + int idata = 0; + for (; idata < idata_split; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + for (; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward(cd.opt_ctx, cd.result2); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + } + + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == ndata/2 - epoch*idata_split; + helper_after_test_idata_split(__func__, high_level, epoch, "weights", subtest_ok, ntest, npass); + } + { + int64_t ndata_result; + ggml_opt_result_ndata(cd.result, &ndata_result); + bool subtest_ok = ndata_result == idata_split; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0; + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_idata_split(__func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass); + } + { + int64_t ndata_result; + ggml_opt_result_ndata(cd.result2, &ndata_result); + bool subtest_ok = ndata_result == ndata - idata_split; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result2, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_idata_split(__func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass); + } + + ggml_opt_result_reset(cd.result); + ggml_opt_result_reset(cd.result2); + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static void helper_after_test_gradient_accumulation( + const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + std::string options = ", nbatch_physical="; + options += std::to_string(nbatch_physical); + options += ", loss_type="; + options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? "mean" : "sum"; + options += ", epoch="; + options += std::to_string(epoch); + helper_after_test(func, false, options, subtest, subtest_ok, ntest, npass); +} + +static std::pair test_gradient_accumulation( + ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data( + backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type); + struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx); + + std::vector grad_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + grad_history[idata] = NAN; + } + + for (int epoch = 1; epoch <= 4; ++epoch) { + if (nbatch_physical == 1) { + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float)); + } + } else if (nbatch_physical == 2) { + for (int idata = 0; idata < ndata; idata += 2) { + const float idataf[2] = {float(idata + 0), float(idata + 1)}; + ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + + grad_history[idata + 0] = 0.0f; + ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float)); + } + } else { + GGML_ASSERT(false); + } + + { + GGML_ASSERT(ndata == 6); + constexpr double atol = 1e-6; + bool subtest_ok = true; + if (loss_type == GGML_OPT_LOSS_TYPE_SUM) { + if (nbatch_physical == 1) { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0, atol); + } else { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0, atol); } + subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol); + } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) { + if (nbatch_physical == 1) { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0/ndata, atol); + } else { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0/ndata, atol); + } + subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol); + } else { + GGML_ASSERT(false); + } + helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass); + } + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == (ndata/2) - epoch; + helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass); + } + { + int64_t ndata_result; + ggml_opt_result_ndata(cd.result, &ndata_result); + bool subtest_ok = ndata_result == ndata/nbatch_physical; + + double loss; + ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr); + if (loss_type == GGML_OPT_LOSS_TYPE_SUM) { + subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0); + } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) { + subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6); + } else { + GGML_ASSERT(false); } - break; - default: - assert(false); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass); + } + + ggml_opt_result_reset(cd.result); } + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) { + ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata); + result.adamw.alpha = 0.1f; return result; } -int main(void) { - struct ggml_init_params params = { - /* .mem_size = */ 1024*1024*1024, - /* .mem_buffer = */ NULL, - /* .no_alloc = */ false, - }; +static std::pair test_regression(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int ntest = 0; + int npass = 0; - struct ggml_context * ctx = ggml_init(params); + // Test for simple regression with f(x) = a*x + b - int64_t ne1[4] = {4, 128, 1, 1}; - int64_t ne2[4] = {4, 256, 1, 1}; - int64_t ne3[4] = {128, 256, 1, 1}; + constexpr int64_t ndata_regression = 201; + constexpr float a_true = 1.2f; + constexpr float b_true = 3.4f; - struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1); - struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + std::mt19937 gen(12345); + std::normal_distribution nd{0.0f, 0.1f}; - struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1); + ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression); - struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b); - struct ggml_tensor * d = ggml_sub(ctx, c, ab); - struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d)); + float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset)); + float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset)); - struct ggml_cgraph * ge = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(ge, e); - ggml_graph_reset(ge); + constexpr float x_min = -100.0f; + constexpr float x_max = 100.0f; - ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1); + for (int64_t idata = 0; idata < ndata_regression; ++idata) { + const float x = x_min + (x_max - x_min) * idata/(ndata_regression-1); + const float y = a_true*x + b_true + nd(gen); - const float fe = ggml_get_f32_1d(e, 0); - printf("%s: e = %.4f\n", __func__, fe); + data[idata] = x; + labels[idata] = y; + } - struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); + struct ggml_context * ctx_static; + struct ggml_context * ctx_compute; + { + struct ggml_init_params params = { + /*.mem_size =*/ 3*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_static = ggml_init(params); + } + { + struct ggml_init_params params = { + /*.mem_size =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute = ggml_init(params); + } - ggml_opt(ctx, opt_params, e); + // The first dimension is the dimension of the datapoints, the second dimension is the number of datapoints. + struct ggml_tensor * x = ggml_new_tensor_2d(ctx_static, GGML_TYPE_F32, 1, ndata_regression); + ggml_set_name(x, "x"); + + struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); + ggml_set_name(a, "a"); + ggml_set_param(ctx_static, a); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); + ggml_set_name(b, "b"); + ggml_set_param(ctx_static, b); + + struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b); + ggml_set_name(f, "f"); + ggml_set_param(ctx_static, f); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend); + const float a0 = 1.0f; + const float b0 = 3.0f; + ggml_backend_tensor_set(a, &a0, 0, sizeof(float)); + ggml_backend_tensor_set(b, &b0, 0, sizeof(float)); + + ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, + helper_get_regression_opt_pars, 100, ndata_regression, 0.0f, true); + + { + float a_fit; + ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float)); + float b_fit; + ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float)); + const bool subtest_ok = almost_equal(a_fit, a_true, 1e-2) && almost_equal(b_fit, b_true, 1e-2); + printf(" %s(subtest=weights): ", __func__); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + } - ggml_graph_reset(ge); + ggml_backend_buffer_free(buf); + ggml_free(ctx_static); + ggml_opt_dataset_free(dataset); - ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1); + return std::make_pair(npass, ntest); +} - const float fe_opt = ggml_get_f32_1d(e, 0); - printf("%s: original e = %.4f\n", __func__, fe); - printf("%s: optimized e = %.4f\n", __func__, fe_opt); +static std::pair test_backend(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int npass = 0; + int ntest = 0; - const bool success = (fe_opt <= fe); - assert(success); + for (bool shuffle : {false, true}) { + std::pair partial = test_dataset(backend_sched, backend, shuffle); + npass += partial.first; + ntest += partial.second; + } + { + std::pair partial = test_grad(backend_sched, backend); + npass += partial.first; + ntest += partial.second; + } + for (bool high_level : {false, true}){ + for (bool shuffle : {false, true}) { + if (!high_level && shuffle) { + continue; + } - ggml_free(ctx); - return success ? 0 : -1; + std::pair partial = test_forward_backward(backend_sched, backend, high_level, shuffle); + npass += partial.first; + ntest += partial.second; + } + } + { + std::pair partial = test_epoch_vs_fit(backend_sched, backend); + npass += partial.first; + ntest += partial.second; + } + for (bool high_level : {false, true}){ + std::pair partial = test_idata_split(backend_sched, backend, high_level); + npass += partial.first; + ntest += partial.second; + } + for (int32_t nbatch_physical : {2, 1}) { + for (enum ggml_opt_loss_type loss_type : {GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN}) { + std::pair partial = test_gradient_accumulation(backend_sched, backend, nbatch_physical, loss_type); + npass += partial.first; + ntest += partial.second; + } + } + { + std::pair partial = test_regression(backend_sched, backend); + npass += partial.first; + ntest += partial.second; + } + + return std::make_pair(npass, ntest); } -// int64_t ne1[4] = {4, 128, 1, 1}; -// int64_t ne2[4] = {4, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 25890.9375 -// main: optimized e = 10094.7031 -// int64_t ne1[4] = {8, 128, 1, 1}; -// int64_t ne2[4] = {8, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 39429.5078 -// main: optimized e = 9275.8936 +int main(void) { + const size_t dev_count = ggml_backend_dev_count(); + printf("Testing %zu devices\n\n", dev_count); + size_t n_ok = 0; + + std::vector devs; + std::vector backends; -// int64_t ne1[4] = {16, 128, 1, 1}; -// int64_t ne2[4] = {16, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 68371.1328 -// main: optimized e = 7854.4502 + for (size_t i = 0; i < dev_count; ++i) { + devs.push_back(ggml_backend_dev_get(i)); + ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL); + GGML_ASSERT(backend != NULL); -// int64_t ne1[4] = {32, 128, 1, 1}; -// int64_t ne2[4] = {32, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 126061.1953 -// main: optimized e = 5451.0166 + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2); + } + + backends.push_back(backend); + } -// int64_t ne1[4] = {4, 1024, 1, 1}; -// int64_t ne2[4] = {4, 2048, 1, 1};; -// int64_t ne3[4] = {1024, 2048, 1, 1}; -// main: original e = 1620817.8750 -// main: optimized e = 698387.6875 + for (size_t i = 0; i < dev_count; ++i) { + // Put the backend to be tested in front so that it's prioritized: + std::vector backends_modded = {backends[i]}; + backends_modded.insert(backends_modded.end(), backends.begin(), backends.end()); -// another run on M1 -// int64_t ne1[4] = {4, 1024, 1, 1}; -// int64_t ne2[4] = {4, 2048, 1, 1};; -// int64_t ne3[4] = {1024, 2048, 1, 1}; -// main: original e = 1629595.6250 -// main: optimized e = 698169.1250 + ggml_backend_sched_t backend_sched = ggml_backend_sched_new( + backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false); -// int64_t ne1[4] = {32, 1024, 1, 1}; -// int64_t ne2[4] = {32, 2048, 1, 1};; -// int64_t ne3[4] = {1024, 2048, 1, 1}; -// main: original e = 8146770.5000 -// main: optimized e = 651119.1250 + printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i])); + printf(" Device description: %s\n", ggml_backend_dev_description(devs[i])); + size_t free, total; // NOLINT + ggml_backend_dev_memory(devs[i], &free, &total); + printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024); + printf("\n"); + + std::pair result = test_backend(backend_sched, backends[i]); + + printf(" %d/%d tests passed\n", result.first, result.second); + printf(" Backend %s: ", ggml_backend_name(backends[i])); + if (result.first == result.second) { + printf("\033[1;32mOK\033[0m\n"); + n_ok++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + + printf("\n"); + + ggml_backend_sched_free(backend_sched); + } + + for (ggml_backend_t backend : backends) { + ggml_backend_free(backend); + } + + printf("%zu/%zu backends passed\n", n_ok, dev_count); + if (n_ok != dev_count) { + printf("\033[1;31mFAIL\033[0m\n"); + return 1; + } + printf("\033[1;32mOK\033[0m\n"); + return 0; +} diff --git a/tests/test1.c b/tests/test1.c deleted file mode 100644 index 1a2db19c2..000000000 --- a/tests/test1.c +++ /dev/null @@ -1,459 +0,0 @@ -#include "ggml.h" -#include "ggml-cpu.h" - -#include -#include - -int main(int argc, const char ** argv) { - const int n_threads = 2; - - struct ggml_init_params params = { - .mem_size = 128*1024*1024, - .mem_buffer = NULL, - .no_alloc = false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - { - struct ggml_tensor * x = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - - ggml_set_param(ctx0, x); - - struct ggml_tensor * a = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - struct ggml_tensor * b = ggml_mul(ctx0, x, x); - struct ggml_tensor * f = ggml_mul(ctx0, b, a); - - // a*x^2 - // 2*a*x - - ggml_print_objects(ctx0); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, f); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_set_f32(x, 2.0f); - ggml_set_f32(a, 3.0f); - - ggml_graph_reset(gf); - ggml_set_f32(f->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("f = %f\n", ggml_get_f32_1d(f, 0)); - printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0)); - - GGML_ASSERT(ggml_get_f32_1d(f, 0) == 12.0f); - GGML_ASSERT(ggml_get_f32_1d(x->grad, 0) == 12.0f); - - ggml_set_f32(x, 3.0f); - - ggml_graph_reset(gf); - ggml_set_f32(f->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("f = %f\n", ggml_get_f32_1d(f, 0)); - printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0)); - - GGML_ASSERT(ggml_get_f32_1d(f, 0) == 27.0f); - GGML_ASSERT(ggml_get_f32_1d(x->grad, 0) == 18.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-1-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-1-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - struct ggml_tensor * x3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - - ggml_set_f32(x1, 3.0f); - ggml_set_f32(x2, 1.0f); - ggml_set_f32(x3, 0.0f); - - ggml_set_param(ctx0, x1); - ggml_set_param(ctx0, x2); - - struct ggml_tensor * y = ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2)); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, y); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0)); - printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == 12.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 7.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 3.0f); - - struct ggml_tensor * g1 = x1->grad; - struct ggml_tensor * g2 = x2->grad; - - struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb); - - ggml_build_backward_expand(ctx0, gb, gbb, false); - - ggml_graph_reset(gb); - ggml_set_f32(g1->grad, 1.0f); - ggml_set_f32(g2->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gbb, n_threads); - - printf("H * [1, 1] = [ %f %f ]\n", ggml_get_f32_1d(x1->grad, 0), ggml_get_f32_1d(x2->grad, 0)); - - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 3.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 1.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-2-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-2-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - - ggml_set_param(ctx0, x1); - ggml_set_param(ctx0, x2); - - struct ggml_tensor * y = ggml_mul(ctx0, ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2)), x1); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, y); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_set_f32(x1, 3.0f); - ggml_set_f32(x2, 4.0f); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0)); - printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == 63.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 51.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 9.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-3-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-3-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - struct ggml_tensor * x3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - - ggml_set_param(ctx0, x1); - ggml_set_param(ctx0, x2); - ggml_set_param(ctx0, x3); - - struct ggml_tensor * y = ggml_mul(ctx0, ggml_mul(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x2, x2)), x3); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, y); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_set_f32(x1, 1.0f); - ggml_set_f32(x2, 2.0f); - ggml_set_f32(x3, 3.0f); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0)); - printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0)); - printf("df/dx3 = %f\n", ggml_get_f32_1d(x3->grad, 0)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == 12.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 24.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 12.0f); - GGML_ASSERT(ggml_get_f32_1d(x3->grad, 0) == 4.0f); - - struct ggml_tensor * g1 = x1->grad; - struct ggml_tensor * g2 = x2->grad; - struct ggml_tensor * g3 = x3->grad; - - struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb); - - ggml_build_backward_expand(ctx0, gb, gbb, false); - - ggml_graph_reset(gb); - ggml_set_f32(g1->grad, 1.0f); - ggml_set_f32(g2->grad, 1.0f); - ggml_set_f32(g3->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gbb, n_threads); - - printf("H * [1, 1, 1] = [ %f %f %f ]\n", - ggml_get_f32_1d(x1->grad, 0), - ggml_get_f32_1d(x2->grad, 0), - ggml_get_f32_1d(x3->grad, 0)); - - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 56.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 34.0f); - GGML_ASSERT(ggml_get_f32_1d(x3->grad, 0) == 12.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-4-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-4-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - - ggml_set_param(ctx0, x1); - ggml_set_param(ctx0, x2); - - struct ggml_tensor * y = ggml_sum(ctx0, ggml_mul(ctx0, x1, x2)); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, y); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_set_f32(x1, 3.0f); - ggml_set_f32(x2, 5.0f); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f %f %f\n", - ggml_get_f32_1d(x1->grad, 0), - ggml_get_f32_1d(x1->grad, 1), - ggml_get_f32_1d(x1->grad, 2)); - printf("df/dx2 = %f %f %f\n", - ggml_get_f32_1d(x2->grad, 0), - ggml_get_f32_1d(x2->grad, 1), - ggml_get_f32_1d(x2->grad, 2)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == 45.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 5.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 3.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == 5.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == 3.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == 5.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == 3.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-5-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-5-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - - ggml_set_param(ctx0, x1); - ggml_set_param(ctx0, x2); - - struct ggml_tensor * y = - ggml_sum(ctx0, - ggml_add(ctx0, - ggml_mul(ctx0, x1, x2), - ggml_mul(ctx0, - ggml_repeat(ctx0, ggml_new_f32(ctx0, -2.0f), x1), - ggml_mul(ctx0, x1, x1) - ) - ) - ); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, y); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_set_f32(x1, 3.0f); - ggml_set_f32(x2, 5.0f); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f %f %f\n", - ggml_get_f32_1d(x1->grad, 0), - ggml_get_f32_1d(x1->grad, 1), - ggml_get_f32_1d(x1->grad, 2)); - printf("df/dx2 = %f %f %f\n", - ggml_get_f32_1d(x2->grad, 0), - ggml_get_f32_1d(x2->grad, 1), - ggml_get_f32_1d(x2->grad, 2)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == -9.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == -7.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == -7.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == -7.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 3.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == 3.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == 3.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-6-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-6-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - - ggml_set_param(ctx0, x1); - ggml_set_param(ctx0, x2); - - struct ggml_tensor * y = - ggml_sum(ctx0, - ggml_sub(ctx0, - ggml_mul(ctx0, x1, x2), - ggml_mul(ctx0, - ggml_mul(ctx0, x1, x1), - ggml_repeat(ctx0, ggml_new_f32(ctx0, -2.0f), x1) - ) - ) - ); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, y); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_set_f32(x1, 3.0f); - ggml_set_f32(x2, 5.0f); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f %f %f\n", - ggml_get_f32_1d(x1->grad, 0), - ggml_get_f32_1d(x1->grad, 1), - ggml_get_f32_1d(x1->grad, 2)); - printf("df/dx2 = %f %f %f\n", - ggml_get_f32_1d(x2->grad, 0), - ggml_get_f32_1d(x2->grad, 1), - ggml_get_f32_1d(x2->grad, 2)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == 99.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 17.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == 17.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == 17.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 3.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == 3.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == 3.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-7-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-7-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3); - - ggml_set_param(ctx0, x1); - ggml_set_param(ctx0, x2); - - struct ggml_tensor * y = - ggml_abs(ctx0, - ggml_sub(ctx0, x1, x2) - ); - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, y); - struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_set_f32(x1, 3.0f); - ggml_set_f32(x2, 5.0f); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f %f %f\n", - ggml_get_f32_1d(x1->grad, 0), - ggml_get_f32_1d(x1->grad, 1), - ggml_get_f32_1d(x1->grad, 2)); - printf("df/dx2 = %f %f %f\n", - ggml_get_f32_1d(x2->grad, 0), - ggml_get_f32_1d(x2->grad, 1), - ggml_get_f32_1d(x2->grad, 2)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == 2.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == -1.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == -1.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == -1.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 1.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == 1.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == 1.0f); - - ggml_set_f32(x1, 7.0f); - ggml_set_f32(x2, 5.0f); - - ggml_graph_reset(gf); - ggml_set_f32(y->grad, 1.0f); - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - printf("y = %f\n", ggml_get_f32_1d(y, 0)); - printf("df/dx1 = %f %f %f\n", - ggml_get_f32_1d(x1->grad, 0), - ggml_get_f32_1d(x1->grad, 1), - ggml_get_f32_1d(x1->grad, 2)); - printf("df/dx2 = %f %f %f\n", - ggml_get_f32_1d(x2->grad, 0), - ggml_get_f32_1d(x2->grad, 1), - ggml_get_f32_1d(x2->grad, 2)); - - GGML_ASSERT(ggml_get_f32_1d(y, 0) == 2.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 1.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == 1.0f); - GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == 1.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == -1.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == -1.0f); - GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == -1.0f); - - ggml_graph_dump_dot(gf, NULL, "test1-8-forward.dot"); - ggml_graph_dump_dot(gb, gf, "test1-8-backward.dot"); - } - - ggml_free(ctx0); - - return 0; -} diff --git a/tests/test1.zig b/tests/test1.zig deleted file mode 100644 index 815d81438..000000000 --- a/tests/test1.zig +++ /dev/null @@ -1,450 +0,0 @@ -const std = @import("std"); -const c = @cImport({ - @cInclude("ggml.h"); -}); - -pub fn main() !void { - const n_threads = 2; - - const params = .{ - .mem_size = 128 * 1024 * 1024, - .mem_buffer = null, - .no_alloc = false, - }; - - const ctx0 = c.ggml_init(params); - defer c.ggml_free(ctx0); - - { - const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - - c.ggml_set_param(ctx0, x); - - const a = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - const b = c.ggml_mul(ctx0, x, x); - const f = c.ggml_mul(ctx0, b, a); - - // a*x^2 - // 2*a*x - - c.ggml_print_objects(ctx0); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, f); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - _ = c.ggml_set_f32(x, 2.0); - _ = c.ggml_set_f32(a, 3.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(f.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)}); - std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)}); - - try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 12.0); - try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 12.0); - - _ = c.ggml_set_f32(x, 3.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(f.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)}); - std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)}); - - try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 27.0); - try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 18.0); - - c.ggml_graph_dump_dot(gf, null, "test1-1-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-1-backward.dot"); - } - - ///////////////////////////////////////////////////////////// - - { - const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - - _ = c.ggml_set_f32(x1, 3.0); - _ = c.ggml_set_f32(x2, 1.0); - _ = c.ggml_set_f32(x3, 0.0); - - c.ggml_set_param(ctx0, x1); - c.ggml_set_param(ctx0, x2); - - const y = c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2)); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, y); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)}); - std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)}); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 7.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); - - const g1 = x1.*.grad; - const g2 = x2.*.grad; - - const gbb = c.ggml_graph_dup(ctx0, @constCast(gb)); - - c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true); - - c.ggml_graph_reset(@constCast(gb)); - _ = c.ggml_set_f32(g1.*.grad, 1.0); - _ = c.ggml_set_f32(g2.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads); - - std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0) }); - - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 3.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0); - - c.ggml_graph_dump_dot(gf, null, "test1-2-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-2-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - - c.ggml_set_param(ctx0, x1); - c.ggml_set_param(ctx0, x2); - - const y = c.ggml_mul(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2)), x1); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, y); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - _ = c.ggml_set_f32(x1, 3.0); - _ = c.ggml_set_f32(x2, 4.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)}); - std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)}); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 63.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 51.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 9.0); - - c.ggml_graph_dump_dot(gf, null, "test1-3-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-3-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1); - - c.ggml_set_param(ctx0, x1); - c.ggml_set_param(ctx0, x2); - c.ggml_set_param(ctx0, x3); - - const y = c.ggml_mul(ctx0, c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x2, x2)), x3); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, y); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - _ = c.ggml_set_f32(x1, 1.0); - _ = c.ggml_set_f32(x2, 2.0); - _ = c.ggml_set_f32(x3, 3.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)}); - std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)}); - std.debug.print("df/dx3 = {d:.6}\n", .{c.ggml_get_f32_1d(x3.*.grad, 0)}); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 24.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 12.0); - try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 4.0); - - const g1 = x1.*.grad; - const g2 = x2.*.grad; - const g3 = x3.*.grad; - - const gbb = c.ggml_graph_dup(ctx0, @constCast(gb)); - - c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true); - - c.ggml_graph_reset(@constCast(gb)); - _ = c.ggml_set_f32(g1.*.grad, 1.0); - _ = c.ggml_set_f32(g2.*.grad, 1.0); - _ = c.ggml_set_f32(g3.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads); - - std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n", .{ - c.ggml_get_f32_1d(x1.*.grad, 0), - c.ggml_get_f32_1d(x2.*.grad, 0), - c.ggml_get_f32_1d(x3.*.grad, 0), - }); - - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 56.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 34.0); - try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 12.0); - - c.ggml_graph_dump_dot(gf, null, "test1-4-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-4-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - - c.ggml_set_param(ctx0, x1); - c.ggml_set_param(ctx0, x2); - - const y = c.ggml_sum(ctx0, c.ggml_mul(ctx0, x1, x2)); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, y); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - _ = c.ggml_set_f32(x1, 3.0); - _ = c.ggml_set_f32(x2, 5.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x1.*.grad, 0), - c.ggml_get_f32_1d(x1.*.grad, 1), - c.ggml_get_f32_1d(x1.*.grad, 2), - }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x2.*.grad, 0), - c.ggml_get_f32_1d(x2.*.grad, 1), - c.ggml_get_f32_1d(x2.*.grad, 2), - }); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 45.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 5.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 5.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 5.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0); - - c.ggml_graph_dump_dot(gf, null, "test1-5-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-5-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - - c.ggml_set_param(ctx0, x1); - c.ggml_set_param(ctx0, x2); - - const y = - c.ggml_sum(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1), c.ggml_mul(ctx0, x1, x1)))); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, y); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - _ = c.ggml_set_f32(x1, 3.0); - _ = c.ggml_set_f32(x2, 5.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x1.*.grad, 0), - c.ggml_get_f32_1d(x1.*.grad, 1), - c.ggml_get_f32_1d(x1.*.grad, 2), - }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x2.*.grad, 0), - c.ggml_get_f32_1d(x2.*.grad, 1), - c.ggml_get_f32_1d(x2.*.grad, 2), - }); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == -9.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -7.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -7.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -7.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0); - - c.ggml_graph_dump_dot(gf, null, "test1-6-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-6-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - - c.ggml_set_param(ctx0, x1); - c.ggml_set_param(ctx0, x2); - - const y = - c.ggml_sum(ctx0, c.ggml_sub(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1)))); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, y); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - _ = c.ggml_set_f32(x1, 3.0); - _ = c.ggml_set_f32(x2, 5.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x1.*.grad, 0), - c.ggml_get_f32_1d(x1.*.grad, 1), - c.ggml_get_f32_1d(x1.*.grad, 2), - }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x2.*.grad, 0), - c.ggml_get_f32_1d(x2.*.grad, 1), - c.ggml_get_f32_1d(x2.*.grad, 2), - }); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 99.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 17.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 17.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 17.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0); - - c.ggml_graph_dump_dot(gf, null, "test1-7-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-7-backward.dot"); - } - - /////////////////////////////////////////////////////////////// - - { - const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3); - - c.ggml_set_param(ctx0, x1); - c.ggml_set_param(ctx0, x2); - - const y = - c.ggml_abs(ctx0, c.ggml_sub(ctx0, x1, x2)); - - const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true); - c.ggml_build_forward_expand(gf, y); - const gb = c.ggml_graph_dup(ctx0, @constCast(gf)); - c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false); - - _ = c.ggml_set_f32(x1, 3.0); - _ = c.ggml_set_f32(x2, 5.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x1.*.grad, 0), - c.ggml_get_f32_1d(x1.*.grad, 1), - c.ggml_get_f32_1d(x1.*.grad, 2), - }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x2.*.grad, 0), - c.ggml_get_f32_1d(x2.*.grad, 1), - c.ggml_get_f32_1d(x2.*.grad, 2), - }); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -1.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -1.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -1.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 1.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 1.0); - - _ = c.ggml_set_f32(x1, 7.0); - _ = c.ggml_set_f32(x2, 5.0); - - c.ggml_graph_reset(@constCast(gf)); - _ = c.ggml_set_f32(y.*.grad, 1.0); - - _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads); - - std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)}); - std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x1.*.grad, 0), - c.ggml_get_f32_1d(x1.*.grad, 1), - c.ggml_get_f32_1d(x1.*.grad, 2), - }); - std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{ - c.ggml_get_f32_1d(x2.*.grad, 0), - c.ggml_get_f32_1d(x2.*.grad, 1), - c.ggml_get_f32_1d(x2.*.grad, 2), - }); - - try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 1.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 1.0); - try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 1.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == -1.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == -1.0); - try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == -1.0); - - c.ggml_graph_dump_dot(gf, null, "test1-8-forward.dot"); - c.ggml_graph_dump_dot(gb, gf, "test1-8-backward.dot"); - } - - _ = try std.io.getStdIn().reader().readByte(); -} diff --git a/tests/test2.c b/tests/test2.c deleted file mode 100644 index 76c044f4b..000000000 --- a/tests/test2.c +++ /dev/null @@ -1,182 +0,0 @@ -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows -#include "ggml.h" -#include "ggml-cpu.h" - -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -bool is_close(float a, float b, float epsilon) { - return fabs(a - b) < epsilon; -} - -int main(int argc, const char ** argv) { - struct ggml_init_params params = { - .mem_size = 128*1024*1024, - .mem_buffer = NULL, - .no_alloc = false, - }; - - //struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); - //opt_params.adam.alpha = 0.01f; - - struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_LBFGS); - - // original threads: 8 - int nthreads = 8; - const char *env = getenv("GGML_NTHREADS"); - if (env != NULL) { - nthreads = atoi(env); - } - if (argc > 1) { - nthreads = atoi(argv[1]); - } - opt_params.n_threads = nthreads; - printf("test2: n_threads:%d\n", opt_params.n_threads); - - const float xi[] = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f , 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, }; - float yi[] = { 15.0f, 25.0f, 35.0f, 45.0f, 55.0f, 65.0f, 75.0f, 85.0f, 95.0f, 105.0f, }; - - const int n = sizeof(xi)/sizeof(xi[0]); - - struct ggml_context * ctx0 = ggml_init(params); - - struct ggml_tensor * x = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n); - struct ggml_tensor * y = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n); - - for (int i = 0; i < n; i++) { - ((float *) x->data)[i] = xi[i]; - ((float *) y->data)[i] = yi[i]; - } - - { - struct ggml_tensor * t0 = ggml_new_f32(ctx0, 0.0f); - struct ggml_tensor * t1 = ggml_new_f32(ctx0, 0.0f); - - // initialize auto-diff parameters: - ggml_set_param(ctx0, t0); - ggml_set_param(ctx0, t1); - - // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n) - struct ggml_tensor * f = - ggml_div(ctx0, - ggml_sum(ctx0, - ggml_sqr(ctx0, - ggml_sub(ctx0, - ggml_add(ctx0, - ggml_mul(ctx0, x, ggml_repeat(ctx0, t1, x)), - ggml_repeat(ctx0, t0, x)), - y) - ) - ), - ggml_new_f32(ctx0, 2.0f*n)); - - enum ggml_opt_result res = ggml_opt(NULL, opt_params, f); - - printf("t0 = %f\n", ggml_get_f32_1d(t0, 0)); - printf("t1 = %f\n", ggml_get_f32_1d(t1, 0)); - - GGML_ASSERT(res == GGML_OPT_RESULT_OK); - - GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0), 5.0f, 1e-3f)); - GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-3f)); - } - - { - struct ggml_tensor * t0 = ggml_new_f32(ctx0, -1.0f); - struct ggml_tensor * t1 = ggml_new_f32(ctx0, 9.0f); - - ggml_set_param(ctx0, t0); - ggml_set_param(ctx0, t1); - - // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n - struct ggml_tensor * f = - ggml_mul(ctx0, - ggml_new_f32(ctx0, 1.0/(2*n)), - ggml_sum(ctx0, - ggml_abs(ctx0, - ggml_sub(ctx0, - ggml_add(ctx0, - ggml_mul(ctx0, x, ggml_repeat(ctx0, t1, x)), - ggml_repeat(ctx0, t0, x)), - y) - ) - ) - ); - - - enum ggml_opt_result res = ggml_opt(NULL, opt_params, f); - - GGML_ASSERT(res == GGML_OPT_RESULT_OK); - GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0), 5.0f, 1e-2f)); - GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-2f)); - } - - { - struct ggml_tensor * t0 = ggml_new_f32(ctx0, 5.0f); - struct ggml_tensor * t1 = ggml_new_f32(ctx0, -4.0f); - - ggml_set_param(ctx0, t0); - ggml_set_param(ctx0, t1); - - // f = t0^2 + t1^2 - struct ggml_tensor * f = - ggml_add(ctx0, - ggml_sqr(ctx0, t0), - ggml_sqr(ctx0, t1) - ); - - enum ggml_opt_result res = ggml_opt(NULL, opt_params, f); - - GGML_ASSERT(res == GGML_OPT_RESULT_OK); - GGML_ASSERT(is_close(ggml_get_f32_1d(f, 0), 0.0f, 1e-3f)); - GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0), 0.0f, 1e-3f)); - GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 0.0f, 1e-3f)); - } - - ///////////////////////////////////////// - - { - struct ggml_tensor * t0 = ggml_new_f32(ctx0, -7.0f); - struct ggml_tensor * t1 = ggml_new_f32(ctx0, 8.0f); - - ggml_set_param(ctx0, t0); - ggml_set_param(ctx0, t1); - - // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2 - struct ggml_tensor * f = - ggml_add(ctx0, - ggml_sqr(ctx0, - ggml_sub(ctx0, - ggml_add(ctx0, - t0, - ggml_mul(ctx0, t1, ggml_new_f32(ctx0, 2.0f))), - ggml_new_f32(ctx0, 7.0f) - ) - ), - ggml_sqr(ctx0, - ggml_sub(ctx0, - ggml_add(ctx0, - ggml_mul(ctx0, t0, ggml_new_f32(ctx0, 2.0f)), - t1), - ggml_new_f32(ctx0, 5.0f) - ) - ) - ); - - enum ggml_opt_result res = ggml_opt(NULL, opt_params, f); - - GGML_ASSERT(res == GGML_OPT_RESULT_OK); - GGML_ASSERT(is_close(ggml_get_f32_1d(f, 0), 0.0f, 1e-3f)); - GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0), 1.0f, 1e-3f)); - GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 3.0f, 1e-3f)); - } - - ggml_free(ctx0); - - return 0; -} diff --git a/tests/test2.zig b/tests/test2.zig deleted file mode 100644 index 783eba6e6..000000000 --- a/tests/test2.zig +++ /dev/null @@ -1,123 +0,0 @@ -const std = @import("std"); -const Thread = std.Thread; -const c = @cImport({ - @cInclude("ggml.h"); -}); - -fn is_close(a: f32, b: f32, epsilon: f32) bool { - return @abs(a - b) < epsilon; -} - -pub fn main() !void { - const params = .{ - .mem_size = 128 * 1024 * 1024, - .mem_buffer = null, - .no_alloc = false, - }; - - var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS); - - const nthreads = try Thread.getCpuCount(); - opt_params.n_threads = @intCast(nthreads); - std.debug.print("test2: n_threads:{}\n", .{opt_params.n_threads}); - - const xi = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 }; - const yi = [_]f32{ 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0 }; - - const n = xi.len; - - const ctx0 = c.ggml_init(params); - defer c.ggml_free(ctx0); - - const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n); - const y = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n); - - for (0..n) |i| { - const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data)); - x_data_pointer[i] = xi[i]; - const y_data_pointer: [*]f32 = @ptrCast(@alignCast(y.*.data)); - y_data_pointer[i] = yi[i]; - } - - { - const t0 = c.ggml_new_f32(ctx0, 0.0); - const t1 = c.ggml_new_f32(ctx0, 0.0); - - // initialize auto-diff parameters: - _ = c.ggml_set_param(ctx0, t0); - _ = c.ggml_set_param(ctx0, t1); - - // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n) - const f = - c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y))), c.ggml_new_f32(ctx0, @as(f32, 2.0) * n)); - - const res = c.ggml_opt(null, opt_params, f); - - std.debug.print("t0 = {d:.6}\n", .{c.ggml_get_f32_1d(t0, 0)}); - std.debug.print("t1 = {d:.6}\n", .{c.ggml_get_f32_1d(t1, 0)}); - - try std.testing.expect(res == c.GGML_OPT_RESULT_OK); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-3)); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-3)); - } - - { - const t0 = c.ggml_new_f32(ctx0, -1.0); - const t1 = c.ggml_new_f32(ctx0, 9.0); - - _ = c.ggml_set_param(ctx0, t0); - _ = c.ggml_set_param(ctx0, t1); - - // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n - const f = - c.ggml_mul(ctx0, c.ggml_new_f32(ctx0, @as(f32, 1.0) / (2 * n)), c.ggml_sum(ctx0, c.ggml_abs(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y)))); - - const res = c.ggml_opt(null, opt_params, f); - - try std.testing.expect(res == c.GGML_OPT_RESULT_OK); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-2)); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-2)); - } - - { - const t0 = c.ggml_new_f32(ctx0, 5.0); - const t1 = c.ggml_new_f32(ctx0, -4.0); - - _ = c.ggml_set_param(ctx0, t0); - _ = c.ggml_set_param(ctx0, t1); - - // f = t0^2 + t1^2 - const f = - c.ggml_add(ctx0, c.ggml_sqr(ctx0, t0), c.ggml_sqr(ctx0, t1)); - - const res = c.ggml_opt(null, opt_params, f); - - try std.testing.expect(res == c.GGML_OPT_RESULT_OK); - try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3)); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 0.0, 1e-3)); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 0.0, 1e-3)); - } - - ///////////////////////////////////////// - - { - const t0 = c.ggml_new_f32(ctx0, -7.0); - const t1 = c.ggml_new_f32(ctx0, 8.0); - - _ = c.ggml_set_param(ctx0, t0); - _ = c.ggml_set_param(ctx0, t1); - - // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2 - const f = - c.ggml_add(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, t0, c.ggml_mul(ctx0, t1, c.ggml_new_f32(ctx0, 2.0))), c.ggml_new_f32(ctx0, 7.0))), c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, t0, c.ggml_new_f32(ctx0, 2.0)), t1), c.ggml_new_f32(ctx0, 5.0)))); - - const res = c.ggml_opt(null, opt_params, f); - - try std.testing.expect(res == c.GGML_OPT_RESULT_OK); - try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3)); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 1.0, 1e-3)); - try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 3.0, 1e-3)); - } - - _ = try std.io.getStdIn().reader().readByte(); -} diff --git a/tests/test3.c b/tests/test3.c deleted file mode 100644 index d1e9fcc61..000000000 --- a/tests/test3.c +++ /dev/null @@ -1,95 +0,0 @@ -#include "ggml.h" - -#include -#include -#include - -bool is_close(float a, float b, float epsilon) { - return fabs(a - b) < epsilon; -} - -int main(int argc, const char ** argv) { - struct ggml_init_params params = { - .mem_size = 1024*1024*1024, - .mem_buffer = NULL, - .no_alloc = false, - }; - - //struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); - struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_LBFGS); - - opt_params.n_threads = (argc > 1) ? atoi(argv[1]) : 8; - - const int NP = 1 << 12; - const int NF = 1 << 8; - - struct ggml_context * ctx0 = ggml_init(params); - - struct ggml_tensor * F = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, NF, NP); - struct ggml_tensor * l = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NP); - - // regularization weight - struct ggml_tensor * lambda = ggml_new_f32(ctx0, 1e-5f); - - srand(0); - - for (int j = 0; j < NP; j++) { - const float ll = j < NP/2 ? 1.0f : -1.0f; - ((float *)l->data)[j] = ll; - - for (int i = 0; i < NF; i++) { - ((float *)F->data)[j*NF + i] = ((ll > 0 && i < NF/2 ? 1.0f : ll < 0 && i >= NF/2 ? 1.0f : 0.0f) + ((float)rand()/(float)RAND_MAX - 0.5f)*0.1f)/(0.5f*NF); - } - } - - { - // initial guess - struct ggml_tensor * x = ggml_set_f32(ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NF), 0.0f); - - ggml_set_param(ctx0, x); - - // f = sum[(fj*x - l)^2]/n + lambda*|x^2| - struct ggml_tensor * f = - ggml_add(ctx0, - ggml_div(ctx0, - ggml_sum(ctx0, - ggml_sqr(ctx0, - ggml_sub(ctx0, - ggml_mul_mat(ctx0, F, x), - l) - ) - ), - ggml_new_f32(ctx0, (float)NP) - ), - ggml_mul(ctx0, - ggml_sum(ctx0, ggml_sqr(ctx0, x)), - lambda) - ); - - enum ggml_opt_result res = ggml_opt(NULL, opt_params, f); - - GGML_ASSERT(res == GGML_OPT_RESULT_OK); - - // print results - for (int i = 0; i < 16; i++) { - printf("x[%3d] = %g\n", i, ((float *)x->data)[i]); - } - printf("...\n"); - for (int i = NF - 16; i < NF; i++) { - printf("x[%3d] = %g\n", i, ((float *)x->data)[i]); - } - printf("\n"); - - for (int i = 0; i < NF; ++i) { - if (i < NF/2) { - GGML_ASSERT(is_close(((float *)x->data)[i], 1.0f, 1e-2f)); - } else { - GGML_ASSERT(is_close(((float *)x->data)[i], -1.0f, 1e-2f)); - } - } - } - - ggml_free(ctx0); - - return 0; -} diff --git a/tests/test3.zig b/tests/test3.zig deleted file mode 100644 index ecaf1b014..000000000 --- a/tests/test3.zig +++ /dev/null @@ -1,87 +0,0 @@ -const std = @import("std"); -const Thread = std.Thread; -const c = @cImport({ - @cInclude("stdlib.h"); - @cInclude("ggml.h"); -}); - -fn is_close(a: f32, b: f32, epsilon: f32) bool { - return @abs(a - b) < epsilon; -} - -pub fn main() !void { - const params = .{ - .mem_size = 128 * 1024 * 1024, - .mem_buffer = null, - .no_alloc = false, - }; - - var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS); - - const nthreads = try Thread.getCpuCount(); - opt_params.n_threads = @intCast(nthreads); - - const NP = 1 << 12; - const NF = 1 << 8; - - const ctx0 = c.ggml_init(params); - defer c.ggml_free(ctx0); - - const F = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_F32, NF, NP); - const l = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NP); - - // regularization weight - const lambda = c.ggml_new_f32(ctx0, 1e-5); - - c.srand(0); - - const l_data_pointer: [*]f32 = @ptrCast(@alignCast(l.*.data)); - const f_data_pointer: [*]f32 = @ptrCast(@alignCast(F.*.data)); - for (0..NP) |j| { - const ll = if (j < NP / 2) @as(f32, 1.0) else @as(f32, -1.0); - l_data_pointer[j] = ll; - - for (0..NF) |i| { - const c_rand: f32 = @floatFromInt(c.rand()); - f_data_pointer[j * NF + i] = - ((if (ll > 0 and i < NF / 2) @as(f32, 1.0) else if (ll < 0 and i >= NF / 2) @as(f32, 1.0) else @as(f32, 0.0)) + - (c_rand / c.RAND_MAX - 0.5) * 0.1) / (0.5 * NF); - } - } - - { - // initial guess - const x = c.ggml_set_f32(c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NF), 0.0); - - c.ggml_set_param(ctx0, x); - - // f = sum[(fj*x - l)^2]/n + lambda*|x^2| - const f = - c.ggml_add(ctx0, c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_mul_mat(ctx0, F, x), l))), c.ggml_new_f32(ctx0, @as(f32, NP))), c.ggml_mul(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, x)), lambda)); - - const res = c.ggml_opt(null, opt_params, f); - - try std.testing.expect(res == c.GGML_OPT_RESULT_OK); - - const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data)); - // print results - for (0..16) |i| { - std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] }); - } - std.debug.print("...\n", .{}); - for (NF - 16..NF) |i| { - std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] }); - } - std.debug.print("\n", .{}); - - for (0..NF) |i| { - if (i < NF / 2) { - try std.testing.expect(is_close(x_data_pointer[i], 1.0, 1e-2)); - } else { - try std.testing.expect(is_close(x_data_pointer[i], -1.0, 1e-2)); - } - } - } - - _ = try std.io.getStdIn().reader().readByte(); -}