Skip to content

Commit

Permalink
update js client to have ids optional
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 2, 2024
1 parent d95f1dc commit 8999e52
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 40 deletions.
15 changes: 8 additions & 7 deletions clients/js/src/ChromaClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { AdminClient } from "./AdminClient";
import { authOptionsToAuthProvider, ClientAuthProvider } from "./auth";
import { chromaFetch } from "./ChromaFetch";
import { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction";
import { ChromaConnectionError, ChromaServerError } from "./Errors";
import {
Configuration,
ApiApi as DefaultApi,
Expand All @@ -18,7 +17,6 @@ import type {
CreateCollectionParams,
DeleteCollectionParams,
DeleteParams,
Embedding,
Embeddings,
GetCollectionParams,
GetOrCreateCollectionParams,
Expand All @@ -33,6 +31,7 @@ import type {
} from "./types";
import {
prepareRecordRequest,
prepareRecordRequestWithIDsOptional,
toArray,
toArrayOfArrays,
validateTenantDatabase,
Expand Down Expand Up @@ -397,7 +396,7 @@ export class ChromaClient {
/**
* Add items to the collection
* @param {Object} params - The parameters for the query.
* @param {ID | IDs} [params.ids] - IDs of the items to add.
* @param {ID | IDs} [params.ids] - Optional IDs of the items to add.
* @param {Embedding | Embeddings} [params.embeddings] - Optional embeddings of the items to add.
* @param {Metadata | Metadatas} [params.metadatas] - Optional metadata of the items to add.
* @param {Document | Documents} [params.documents] - Optional documents of the items to add.
Expand All @@ -416,18 +415,20 @@ export class ChromaClient {
async addRecords(
collection: Collection,
params: AddRecordsParams,
): Promise<void> {
): Promise<AddResponse> {
await this.init();

await this.api.add(
const resp = (await this.api.add(
collection.id,
// TODO: For some reason the auto generated code requires metadata to be defined here.
(await prepareRecordRequest(
(await prepareRecordRequestWithIDsOptional(
params,
collection.embeddingFunction,
)) as GeneratedApi.AddEmbedding,
this.api.options,
);
)) as AddResponse;

return resp;
}

/**
Expand Down
60 changes: 58 additions & 2 deletions clients/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ export type MultiQueryResponse = {

export type QueryResponse = SingleQueryResponse | MultiQueryResponse;

export type AddResponse = {};
export type AddResponse = {
ids: IDs;
};

export interface Collection {
name: string;
Expand Down Expand Up @@ -164,13 +166,28 @@ export type BaseRecordOperationParams = {
documents?: Document | Documents;
};

export type BaseRecordOperationParamsWithIDsOptional = {
ids?: ID | IDs;
embeddings?: Embedding | Embeddings;
metadatas?: Metadata | Metadatas;
documents?: Document | Documents;
};

export type SingleRecordOperationParams = BaseRecordOperationParams & {
ids: ID;
embeddings?: Embedding;
metadatas?: Metadata;
documents?: Document;
};

export type SingleRecordOperationParamsWithIDsOptional =
BaseRecordOperationParamsWithIDsOptional & {
ids?: ID;
embeddings?: Embedding;
metadatas?: Metadata;
documents?: Document;
};

type SingleEmbeddingRecordOperationParams = SingleRecordOperationParams & {
embeddings: Embedding;
};
Expand All @@ -183,13 +200,31 @@ export type SingleAddRecordOperationParams =
| SingleEmbeddingRecordOperationParams
| SingleContentRecordOperationParams;

type SingleEmbeddingRecordOperationParamsWithOptionalIDs =
BaseRecordOperationParamsWithIDsOptional & {
embeddings: Embedding;
};

type SingleContentRecordOperationParamsWithOptionalIDs =
BaseRecordOperationParamsWithIDsOptional & {
documents: Document;
};

export type MultiRecordOperationParams = BaseRecordOperationParams & {
ids: IDs;
embeddings?: Embeddings;
metadatas?: Metadatas;
documents?: Documents;
};

export type MultiRecordOperationParamsWithIDsOptional =
BaseRecordOperationParamsWithIDsOptional & {
ids?: IDs;
embeddings?: Embeddings;
metadatas?: Metadatas;
documents?: Documents;
};

type MultiEmbeddingRecordOperationParams = MultiRecordOperationParams & {
embeddings: Embeddings;
};
Expand All @@ -198,15 +233,36 @@ type MultiContentRecordOperationParams = MultiRecordOperationParams & {
documents: Documents;
};

type MultiEmbeddingRecordOperationParamsWithOptionalIDs =
MultiRecordOperationParamsWithIDsOptional & {
embeddings: Embeddings;
};

type MultiContentRecordOperationParamsWithOptionalIDs =
MultiRecordOperationParamsWithIDsOptional & {
documents: Documents;
};

export type SingleAddRecordOperationParamsWithOptionalIDs =
| SingleEmbeddingRecordOperationParamsWithOptionalIDs
| SingleContentRecordOperationParamsWithOptionalIDs;

export type MultiAddRecordsOperationParamsWithOptionalIDs =
| MultiEmbeddingRecordOperationParamsWithOptionalIDs
| MultiContentRecordOperationParamsWithOptionalIDs;

export type MultiAddRecordsOperationParams =
| MultiEmbeddingRecordOperationParams
| MultiContentRecordOperationParams;

export type AddRecordsParams =
| SingleAddRecordOperationParamsWithOptionalIDs
| MultiAddRecordsOperationParamsWithOptionalIDs;

export type UpsertRecordsParams =
| SingleAddRecordOperationParams
| MultiAddRecordsOperationParams;

export type UpsertRecordsParams = AddRecordsParams;
export type UpdateRecordsParams =
| MultiRecordOperationParams
| SingleRecordOperationParams;
Expand Down
93 changes: 70 additions & 23 deletions clients/js/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ import { ChromaConnectionError } from "./Errors";
import { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction";
import {
AddRecordsParams,
BaseRecordOperationParams,
BaseRecordOperationParamsWithIDsOptional,
Collection,
Embeddings,
Documents,
Metadata,
MultiRecordOperationParams,
MultiRecordOperationParamsWithIDsOptional,
UpdateRecordsParams,
UpsertRecordsParams,
} from "./types";

// a function to convert a non-Array object to an Array
Expand Down Expand Up @@ -82,10 +86,10 @@ export function isBrowser() {
}

function arrayifyParams(
params: BaseRecordOperationParams,
): MultiRecordOperationParams {
params: BaseRecordOperationParamsWithIDsOptional,
): MultiRecordOperationParamsWithIDsOptional {
return {
ids: toArray(params.ids),
ids: params.ids !== undefined ? toArray(params.ids) : undefined,
embeddings: params.embeddings
? toArrayOfArrays(params.embeddings)
: undefined,
Expand All @@ -97,16 +101,72 @@ function arrayifyParams(
}

export async function prepareRecordRequest(
reqParams: AddRecordsParams | UpdateRecordsParams,
reqParams: UpsertRecordsParams | UpdateRecordsParams,
embeddingFunction: IEmbeddingFunction,
update?: true,
): Promise<MultiRecordOperationParams> {
const { ids, embeddings, metadatas, documents } = arrayifyParams(reqParams);
const {
ids = [],
embeddings,
metadatas,
documents,
} = arrayifyParams(reqParams);

if (!embeddings && !documents && !update) {
throw new Error("embeddings and documents cannot both be undefined");
}

validateIDs(ids);

const embeddingsArray = await computeEmbeddings(
embeddingFunction,
embeddings,
documents,
update,
);

return {
ids,
metadatas,
documents,
embeddings: embeddingsArray,
};
}

export async function prepareRecordRequestWithIDsOptional(
reqParams: AddRecordsParams,
embeddingFunction: IEmbeddingFunction,
): Promise<MultiRecordOperationParamsWithIDsOptional> {
const { ids, embeddings, metadatas, documents } = arrayifyParams(reqParams);

if (!embeddings && !documents) {
throw new Error("embeddings and documents cannot both be undefined");
}

if (ids) {
validateIDs(ids);
}

const embeddingsArray = await computeEmbeddings(
embeddingFunction,
embeddings,
documents,
);

return {
ids,
metadatas,
documents,
embeddings: embeddingsArray,
};
}

async function computeEmbeddings(
embeddingFunction: IEmbeddingFunction,
embeddings?: Embeddings,
documents?: Documents,
update?: true,
): Promise<Embeddings | undefined> {
const embeddingsArray = embeddings
? embeddings
: documents
Expand All @@ -117,6 +177,10 @@ export async function prepareRecordRequest(
throw new Error("Failed to generate embeddings for your request.");
}

return embeddingsArray;
}

function validateIDs(ids: string[]) {
for (let i = 0; i < ids.length; i += 1) {
if (typeof ids[i] !== "string") {
throw new Error(
Expand All @@ -125,16 +189,6 @@ export async function prepareRecordRequest(
}
}

if (
(embeddingsArray !== undefined && ids.length !== embeddingsArray.length) ||
(metadatas !== undefined && ids.length !== metadatas.length) ||
(documents !== undefined && ids.length !== documents.length)
) {
throw new Error(
"ids, embeddings, metadatas, and documents must all be the same length",
);
}

const uniqueIds = new Set(ids);
if (uniqueIds.size !== ids.length) {
const duplicateIds = ids.filter(
Expand All @@ -144,13 +198,6 @@ export async function prepareRecordRequest(
`ID's must be unique, found duplicates for: ${duplicateIds}`,
);
}

return {
ids,
metadatas,
documents,
embeddings: embeddingsArray,
};
}

function notifyUserOfLegacyMethod(newMethod: string) {
Expand Down
20 changes: 12 additions & 8 deletions clients/js/test/add.collections.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,27 @@ describe("add collections", () => {
const ids = IDS.concat(["test1"]);
const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]);
const metadatas = METADATAS.concat([{ test: "test1", float_value: 0.1 }]);
try {
expect(async () => {
await client.addRecords(collection, { ids, embeddings, metadatas });
} catch (e: any) {
expect(e.message).toMatch("duplicates");
}
}).rejects.toThrow("found duplicates");
});

test("It should generate IDs if not provided", async () => {
const collection = await client.createCollection({ name: "test" });
const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]);
const metadatas = METADATAS.concat([{ test: "test1", float_value: 0.1 }]);
const resp = await client.addRecords(collection, { embeddings, metadatas });
expect(resp.ids.length).toEqual(4);
});

test("should error on empty embedding", async () => {
const collection = await client.createCollection({ name: "test" });
const ids = ["id1"];
const embeddings = [[]];
const metadatas = [{ test: "test1", float_value: 0.1 }];
try {
expect(async () => {
await client.addRecords(collection, { ids, embeddings, metadatas });
} catch (e: any) {
expect(e.message).toMatch("got empty embedding at pos");
}
}).rejects.toThrow("got empty embedding at pos");
});

if (!process.env.OLLAMA_SERVER_URL) {
Expand Down

0 comments on commit 8999e52

Please sign in to comment.