Skip to content

Commit

Permalink
Default Embedding Function for JS (#1382)
Browse files Browse the repository at this point in the history
End dx: `npm install chromadb chromadb-default-embed`

`chromadb-default-embed` is a fork of `@xenova/transfomers` to maintain
stability

***

Motivation
- good defaults are good DX
- good defaults lower the barrier to getting started
- currently JS usage is gated on having an API key (seems bad) 

`npm install --save chromadb @xenova/transformers` 

- We want to use the same EF as `python` - `all-MiniLM-L6-v2`
- We want to keep our default package size small (currently 4.5mb)
- We want a happy path for devs getting started - they shouldnt need to
create any accounts or get API keys
- `@xenova/transformers` is great, but it's huge - ~250MB! - so we can't
by default bundle it
- To have a happy path, but keep the bundle size small - we just ask
users to run `npm install --save chromadb @xenova/transformers` to
install chroma. we can add a comment like `// optional default embedding
function`

I also evaluated `https://github.com/visheratin/web-ai` - which is small
(~8MB), but I dont think it supports this model yet? (thought
potentially possible) and
https://github.com/microsoft/onnxruntime/tree/main, which is also
massive (over 100MB).

I confirmed that if you just install `chromadb` and pass
`OpenAIEmbeddingFunction` (or other) - it doesn't complain or yell at
you that you have a missing dep. If you true to use the default and
don't have `@xenova/transformers` installed, it will tell you to use it.

Todo
- [ ] test no require warnings in `nextjs` - this has been an issue in
the past

Thoughts about this DX?
  • Loading branch information
jeffchuber authored Jan 5, 2024
1 parent 920ebf3 commit fca3426
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 2 deletions.
11 changes: 11 additions & 0 deletions clients/js/src/ChromaClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
ClientAuthProtocolAdapter,
IsomorphicFetchClientAuthProtocolAdapter
} from "./auth";
import { DefaultEmbeddingFunction } from './embeddings/DefaultEmbeddingFunction';
import { AdminClient } from './AdminClient';

const DEFAULT_TENANT = "default_tenant"
Expand Down Expand Up @@ -144,6 +145,11 @@ export class ChromaClient {
metadata,
embeddingFunction
}: CreateCollectionParams): Promise<Collection> {

if (embeddingFunction === undefined) {
embeddingFunction = new DefaultEmbeddingFunction();
}

const newCollection = await this.api
.createCollection(this.tenant, this.database, {
name,
Expand Down Expand Up @@ -185,6 +191,11 @@ export class ChromaClient {
metadata,
embeddingFunction
}: GetOrCreateCollectionParams): Promise<Collection> {

if (embeddingFunction === undefined) {
embeddingFunction = new DefaultEmbeddingFunction();
}

const newCollection = await this.api
.createCollection(this.tenant, this.database, {
name,
Expand Down
99 changes: 99 additions & 0 deletions clients/js/src/embeddings/DefaultEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

// Dynamically import module
let TransformersApi: Promise<any>;

export class DefaultEmbeddingFunction implements IEmbeddingFunction {
private pipelinePromise?: Promise<any> | null;
private transformersApi: any;
private model: string;
private revision: string;
private quantized: boolean;
private progress_callback: Function | null;

/**
* DefaultEmbeddingFunction constructor.
* @param options The configuration options.
* @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`.
* @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'.
* @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`.
* @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates.
*/
constructor({
model = "Xenova/all-MiniLM-L6-v2",
revision = "main",
quantized = false,
progress_callback = null,
}: {
model?: string;
revision?: string;
quantized?: boolean;
progress_callback?: Function | null;
} = {}) {
this.model = model;
this.revision = revision;
this.quantized = quantized;
this.progress_callback = progress_callback;
}

public async generate(texts: string[]): Promise<number[][]> {
await this.loadClient();

// Store a promise that resolves to the pipeline
this.pipelinePromise = new Promise(async (resolve, reject) => {
try {
const pipeline = this.transformersApi

const quantized = this.quantized
const revision = this.revision
const progress_callback = this.progress_callback

resolve(
await pipeline("feature-extraction", this.model, {
quantized,
revision,
progress_callback,
})
);
} catch (e) {
reject(e);
}
});

let pipe = await this.pipelinePromise;
let output = await pipe(texts, { pooling: "mean", normalize: true });
return output.tolist();
}

private async loadClient() {
if(this.transformersApi) return;
try {
// eslint-disable-next-line global-require,import/no-extraneous-dependencies
let { pipeline } = await DefaultEmbeddingFunction.import();
TransformersApi = pipeline;
} catch (_a) {
// @ts-ignore
if (_a.code === 'MODULE_NOT_FOUND') {
throw new Error("Please install the chromadb-default-embed package to use the DefaultEmbeddingFunction, `npm install -S chromadb-default-embed`");
}
throw _a; // Re-throw other errors
}
this.transformersApi = TransformersApi;
}

/** @ignore */
static async import(): Promise<{
// @ts-ignore
pipeline: typeof import("chromadb-default-embed");
}> {
try {
// @ts-ignore
const { pipeline } = await import("chromadb-default-embed");
return { pipeline };
} catch (e) {
throw new Error(
"Please install chromadb-default-embed as a dependency with, e.g. `yarn add chromadb-default-embed`"
);
}
}
}
99 changes: 99 additions & 0 deletions clients/js/src/embeddings/TransformersEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

// Dynamically import module
let TransformersApi: Promise<any>;

export class TransformersEmbeddingFunction implements IEmbeddingFunction {
private pipelinePromise?: Promise<any> | null;
private transformersApi: any;
private model: string;
private revision: string;
private quantized: boolean;
private progress_callback: Function | null;

/**
* TransformersEmbeddingFunction constructor.
* @param options The configuration options.
* @param options.model The model to use to calculate embeddings. Defaults to 'Xenova/all-MiniLM-L6-v2', which is an ONNX port of `sentence-transformers/all-MiniLM-L6-v2`.
* @param options.revision The specific model version to use (can be a branch, tag name, or commit id). Defaults to 'main'.
* @param options.quantized Whether to load the 8-bit quantized version of the model. Defaults to `false`.
* @param options.progress_callback If specified, this function will be called during model construction, to provide the user with progress updates.
*/
constructor({
model = "Xenova/all-MiniLM-L6-v2",
revision = "main",
quantized = false,
progress_callback = null,
}: {
model?: string;
revision?: string;
quantized?: boolean;
progress_callback?: Function | null;
} = {}) {
this.model = model;
this.revision = revision;
this.quantized = quantized;
this.progress_callback = progress_callback;
}

public async generate(texts: string[]): Promise<number[][]> {
await this.loadClient();

// Store a promise that resolves to the pipeline
this.pipelinePromise = new Promise(async (resolve, reject) => {
try {
const pipeline = this.transformersApi

const quantized = this.quantized
const revision = this.revision
const progress_callback = this.progress_callback

resolve(
await pipeline("feature-extraction", this.model, {
quantized,
revision,
progress_callback,
})
);
} catch (e) {
reject(e);
}
});

let pipe = await this.pipelinePromise;
let output = await pipe(texts, { pooling: "mean", normalize: true });
return output.tolist();
}

private async loadClient() {
if(this.transformersApi) return;
try {
// eslint-disable-next-line global-require,import/no-extraneous-dependencies
let { pipeline } = await TransformersEmbeddingFunction.import();
TransformersApi = pipeline;
} catch (_a) {
// @ts-ignore
if (_a.code === 'MODULE_NOT_FOUND') {
throw new Error("Please install the @xenova/transformers package to use the TransformersEmbeddingFunction, `npm install -S @xenova/transformers`");
}
throw _a; // Re-throw other errors
}
this.transformersApi = TransformersApi;
}

/** @ignore */
static async import(): Promise<{
// @ts-ignore
pipeline: typeof import("@xenova/transformers");
}> {
try {
// @ts-ignore
const { pipeline } = await import("@xenova/transformers");
return { pipeline };
} catch (e) {
throw new Error(
"Please install @xenova/transformers as a dependency with, e.g. `yarn add @xenova/transformers`"
);
}
}
}
8 changes: 6 additions & 2 deletions clients/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ export { ChromaClient } from './ChromaClient';
export { AdminClient } from './AdminClient';
export { CloudClient } from './CloudClient';
export { Collection } from './Collection';

export { IEmbeddingFunction } from './embeddings/IEmbeddingFunction';
export { OpenAIEmbeddingFunction } from './embeddings/OpenAIEmbeddingFunction';
export { CohereEmbeddingFunction } from './embeddings/CohereEmbeddingFunction';
export { TransformersEmbeddingFunction } from './embeddings/TransformersEmbeddingFunction';
export { DefaultEmbeddingFunction } from './embeddings/DefaultEmbeddingFunction';
export { HuggingFaceEmbeddingServerFunction } from './embeddings/HuggingFaceEmbeddingServerFunction';
export { JinaEmbeddingFunction } from './embeddings/JinaEmbeddingFunction';
export { GoogleGenerativeAiEmbeddingFunction } from './embeddings/GoogleGeminiEmbeddingFunction';

export {
IncludeEnum,
GetParams,
Expand Down Expand Up @@ -37,5 +43,3 @@ export {
PeekParams,
DeleteParams
} from './types';
export { HuggingFaceEmbeddingServerFunction } from './embeddings/HuggingFaceEmbeddingServerFunction';
export { JinaEmbeddingFunction } from './embeddings/JinaEmbeddingFunction';

0 comments on commit fca3426

Please sign in to comment.