From 23661968d0996c745a398bbbfefb91d2b4ab2ba2 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 19 Dec 2023 19:51:52 +0200 Subject: [PATCH] [ENH]: Added cohere-ai 7.0.0 support in package.json (#1460) #1445 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Support for cohere-ai SDK 7+ ## Test plan *How are these changes tested?* - [x] Tests pass locally with `yarn test` ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- clients/js/package.json | 2 +- .../src/embeddings/CohereEmbeddingFunction.ts | 157 ++++++++++++------ 2 files changed, 110 insertions(+), 49 deletions(-) diff --git a/clients/js/package.json b/clients/js/package.json index f470efe3e37..6fdf88885a9 100644 --- a/clients/js/package.json +++ b/clients/js/package.json @@ -77,7 +77,7 @@ }, "peerDependencies": { "@google/generative-ai": "^0.1.1", - "cohere-ai": "^5.0.0 || ^6.0.0", + "cohere-ai": "^5.0.0 || ^6.0.0 || ^7.0.0", "openai": "^3.0.0 || ^4.0.0" }, "peerDependenciesMeta": { diff --git a/clients/js/src/embeddings/CohereEmbeddingFunction.ts b/clients/js/src/embeddings/CohereEmbeddingFunction.ts index 7d30c416223..2efe45a77c5 100644 --- a/clients/js/src/embeddings/CohereEmbeddingFunction.ts +++ b/clients/js/src/embeddings/CohereEmbeddingFunction.ts @@ -1,61 +1,122 @@ import { IEmbeddingFunction } from "./IEmbeddingFunction"; -let CohereAiApi: any; - -export class CohereEmbeddingFunction implements IEmbeddingFunction { - private api_key: string; - private model: string; - private cohereAiApi?: any; - - constructor({ cohere_api_key, model }: { cohere_api_key: string, model?: string }) { - // we used to construct the client here, but we need to async import the types - // for the openai npm package, and the constructor can not be async - this.api_key = cohere_api_key; - this.model = model || "large"; - } +interface CohereAIAPI { + createEmbedding: (params: { + model: string; + input: string[]; + }) => Promise; +} - private async loadClient() { - if(this.cohereAiApi) return; - try { - // eslint-disable-next-line global-require,import/no-extraneous-dependencies - const { cohere } = await CohereEmbeddingFunction.import(); - CohereAiApi = cohere; - CohereAiApi.init(this.api_key); - } catch (_a) { - // @ts-ignore - if (_a.code === 'MODULE_NOT_FOUND') { - throw new Error("Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`"); - } - throw _a; // Re-throw other errors - } - this.cohereAiApi = CohereAiApi; - } +class CohereAISDK56 implements CohereAIAPI { + private cohereClient: any; + private apiKey: string; - public async generate(texts: string[]) { + constructor(configuration: { apiKey: string }) { + this.apiKey = configuration.apiKey; + } - await this.loadClient(); + private async loadClient() { + if (this.cohereClient) return; + //@ts-ignore + const { default: cohere } = await import("cohere-ai"); + // @ts-ignore + cohere.init(this.apiKey); + this.cohereClient = cohere; + } - const response = await this.cohereAiApi.embed({ - texts: texts, - model: this.model, - }); + public async createEmbedding(params: { + model: string; + input: string[]; + }): Promise { + await this.loadClient(); + return await this.cohereClient + .embed({ + texts: params.input, + model: params.model, + }) + .then((response: any) => { return response.body.embeddings; - } + }); + } +} + +class CohereAISDK7 implements CohereAIAPI { + private cohereClient: any; + private apiKey: string; + + constructor(configuration: { apiKey: string }) { + this.apiKey = configuration.apiKey; + } + + private async loadClient() { + if (this.cohereClient) return; + //@ts-ignore + const cohere = await import("cohere-ai").then((cohere) => { + return cohere; + }); + // @ts-ignore + this.cohereClient = new cohere.CohereClient({ + token: this.apiKey, + }); + } + + public async createEmbedding(params: { + model: string; + input: string[]; + }): Promise { + await this.loadClient(); + return await this.cohereClient + .embed({ texts: params.input, model: params.model }) + .then((response: any) => { + return response.embeddings; + }); + } +} + +export class CohereEmbeddingFunction implements IEmbeddingFunction { + private cohereAiApi?: CohereAIAPI; + private model: string; + private apiKey: string; + constructor({ + cohere_api_key, + model, + }: { + cohere_api_key: string; + model?: string; + }) { + this.model = model || "large"; + this.apiKey = cohere_api_key; + } - /** @ignore */ - static async import(): Promise<{ + private async initCohereClient() { + if (this.cohereAiApi) return; + try { + // @ts-ignore + this.cohereAiApi = await import("cohere-ai").then((cohere) => { // @ts-ignore - cohere: typeof import("cohere-ai"); - }> { - try { - // @ts-ignore - const { default: cohere } = await import("cohere-ai"); - return { cohere }; - } catch (e) { - throw new Error( - "Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" - ); + if (cohere.CohereClient) { + return new CohereAISDK7({ apiKey: this.apiKey }); + } else { + return new CohereAISDK56({ apiKey: this.apiKey }); } + }); + } catch (e) { + // @ts-ignore + if (e.code === "MODULE_NOT_FOUND") { + throw new Error( + "Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`" + ); + } + throw e; } + } + public async generate(texts: string[]): Promise { + await this.initCohereClient(); + // @ts-ignore + return await this.cohereAiApi.createEmbedding({ + model: this.model, + input: texts, + }); + } }