-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH]: Added cohere-ai 7.0.0 support in package.json (#1460)
#1445 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Support for cohere-ai SDK 7+ ## Test plan *How are these changes tested?* - [x] Tests pass locally with `yarn test` ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
- Loading branch information
Showing
2 changed files
with
110 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,122 @@ | ||
import { IEmbeddingFunction } from "./IEmbeddingFunction"; | ||
|
||
let CohereAiApi: any; | ||
|
||
export class CohereEmbeddingFunction implements IEmbeddingFunction { | ||
private api_key: string; | ||
private model: string; | ||
private cohereAiApi?: any; | ||
|
||
constructor({ cohere_api_key, model }: { cohere_api_key: string, model?: string }) { | ||
// we used to construct the client here, but we need to async import the types | ||
// for the openai npm package, and the constructor can not be async | ||
this.api_key = cohere_api_key; | ||
this.model = model || "large"; | ||
} | ||
interface CohereAIAPI { | ||
createEmbedding: (params: { | ||
model: string; | ||
input: string[]; | ||
}) => Promise<number[][]>; | ||
} | ||
|
||
private async loadClient() { | ||
if(this.cohereAiApi) return; | ||
try { | ||
// eslint-disable-next-line global-require,import/no-extraneous-dependencies | ||
const { cohere } = await CohereEmbeddingFunction.import(); | ||
CohereAiApi = cohere; | ||
CohereAiApi.init(this.api_key); | ||
} catch (_a) { | ||
// @ts-ignore | ||
if (_a.code === 'MODULE_NOT_FOUND') { | ||
throw new Error("Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`"); | ||
} | ||
throw _a; // Re-throw other errors | ||
} | ||
this.cohereAiApi = CohereAiApi; | ||
} | ||
class CohereAISDK56 implements CohereAIAPI { | ||
private cohereClient: any; | ||
private apiKey: string; | ||
|
||
public async generate(texts: string[]) { | ||
constructor(configuration: { apiKey: string }) { | ||
this.apiKey = configuration.apiKey; | ||
} | ||
|
||
await this.loadClient(); | ||
private async loadClient() { | ||
if (this.cohereClient) return; | ||
//@ts-ignore | ||
const { default: cohere } = await import("cohere-ai"); | ||
// @ts-ignore | ||
cohere.init(this.apiKey); | ||
this.cohereClient = cohere; | ||
} | ||
|
||
const response = await this.cohereAiApi.embed({ | ||
texts: texts, | ||
model: this.model, | ||
}); | ||
public async createEmbedding(params: { | ||
model: string; | ||
input: string[]; | ||
}): Promise<number[][]> { | ||
await this.loadClient(); | ||
return await this.cohereClient | ||
.embed({ | ||
texts: params.input, | ||
model: params.model, | ||
}) | ||
.then((response: any) => { | ||
return response.body.embeddings; | ||
} | ||
}); | ||
} | ||
} | ||
|
||
class CohereAISDK7 implements CohereAIAPI { | ||
private cohereClient: any; | ||
private apiKey: string; | ||
|
||
constructor(configuration: { apiKey: string }) { | ||
this.apiKey = configuration.apiKey; | ||
} | ||
|
||
private async loadClient() { | ||
if (this.cohereClient) return; | ||
//@ts-ignore | ||
const cohere = await import("cohere-ai").then((cohere) => { | ||
return cohere; | ||
}); | ||
// @ts-ignore | ||
this.cohereClient = new cohere.CohereClient({ | ||
token: this.apiKey, | ||
}); | ||
} | ||
|
||
public async createEmbedding(params: { | ||
model: string; | ||
input: string[]; | ||
}): Promise<number[][]> { | ||
await this.loadClient(); | ||
return await this.cohereClient | ||
.embed({ texts: params.input, model: params.model }) | ||
.then((response: any) => { | ||
return response.embeddings; | ||
}); | ||
} | ||
} | ||
|
||
export class CohereEmbeddingFunction implements IEmbeddingFunction { | ||
private cohereAiApi?: CohereAIAPI; | ||
private model: string; | ||
private apiKey: string; | ||
constructor({ | ||
cohere_api_key, | ||
model, | ||
}: { | ||
cohere_api_key: string; | ||
model?: string; | ||
}) { | ||
this.model = model || "large"; | ||
this.apiKey = cohere_api_key; | ||
} | ||
|
||
/** @ignore */ | ||
static async import(): Promise<{ | ||
private async initCohereClient() { | ||
if (this.cohereAiApi) return; | ||
try { | ||
// @ts-ignore | ||
this.cohereAiApi = await import("cohere-ai").then((cohere) => { | ||
// @ts-ignore | ||
cohere: typeof import("cohere-ai"); | ||
}> { | ||
try { | ||
// @ts-ignore | ||
const { default: cohere } = await import("cohere-ai"); | ||
return { cohere }; | ||
} catch (e) { | ||
throw new Error( | ||
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`" | ||
); | ||
if (cohere.CohereClient) { | ||
return new CohereAISDK7({ apiKey: this.apiKey }); | ||
} else { | ||
return new CohereAISDK56({ apiKey: this.apiKey }); | ||
} | ||
}); | ||
} catch (e) { | ||
// @ts-ignore | ||
if (e.code === "MODULE_NOT_FOUND") { | ||
throw new Error( | ||
"Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`" | ||
); | ||
} | ||
throw e; | ||
} | ||
} | ||
|
||
public async generate(texts: string[]): Promise<number[][]> { | ||
await this.initCohereClient(); | ||
// @ts-ignore | ||
return await this.cohereAiApi.createEmbedding({ | ||
model: this.model, | ||
input: texts, | ||
}); | ||
} | ||
} |