Skip to content

Commit

Permalink
build embeddings database
Browse files Browse the repository at this point in the history
  • Loading branch information
Green-Sky committed Jun 17, 2023
1 parent e757b33 commit 304c3e6
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
build/
.vscode
.cache/
*.swp
models/*.bin
tests/*.txt
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,5 @@ if (BUILD_SHARED_LIBS)
endif()

add_subdirectory(examples)
# add_subdirectory(models)
# add_subdirectory(models)

5 changes: 4 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
add_subdirectory(./image-search)

add_executable(main main.cpp)
target_link_libraries(main PRIVATE clip ggml)
target_link_libraries(main PRIVATE clip ggml)

11 changes: 11 additions & 0 deletions examples/image-search/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
set(CXX_STANDARD_REQUIRED ON)

add_executable(image-search-build
usearch/index.hpp
build.cpp
)

target_link_libraries(image-search-build PRIVATE clip ggml)

target_compile_features(image-search-build PUBLIC cxx_std_17)

155 changes: 155 additions & 0 deletions examples/image-search/build.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#include "clip.h"
#include "usearch/index.hpp"

#include <fstream>
#include <filesystem>

struct my_app_params {
int32_t n_threads {4};
std::string model = "../models/ggml-model-f16.bin";
std::vector<std::string> image_directories;
};

void my_print_help(int argc, char **argv, my_app_params &params) {
printf("Usage: %s [options] dir/with/pictures [more/dirs]\n", argv[0]);
printf("\nOptions:");
printf(" -h, --help: Show this message and exit\n");
printf(" -m <path>, --model <path>: path to model. Default: %s\n", params.model.c_str());
printf(" -t N, --threads N: Number of threads to use for inference. Default: %d\n", params.n_threads);
}

// returns success
bool my_app_params_parse(int argc, char **argv, my_app_params &params) {
bool invalid_param = false;
for (int i = 1; i < argc; i++) {

std::string arg = argv[i];

if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.model = argv[i];
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_threads = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") {
my_print_help(argc, argv, params);
exit(0);
} else if (arg.starts_with('-')) {
if (i != 0) {
printf("%s: unrecognized argument: %s\n", __func__, arg.c_str());
return false;
}
} else {
// assume image directory
params.image_directories.push_back(argv[i]);
}
}

return !(invalid_param || params.image_directories.empty());
}

bool is_image_file_extension(std::string_view ext) {
if (ext == ".jpg") return true;
if (ext == ".JPG") return true;

if (ext == ".jpeg") return true;
if (ext == ".JPEG") return true;

if (ext == ".gif") return true;
if (ext == ".GIF") return true;

if (ext == ".png") return true;
if (ext == ".PNG") return true;

// TODO: add all the supported files for std_image

return false;
}

int main(int argc, char** argv) {
my_app_params params;
if (!my_app_params_parse(argc, argv, params)) {
my_print_help(argc, argv, params);
return 1;
}

auto clip_ctx = clip_model_load(params.model.c_str());
if (!clip_ctx) {
printf("%s: Unable to load model from %s\n", __func__, params.model.c_str());
return 1;
}

std::vector<std::string> image_file_index;
unum::usearch::index_gt<unum::usearch::cos_gt<float>> embd_index;

const size_t vec_dim = clip_ctx->vision_model.hparams.projection_dim;

size_t label = 0;

// search for images in path and create embedding
for (const auto& base_dir : params.image_directories) {
fprintf(stdout, "%s: starting base dir scan of '%s'\n", __func__, base_dir.c_str());

for (auto const& dir_entry : std::filesystem::recursive_directory_iterator(base_dir)) {
if (!dir_entry.is_regular_file()) {
continue;
}

// check for image file
const auto& ext = dir_entry.path().extension();
if (ext.empty()) {
continue;
}
if (!is_image_file_extension(ext.c_str())) {
continue;
}

std::string img_path {dir_entry.path()};
fprintf(stdout, "%s: found image file '%s'\n", __func__, img_path.c_str());

clip_image_u8 img0;
clip_image_f32 img_res;
if (!clip_image_load_from_file(img_path, img0)) {
fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, img_path.c_str());
continue;
}

clip_image_preprocess(&img0, &img_res);

std::vector<float> vec(vec_dim);
if (!clip_image_encode(clip_ctx, params.n_threads, img_res, vec.data())) {
fprintf(stderr, "%s: failed to encode image from '%s'\n", __func__, img_path.c_str());
continue;
}

if (embd_index.capacity() == embd_index.size()) {
embd_index.reserve(embd_index.size() + 32);
}

// add the image
embd_index.add(label++, {vec.data(), vec.size()});
image_file_index.push_back(std::filesystem::canonical(dir_entry.path()));
}
}

clip_free(clip_ctx);

// save to disk

embd_index.save("images.usearch");
// TODO: figrue out why usearch leaks

std::ofstream image_file_index_file("images.paths", std::ios::binary | std::ios::trunc);
for (const auto& i_path : image_file_index) {
image_file_index_file << i_path << "\n";
}

return 0;
}

0 comments on commit 304c3e6

Please sign in to comment.