Skip to content

Commit

Permalink
refactored normalization and updated text model size
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavkgrd committed Jan 9, 2024
1 parent 876b639 commit 6ae6233
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions src/services/clipService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const IMAGE_MODEL_SIZE_IN_BYTES = {
};
const TEXT_MODEL_SIZE_IN_BYTES = {
ggml: 127853440, // 121.9 MB,
onnx: 254069585, // 242.3 MB
onnx: 64173509, // 61.2 MB
};

const MODEL_SAVE_FOLDER = 'models';
Expand Down Expand Up @@ -266,15 +266,7 @@ export async function computeONNXImageEmbedding(
};
const results = await imageSession.run(feeds);
const imageEmbedding = results['output'].data; // Float32Array
let imageNormalization = 0;
for (let index = 0; index < imageEmbedding.length; index++) {
imageNormalization += imageEmbedding[index] * imageEmbedding[index];
}
for (let index = 0; index < imageEmbedding.length; index++) {
imageEmbedding[index] =
imageEmbedding[index] / Math.sqrt(imageNormalization);
}
return imageEmbedding;
return normalizeEmbedding(imageEmbedding);
} catch (err) {
logErrorSentry(err, 'Error in computeImageEmbedding');
throw err;
Expand Down Expand Up @@ -344,7 +336,7 @@ export async function computeONNXTextEmbedding(
};
const results = await imageSession.run(feeds);
const embedVec = results['output'].data; // Float32Array
return embedVec;
return normalizeEmbedding(embedVec);
} catch (err) {
if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
Expand Down Expand Up @@ -431,21 +423,19 @@ export const computeClipMatchScore = async (
throw Error('imageEmbedding and textEmbedding length mismatch');
}
let score = 0;
let imageNormalization = 0;
let textNormalization = 0;

for (let index = 0; index < imageEmbedding.length; index++) {
imageNormalization += imageEmbedding[index] * imageEmbedding[index];
textNormalization += textEmbedding[index] * textEmbedding[index];
}
for (let index = 0; index < imageEmbedding.length; index++) {
imageEmbedding[index] =
imageEmbedding[index] / Math.sqrt(imageNormalization);
textEmbedding[index] =
textEmbedding[index] / Math.sqrt(textNormalization);
}
for (let index = 0; index < imageEmbedding.length; index++) {
score += imageEmbedding[index] * textEmbedding[index];
}
return score;
};

export const normalizeEmbedding = (embedding: Float32Array) => {
let normalization = 0;
for (let index = 0; index < embedding.length; index++) {
normalization += embedding[index] * embedding[index];
}
for (let index = 0; index < embedding.length; index++) {
embedding[index] = embedding[index] / Math.sqrt(normalization);
}
return embedding;
};

0 comments on commit 6ae6233

Please sign in to comment.