-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Default Embedding Function for JS (#1382)
End dx: `npm install chromadb chromadb-default-embed` `chromadb-default-embed` is a fork of `@xenova/transfomers` to maintain stability *** Motivation - good defaults are good DX - good defaults lower the barrier to getting started - currently JS usage is gated on having an API key (seems bad) `npm install --save chromadb @xenova/transformers` - We want to use the same EF as `python` - `all-MiniLM-L6-v2` - We want to keep our default package size small (currently 4.5mb) - We want a happy path for devs getting started - they shouldnt need to create any accounts or get API keys - `@xenova/transformers` is great, but it's huge - ~250MB! - so we can't by default bundle it - To have a happy path, but keep the bundle size small - we just ask users to run `npm install --save chromadb @xenova/transformers` to install chroma. we can add a comment like `// optional default embedding function` I also evaluated `https://github.com/visheratin/web-ai` - which is small (~8MB), but I dont think it supports this model yet? (thought potentially possible) and https://github.com/microsoft/onnxruntime/tree/main, which is also massive (over 100MB). I confirmed that if you just install `chromadb` and pass `OpenAIEmbeddingFunction` (or other) - it doesn't complain or yell at you that you have a missing dep. If you true to use the default and don't have `@xenova/transformers` installed, it will tell you to use it. Todo - [ ] test no require warnings in `nextjs` - this has been an issue in the past Thoughts about this DX?
- Loading branch information
1 parent
920ebf3
commit fca3426
Showing
4 changed files
with
215 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import { IEmbeddingFunction } from "./IEmbeddingFunction"; | ||
|
||
// Dynamically import module | ||
let TransformersApi: Promise<any>; | ||
|
||
export class DefaultEmbeddingFunction implements IEmbeddingFunction { | ||
private pipelinePromise?: Promise<any> | null; | ||
private transformersApi: any; | ||
private model: string; | ||
private revision: string; | ||
private quantized: boolean; | ||
private progress_callback: Function | null; | ||
|
||
/** | ||
* DefaultEmbeddingFunction constructor. | ||
* @param options The configuration options. | ||
* @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`. | ||
* @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'. | ||
* @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`. | ||
* @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates. | ||
*/ | ||
constructor({ | ||
model = "Xenova/all-MiniLM-L6-v2", | ||
revision = "main", | ||
quantized = false, | ||
progress_callback = null, | ||
}: { | ||
model?: string; | ||
revision?: string; | ||
quantized?: boolean; | ||
progress_callback?: Function | null; | ||
} = {}) { | ||
this.model = model; | ||
this.revision = revision; | ||
this.quantized = quantized; | ||
this.progress_callback = progress_callback; | ||
} | ||
|
||
public async generate(texts: string[]): Promise<number[][]> { | ||
await this.loadClient(); | ||
|
||
// Store a promise that resolves to the pipeline | ||
this.pipelinePromise = new Promise(async (resolve, reject) => { | ||
try { | ||
const pipeline = this.transformersApi | ||
|
||
const quantized = this.quantized | ||
const revision = this.revision | ||
const progress_callback = this.progress_callback | ||
|
||
resolve( | ||
await pipeline("feature-extraction", this.model, { | ||
quantized, | ||
revision, | ||
progress_callback, | ||
}) | ||
); | ||
} catch (e) { | ||
reject(e); | ||
} | ||
}); | ||
|
||
let pipe = await this.pipelinePromise; | ||
let output = await pipe(texts, { pooling: "mean", normalize: true }); | ||
return output.tolist(); | ||
} | ||
|
||
private async loadClient() { | ||
if(this.transformersApi) return; | ||
try { | ||
// eslint-disable-next-line global-require,import/no-extraneous-dependencies | ||
let { pipeline } = await DefaultEmbeddingFunction.import(); | ||
TransformersApi = pipeline; | ||
} catch (_a) { | ||
// @ts-ignore | ||
if (_a.code === 'MODULE_NOT_FOUND') { | ||
throw new Error("Please install the chromadb-default-embed package to use the DefaultEmbeddingFunction, `npm install -S chromadb-default-embed`"); | ||
} | ||
throw _a; // Re-throw other errors | ||
} | ||
this.transformersApi = TransformersApi; | ||
} | ||
|
||
/** @ignore */ | ||
static async import(): Promise<{ | ||
// @ts-ignore | ||
pipeline: typeof import("chromadb-default-embed"); | ||
}> { | ||
try { | ||
// @ts-ignore | ||
const { pipeline } = await import("chromadb-default-embed"); | ||
return { pipeline }; | ||
} catch (e) { | ||
throw new Error( | ||
"Please install chromadb-default-embed as a dependency with, e.g. `yarn add chromadb-default-embed`" | ||
); | ||
} | ||
} | ||
} |
99 changes: 99 additions & 0 deletions
99
clients/js/src/embeddings/TransformersEmbeddingFunction.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import { IEmbeddingFunction } from "./IEmbeddingFunction"; | ||
|
||
// Dynamically import module | ||
let TransformersApi: Promise<any>; | ||
|
||
export class TransformersEmbeddingFunction implements IEmbeddingFunction { | ||
private pipelinePromise?: Promise<any> | null; | ||
private transformersApi: any; | ||
private model: string; | ||
private revision: string; | ||
private quantized: boolean; | ||
private progress_callback: Function | null; | ||
|
||
/** | ||
* TransformersEmbeddingFunction constructor. | ||
* @param options The configuration options. | ||
* @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`. | ||
* @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'. | ||
* @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`. | ||
* @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates. | ||
*/ | ||
constructor({ | ||
model = "Xenova/all-MiniLM-L6-v2", | ||
revision = "main", | ||
quantized = false, | ||
progress_callback = null, | ||
}: { | ||
model?: string; | ||
revision?: string; | ||
quantized?: boolean; | ||
progress_callback?: Function | null; | ||
} = {}) { | ||
this.model = model; | ||
this.revision = revision; | ||
this.quantized = quantized; | ||
this.progress_callback = progress_callback; | ||
} | ||
|
||
public async generate(texts: string[]): Promise<number[][]> { | ||
await this.loadClient(); | ||
|
||
// Store a promise that resolves to the pipeline | ||
this.pipelinePromise = new Promise(async (resolve, reject) => { | ||
try { | ||
const pipeline = this.transformersApi | ||
|
||
const quantized = this.quantized | ||
const revision = this.revision | ||
const progress_callback = this.progress_callback | ||
|
||
resolve( | ||
await pipeline("feature-extraction", this.model, { | ||
quantized, | ||
revision, | ||
progress_callback, | ||
}) | ||
); | ||
} catch (e) { | ||
reject(e); | ||
} | ||
}); | ||
|
||
let pipe = await this.pipelinePromise; | ||
let output = await pipe(texts, { pooling: "mean", normalize: true }); | ||
return output.tolist(); | ||
} | ||
|
||
private async loadClient() { | ||
if(this.transformersApi) return; | ||
try { | ||
// eslint-disable-next-line global-require,import/no-extraneous-dependencies | ||
let { pipeline } = await TransformersEmbeddingFunction.import(); | ||
TransformersApi = pipeline; | ||
} catch (_a) { | ||
// @ts-ignore | ||
if (_a.code === 'MODULE_NOT_FOUND') { | ||
throw new Error("Please install the @xenova/transformers package to use the TransformersEmbeddingFunction, `npm install -S @xenova/transformers`"); | ||
} | ||
throw _a; // Re-throw other errors | ||
} | ||
this.transformersApi = TransformersApi; | ||
} | ||
|
||
/** @ignore */ | ||
static async import(): Promise<{ | ||
// @ts-ignore | ||
pipeline: typeof import("@xenova/transformers"); | ||
}> { | ||
try { | ||
// @ts-ignore | ||
const { pipeline } = await import("@xenova/transformers"); | ||
return { pipeline }; | ||
} catch (e) { | ||
throw new Error( | ||
"Please install @xenova/transformers as a dependency with, e.g. `yarn add @xenova/transformers`" | ||
); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters