Skip to content

Commit

Permalink
Implement batch inference for image encoding (#30)
Browse files Browse the repository at this point in the history
* WIP: Implement batch inference

* WIP: use broadcastable mul_mat

* Batched Conv2D is working

* Batched Conv2D is working

* Fix concat

* Batched output normalization

* Full batch inference is working

* Full batch inference is working

* Sync ggml

* Sync ggml

* Sync ggml

* add multithreaded batched image preprocessing

* add multithreaded batched image preprocessing

* Update batch preprocess function signature

* Implement batch inference in benchmark util

* sync ggml

* set n_threads as const

* Sync ggml
  • Loading branch information
monatis authored Jul 12, 2023
1 parent 84de1c7 commit 018df28
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 93 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ endif()

# general
option(CLIP_STATIC "CLIP: static link libraries" OFF)
option(CLIP_BUILD_TEST "CLIP: build tests" ${CLIP_STANDALONE})
option(CLIP_BUILD_TESTS "CLIP: build tests" ${CLIP_STANDALONE})
option(CLIP_BUILD_EXAMPLES "CLIP: build examples" ${CLIP_STANDALONE})
option(CLIP_BUILD_IMAGE_SEARCH "CLIP: build image-search" OFF)
option(CLIP_NATIVE "CLIP: enable -march=native flag" ON)
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ This repo is aimed at powering useful applications based on such models on compu

clip.cpp also has a short startup time compared to large ML frameworks, which makes it suitable for serverless deployments where the cold start is an issue.

## Hot topics
- 07/12/2023: Batch inference support for image encoding.
- 07/11/2023: Semantic image search [example](examples/image-search/README.md) directly in C++.

## Note about image preprocessing
PIL uses a two-pass convolutions-based bicubic interpolation in resizing with antialiasing applied. In Pytorch, antialiasing is optional. It needs some extra attention to implement this preprocessing logic that matches their results numerically. However, I found that linear interpolation is also good enough for both comparison of different embeddings from this implementation and also comparison of an embedding from this implementation and another one from Transformers. So let's use it until we craft a proper bicubic interpolation.

Expand Down
207 changes: 158 additions & 49 deletions clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <iostream>
#include <regex>
#include <fstream>
#include <pthread.h>

#include "ggml/ggml.h"
#include "clip.h"

Expand All @@ -23,7 +25,7 @@ size_t get_mem_req_by_size(const size_t n_tensors, const int n_image_positions)
case 397: // base
if (n_image_positions == 50) // patch size = 32
{
return 8 * mb;
return 12 * mb;
}
else // patch size = 16
{
Expand Down Expand Up @@ -54,7 +56,7 @@ size_t get_scr_buf_req_by_size(const size_t n_tensors, const int n_positions)
case 397:
if (n_positions <= 50)
{
return 16 * mb;
return 32 * mb;
}
else
{
Expand Down Expand Up @@ -252,6 +254,77 @@ bool clip_image_preprocess(const clip_ctx *ctx, const clip_image_u8 *img, clip_i

return true;
}
// Structure to hold the image data as an input to function to be executed for thread
typedef struct
{
const clip_image_u8 *input;
clip_image_f32 *resized;
const clip_ctx *ctx;
} ImageData;

// Function to preprocess a single image in a thread
void *preprocess_image(void *arg)
{
ImageData *imageData = static_cast<ImageData *>(arg);
const clip_image_u8 *input = imageData->input;
clip_image_f32 *resized = imageData->resized;
const clip_ctx *ctx = imageData->ctx;

// Call the original preprocess function on the image
clip_image_preprocess(ctx, input, resized);

pthread_exit(NULL);
}

// Function to batch-preprocess multiple images i
void clip_image_batch_preprocess(const clip_ctx *ctx, const int n_threads, const std::vector<clip_image_u8> &img_inputs, std::vector<clip_image_f32> &imgs_resized)
{
GGML_ASSERT(img_inputs.size() == imgs_resized.size());
int num_threads = std::min(n_threads, static_cast<int>(img_inputs.size()));
int i, t;

// Divide the images among the threads
int images_per_thread = img_inputs.size() / num_threads;

if (num_threads == 1)
{
// Single-threaded case
for (i = 0; i < img_inputs.size(); i++)
{
clip_image_preprocess(ctx, &img_inputs[i], &imgs_resized[i]);
}
}
else
{
// Multi-threaded case

std::vector<pthread_t> threads(num_threads);
std::vector<ImageData> imageData(img_inputs.size());

for (t = 0; t < num_threads; t++)
{
int start_index = t * images_per_thread;
int end_index = (t == num_threads - 1) ? img_inputs.size() : start_index + images_per_thread;

// Create ImageData for each thread
for (i = start_index; i < end_index; i++)
{
imageData[i].input = &img_inputs[i];
imageData[i].resized = &imgs_resized[i];
imageData[i].ctx = ctx;
}

// Create a thread for each batch of images
pthread_create(&threads[t], NULL, preprocess_image, static_cast<void *>(&imageData[start_index]));
}

// Wait for all threads to finish
for (t = 0; t < num_threads; t++)
{
pthread_join(threads[t], NULL);
}
}
}

struct clip_ctx *clip_model_load(const char *fname, const int verbosity = 1)
{
Expand Down Expand Up @@ -840,7 +913,6 @@ bool clip_text_encode(

struct ggml_context *ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;

static size_t scr0_size = get_scr_buf_req_by_size(ctx->text_model.tensors.size() + ctx->vision_model.tensors.size(), N);
static void *scr0 = malloc(scr0_size);
Expand Down Expand Up @@ -991,7 +1063,7 @@ bool clip_text_encode(

// run the computation
ggml_build_forward_expand(&gf, embeddings);
ggml_graph_compute(ctx0, &gf);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

// print
#ifdef CLIP_DEBUG
Expand Down Expand Up @@ -1058,6 +1130,17 @@ bool clip_image_encode(
int n_threads,
const clip_image_f32 &img,
float *vec)
{
std::vector<clip_image_f32> imgs;
imgs.push_back(img);
return clip_image_batch_encode(ctx, n_threads, imgs, vec);
}

bool clip_image_batch_encode(
const clip_ctx *ctx,
int n_threads,
const std::vector<clip_image_f32> &imgs,
float *vec)
{
const auto &model = ctx->vision_model;
const auto &hparams = model.hparams;
Expand All @@ -1072,6 +1155,7 @@ bool clip_image_encode(
const int n_layer = hparams.n_layer;
const int n_intermediate = hparams.n_intermediate;
const int projection_dim = hparams.projection_dim;
int batch_size = imgs.size();

auto &buf_compute = ctx->buf_compute;

Expand All @@ -1083,51 +1167,60 @@ bool clip_image_encode(

struct ggml_context *ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;

static size_t scr0_size = get_scr_buf_req_by_size(ctx->text_model.tensors.size() + ctx->vision_model.tensors.size(), num_positions);
static void *scr0 = malloc(scr0_size);

struct ggml_tensor *inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, 1);
struct ggml_tensor *inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);

{
float *data = (float *)ggml_get_data(inp);
float *data = (float *)ggml_get_data(inp_raw);

const int nx = img.nx;
const int ny = img.ny;
const int n = nx * ny;
for (int b = 0; b < imgs.size(); b++)
{
const int nx = imgs[b].nx;
const int ny = imgs[b].ny;
GGML_ASSERT(nx == image_size && ny == image_size);

GGML_ASSERT(nx == image_size && ny == image_size);
const int n = nx * ny;

for (int k = 0; k < 3; k++)
{
for (int y = 0; y < ny; y++)
for (int b = 0; b < batch_size; b++)
{
for (int x = 0; x < nx; x++)
for (int k = 0; k < 3; k++)
{
data[k * n + y * nx + x] = img.data[3 * (y * nx + x) + k];
for (int y = 0; y < ny; y++)
{
for (int x = 0; x < nx; x++)
{
data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].data[3 * (y * nx + x) + k];
}
}
}
}
}
}

inp = ggml_conv_2d_sk_p0(ctx0, model.patch_embeddings, inp);
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
struct ggml_tensor *inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);

inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));

// concat class_embeddings and patch_embeddings
struct ggml_tensor *embeddings = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_size, num_positions);
struct ggml_tensor *embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);

ggml_set_zero(embeddings);
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], ggml_element_size(model.class_embedding) * hidden_size);
struct ggml_tensor *temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size);

embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);

struct ggml_tensor *positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
for (int i = 0; i < num_positions; i++)
{
ggml_set_i32_1d(positions, i, i);
}

embeddings = ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
embeddings = ggml_add(ctx0, embeddings, ggml_repeat(ctx0, ggml_get_rows(ctx0, model.position_embeddings, positions), embeddings));

// pre-layernorm
{
Expand All @@ -1145,6 +1238,8 @@ bool clip_image_encode(
{
struct ggml_tensor *cur = embeddings; // embeddings = residual, cur = hidden_states

const size_t nb_q_w = model.layers[il].q_w->nb[0];

ggml_set_scratch(ctx0, {0, scr0_size, scr0});

// layernorm1
Expand All @@ -1160,44 +1255,48 @@ bool clip_image_encode(

// self-attention
{

struct ggml_tensor *Q = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur),
ggml_mul_mat(ctx0, model.layers[il].q_w, cur));
ggml_mul_mat(ctx0, model.layers[il].q_w,
cur));

Q = ggml_scale_inplace(ctx0, Q, ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head)));
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, 1);
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head);
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);

struct ggml_tensor *K =
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur),
ggml_mul_mat(ctx0, model.layers[il].k_w, cur));
struct ggml_tensor *K = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur),
ggml_mul_mat(ctx0, model.layers[il].k_w,
cur));

K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, 1);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head);
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);

struct ggml_tensor *V =
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur),
ggml_mul_mat(ctx0, model.layers[il].v_w, cur));
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, 1);
struct ggml_tensor *V = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur),
ggml_mul_mat(ctx0, model.layers[il].v_w,
cur));

V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head);
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);

struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_inplace(ctx0, KQ);
struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, 1);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
KQV = ggml_cont(ctx0, ggml_permute(ctx0, KQV, 0, 2, 1, 3));

cur = ggml_cpy(ctx0,
KQV,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_size, num_positions));
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size));
}

// attention output
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].o_b, cur),
ggml_mul_mat(ctx0, model.layers[il].o_w, cur));
ggml_mul_mat(ctx0, model.layers[il].o_w,
cur));

// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
Expand Down Expand Up @@ -1236,14 +1335,17 @@ bool clip_image_encode(

// residual 2
cur = ggml_add(ctx0, embeddings, cur);
// ggml_set_name(cur, "check");

embeddings = cur;
}

// get the output of cls token, e.g., 0th index
struct ggml_tensor *cls = ggml_new_i32(ctx0, 0);
embeddings = ggml_get_rows(ctx0, embeddings, cls);
struct ggml_tensor *cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
for (int b = 0; b < batch_size; b++)
{
ggml_set_i32_1d(cls, b, b * num_positions);
}
embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls);

// post-layernorm
{
Expand All @@ -1262,15 +1364,21 @@ bool clip_image_encode(
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);

// normalize output embeddings
ggml_tensor *length = ggml_sqrt(ctx0,
ggml_sum(ctx0, ggml_sqr(ctx0, embeddings)));
embeddings = ggml_scale_inplace(ctx0, embeddings, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
struct ggml_tensor *output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);

ggml_set_name(embeddings, "check");
for (int b = 0; b < batch_size; b++)
{
struct ggml_tensor *embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
ggml_tensor *length = ggml_sqrt(ctx0,
ggml_sum(ctx0, ggml_sqr(ctx0, embedding)));
embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
}
ggml_set_name(output, "check");

// run the computation
ggml_build_forward_expand(&gf, embeddings);
ggml_graph_compute(ctx0, &gf);
ggml_build_forward_expand(&gf, output);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

// print
#ifdef CLIP_DEBUG
Expand Down Expand Up @@ -1313,6 +1421,7 @@ bool clip_image_encode(
};

auto *t = ggml_get_tensor(ctx0, "check");
// auto t = inp_raw;
if (t->type == GGML_TYPE_F32)
{
print_t_f32(t);
Expand All @@ -1326,7 +1435,7 @@ bool clip_image_encode(
printf("used_mem = %zu\n", ggml_used_mem(ctx0));
#endif

memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim);
memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size);

ggml_free(ctx0);

Expand Down
Loading

0 comments on commit 018df28

Please sign in to comment.