Skip to content

Commit

Permalink
[ENH]: Added cohere-ai 7.0.0 support in package.json (#1460)
Browse files Browse the repository at this point in the history
#1445 

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 -  Support for cohere-ai SDK 7+
## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `yarn test`

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
tazarov authored Dec 19, 2023
1 parent e4f7bba commit 2366196
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 49 deletions.
2 changes: 1 addition & 1 deletion clients/js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
},
"peerDependencies": {
"@google/generative-ai": "^0.1.1",
"cohere-ai": "^5.0.0 || ^6.0.0",
"cohere-ai": "^5.0.0 || ^6.0.0 || ^7.0.0",
"openai": "^3.0.0 || ^4.0.0"
},
"peerDependenciesMeta": {
Expand Down
157 changes: 109 additions & 48 deletions clients/js/src/embeddings/CohereEmbeddingFunction.ts
Original file line number Diff line number Diff line change
@@ -1,61 +1,122 @@
import { IEmbeddingFunction } from "./IEmbeddingFunction";

let CohereAiApi: any;

export class CohereEmbeddingFunction implements IEmbeddingFunction {
private api_key: string;
private model: string;
private cohereAiApi?: any;

constructor({ cohere_api_key, model }: { cohere_api_key: string, model?: string }) {
// we used to construct the client here, but we need to async import the types
// for the openai npm package, and the constructor can not be async
this.api_key = cohere_api_key;
this.model = model || "large";
}
interface CohereAIAPI {
createEmbedding: (params: {
model: string;
input: string[];
}) => Promise<number[][]>;
}

private async loadClient() {
if(this.cohereAiApi) return;
try {
// eslint-disable-next-line global-require,import/no-extraneous-dependencies
const { cohere } = await CohereEmbeddingFunction.import();
CohereAiApi = cohere;
CohereAiApi.init(this.api_key);
} catch (_a) {
// @ts-ignore
if (_a.code === 'MODULE_NOT_FOUND') {
throw new Error("Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`");
}
throw _a; // Re-throw other errors
}
this.cohereAiApi = CohereAiApi;
}
class CohereAISDK56 implements CohereAIAPI {
private cohereClient: any;
private apiKey: string;

public async generate(texts: string[]) {
constructor(configuration: { apiKey: string }) {
this.apiKey = configuration.apiKey;
}

await this.loadClient();
private async loadClient() {
if (this.cohereClient) return;
//@ts-ignore
const { default: cohere } = await import("cohere-ai");
// @ts-ignore
cohere.init(this.apiKey);
this.cohereClient = cohere;
}

const response = await this.cohereAiApi.embed({
texts: texts,
model: this.model,
});
public async createEmbedding(params: {
model: string;
input: string[];
}): Promise<number[][]> {
await this.loadClient();
return await this.cohereClient
.embed({
texts: params.input,
model: params.model,
})
.then((response: any) => {
return response.body.embeddings;
}
});
}
}

class CohereAISDK7 implements CohereAIAPI {
private cohereClient: any;
private apiKey: string;

constructor(configuration: { apiKey: string }) {
this.apiKey = configuration.apiKey;
}

private async loadClient() {
if (this.cohereClient) return;
//@ts-ignore
const cohere = await import("cohere-ai").then((cohere) => {
return cohere;
});
// @ts-ignore
this.cohereClient = new cohere.CohereClient({
token: this.apiKey,
});
}

public async createEmbedding(params: {
model: string;
input: string[];
}): Promise<number[][]> {
await this.loadClient();
return await this.cohereClient
.embed({ texts: params.input, model: params.model })
.then((response: any) => {
return response.embeddings;
});
}
}

export class CohereEmbeddingFunction implements IEmbeddingFunction {
private cohereAiApi?: CohereAIAPI;
private model: string;
private apiKey: string;
constructor({
cohere_api_key,
model,
}: {
cohere_api_key: string;
model?: string;
}) {
this.model = model || "large";
this.apiKey = cohere_api_key;
}

/** @ignore */
static async import(): Promise<{
private async initCohereClient() {
if (this.cohereAiApi) return;
try {
// @ts-ignore
this.cohereAiApi = await import("cohere-ai").then((cohere) => {
// @ts-ignore
cohere: typeof import("cohere-ai");
}> {
try {
// @ts-ignore
const { default: cohere } = await import("cohere-ai");
return { cohere };
} catch (e) {
throw new Error(
"Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`"
);
if (cohere.CohereClient) {
return new CohereAISDK7({ apiKey: this.apiKey });
} else {
return new CohereAISDK56({ apiKey: this.apiKey });
}
});
} catch (e) {
// @ts-ignore
if (e.code === "MODULE_NOT_FOUND") {
throw new Error(
"Please install the cohere-ai package to use the CohereEmbeddingFunction, `npm install -S cohere-ai`"
);
}
throw e;
}
}

public async generate(texts: string[]): Promise<number[][]> {
await this.initCohereClient();
// @ts-ignore
return await this.cohereAiApi.createEmbedding({
model: this.model,
input: texts,
});
}
}

0 comments on commit 2366196

Please sign in to comment.