Skip to content

Commit

Permalink
Add chat completion method (#645)
Browse files Browse the repository at this point in the history
Supersede #581. Thanks to @Wauplin, I can import the types from
"@huggingface/tasks"
I've followed the pattern for `textGeneration` and
`textGenerationStream`.

---------

Co-authored-by: coyotte508 <[email protected]>
Co-authored-by: Julien Chaumond <[email protected]>
  • Loading branch information
3 people authored May 13, 2024
1 parent de30544 commit f78bf7a
Show file tree
Hide file tree
Showing 14 changed files with 1,535 additions and 84 deletions.
35 changes: 30 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<p align="center">
<br/>
<picture>
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingfacejs-dark.svg">
<source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingfacejs-light.svg">
<img alt="huggingface javascript library logo" src="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingfacejs-light.svg" width="376" height="59" style="max-width: 100%;">
Expand Down Expand Up @@ -56,8 +56,7 @@ This is a collection of JS libraries to interact with the Hugging Face API, with
- [@huggingface/tasks](packages/tasks/README.md): The definition files and source-of-truth for the Hub's main primitives like pipeline tasks, model libraries, etc.



We use modern features to avoid polyfills and dependencies, so the libraries will only work on modern browsers / Node.js >= 18 / Bun / Deno.
We use modern features to avoid polyfills and dependencies, so the libraries will only work on modern browsers / Node.js >= 18 / Bun / Deno.

The libraries are still very young, please help us by opening issues!

Expand Down Expand Up @@ -108,7 +107,6 @@ import { HfAgent } from "npm:@huggingface/agents";
import { createRepo, commit, deleteRepo, listFiles } from "npm:@huggingface/hub"
```


## Usage examples

Get your HF access token in your [account settings](https://huggingface.co/settings/tokens).
Expand All @@ -122,6 +120,23 @@ const HF_TOKEN = "hf_...";

const inference = new HfInference(HF_TOKEN);

// Chat completion API
const out = await inference.chatCompletion({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 100
});
console.log(out.choices[0].message);

// Streaming chat completion API
for await (const chunk of inference.chatCompletionStream({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 100
})) {
console.log(chunk.choices[0].delta.content);
}

// You can also omit "model" to use the recommended model for the task
await inference.translation({
model: 't5-base',
Expand All @@ -144,6 +159,17 @@ await inference.imageToText({
// Using your own dedicated inference endpoint: https://hf.co/docs/inference-endpoints/
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

//Chat Completion
const mistal = inference.endpoint(
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
);
const out = await mistal.chatCompletion({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 100,
});
console.log(out.choices[0].message);
```

### @huggingface/hub examples
Expand Down Expand Up @@ -200,7 +226,6 @@ const messages = await agent.run("Draw a picture of a cat wearing a top hat. The
console.log(messages);
```


There are more features of course, check each library's README!

## Formatting & testing
Expand Down
149 changes: 122 additions & 27 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ It works with both [Inference API (serverless)](https://huggingface.co/docs/api-

Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).

You can also try out a live [interactive notebook](https://observablehq.com/@huggingface/hello-huggingface-js-inference), see some demos on [hf.co/huggingfacejs](https://huggingface.co/huggingfacejs), or watch a [Scrimba tutorial that explains how Inference Endpoints works](https://scrimba.com/scrim/cod8248f5adfd6e129582c523).
You can also try out a live [interactive notebook](https://observablehq.com/@huggingface/hello-huggingface-js-inference), see some demos on [hf.co/huggingfacejs](https://huggingface.co/huggingfacejs), or watch a [Scrimba tutorial that explains how Inference Endpoints works](https://scrimba.com/scrim/cod8248f5adfd6e129582c523).

## Getting Started

Expand All @@ -30,7 +30,6 @@ import { HfInference } from "https://esm.sh/@huggingface/inference"
import { HfInference } from "npm:@huggingface/inference"
```


### Initialize

```typescript
Expand All @@ -43,7 +42,6 @@ const hf = new HfInference('your access token')

Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.


#### Tree-shaking

You can import the functions you need directly from the module instead of using the `HfInference` class.
Expand All @@ -63,6 +61,85 @@ This will enable tree-shaking by your bundler.

## Natural Language Processing

### Text Generation

Generates text from an input prompt.

[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-text-generation)

```typescript
await hf.textGeneration({
model: 'gpt2',
inputs: 'The answer to the universe is'
})

for await (const output of hf.textGenerationStream({
model: "google/flan-t5-xxl",
inputs: 'repeat "one two three four"',
parameters: { max_new_tokens: 250 }
})) {
console.log(output.token.text, output.generated_text);
}
```

### Text Generation (Chat Completion API Compatible)

Using the `chatCompletion` method, you can generate text with models compatible with the OpenAI Chat Completion API. All models served by [TGI](https://api-inference.huggingface.co/framework/text-generation-inference) on Hugging Face support Messages API.

[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-chat-completion)

```typescript
// Non-streaming API
const out = await hf.chatCompletion({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [{ role: "user", content: "Complete the this sentence with words one plus one is equal " }],
max_tokens: 500,
temperature: 0.1,
seed: 0,
});

// Streaming API
let out = "";
for await (const chunk of hf.chatCompletionStream({
model: "mistralai/Mistral-7B-Instruct-v0.2",
messages: [
{ role: "user", content: "Complete the equation 1+1= ,just the answer" },
],
max_tokens: 500,
temperature: 0.1,
seed: 0,
})) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
```

It's also possible to call Mistral or OpenAI endpoints directly:

```typescript
const openai = new HfInference(OPENAI_TOKEN).endpoint("https://api.openai.com");

let out = "";
for await (const chunk of openai.chatCompletionStream({
model: "gpt-3.5-turbo",
messages: [
{ role: "user", content: "Complete the equation 1+1= ,just the answer" },
],
max_tokens: 500,
temperature: 0.1,
seed: 0,
})) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}

// For mistral AI:
// endpointUrl: "https://api.mistral.ai"
// model: "mistral-tiny"
```

### Fill Mask

Tries to fill in a hole with a missing word (token to be precise).
Expand Down Expand Up @@ -131,27 +208,6 @@ await hf.textClassification({
})
```

### Text Generation

Generates text from an input prompt.

[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-text-generation)

```typescript
await hf.textGeneration({
model: 'gpt2',
inputs: 'The answer to the universe is'
})

for await (const output of hf.textGenerationStream({
model: "google/flan-t5-xxl",
inputs: 'repeat "one two three four"',
parameters: { max_new_tokens: 250 }
})) {
console.log(output.token.text, output.generated_text);
}
```

### Token Classification

Used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
Expand All @@ -177,9 +233,9 @@ await hf.translation({
model: 'facebook/mbart-large-50-many-to-many-mmt',
inputs: textToTranslate,
parameters: {
"src_lang": "en_XX",
"tgt_lang": "fr_XX"
}
"src_lang": "en_XX",
"tgt_lang": "fr_XX"
}
})
```

Expand Down Expand Up @@ -497,13 +553,52 @@ for await (const output of hf.streamingRequest({
}
```

You can use any Chat Completion API-compatible provider with the `chatCompletion` method.

```typescript
// Chat Completion Example
const MISTRAL_KEY = process.env.MISTRAL_KEY;
const hf = new HfInference(MISTRAL_KEY);
const ep = hf.endpoint("https://api.mistral.ai");
const stream = ep.chatCompletionStream({
model: "mistral-tiny",
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
});
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
console.log(out);
}
}
```

## Custom Inference Endpoints

Learn more about using your own inference endpoints [here](https://hf.co/docs/inference-endpoints/)

```typescript
const gpt2 = hf.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

// Chat Completion Example
const ep = hf.endpoint(
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
);
const stream = ep.chatCompletionStream({
model: "tgi",
messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }],
max_tokens: 500,
temperature: 0.1,
seed: 0,
});
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
console.log(out);
}
}
```

By default, all calls to the inference endpoint will wait until the model is
Expand Down
8 changes: 4 additions & 4 deletions packages/inference/src/HfInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type TaskWithNoAccessToken = {
) => ReturnType<Task[key]>;
};

type TaskWithNoAccessTokenNoModel = {
type TaskWithNoAccessTokenNoEndpointUrl = {
[key in keyof Task]: (
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
options?: Parameters<Task[key]>[1]
) => ReturnType<Task[key]>;
};
Expand Down Expand Up @@ -57,12 +57,12 @@ export class HfInferenceEndpoint {
enumerable: false,
value: (params: RequestArgs, options: Options) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
});
}
}
}

export interface HfInference extends TaskWithNoAccessToken {}

export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
24 changes: 17 additions & 7 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

Expand All @@ -22,10 +23,10 @@ export async function makeRequestOptions(
forceTask?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
chatCompletion?: boolean;
}
): Promise<{ url: string; info: RequestInit }> {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { accessToken, model: _model, ...otherArgs } = args;
const { accessToken, endpointUrl, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
Expand All @@ -34,7 +35,7 @@ export async function makeRequestOptions(
wait_for_model,
use_cache,
dont_load_model,
...otherOptions
chatCompletion,
} = options ?? {};

const headers: Record<string, string> = {};
Expand Down Expand Up @@ -77,18 +78,28 @@ export async function makeRequestOptions(
headers["X-Load-Model"] = "0";
}

const url = (() => {
let url = (() => {
if (endpointUrl && isUrl(model)) {
throw new TypeError("Both model and endpointUrl cannot be URLs");
}
if (isUrl(model)) {
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
return model;
}

if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
}

return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
})();

if (chatCompletion && !url.endsWith("/chat/completions")) {
url += "/v1/chat/completions";
}

/**
* For edge runtimes, leave 'credentials' undefined, otherwise cloudflare workers will error
*/
Expand All @@ -105,8 +116,7 @@ export async function makeRequestOptions(
body: binary
? args.data
: JSON.stringify({
...otherArgs,
options: options && otherOptions,
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
}),
...(credentials && { credentials }),
signal: options?.signal,
Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ export async function request<T>(
task?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
/** Is chat completion compatible */
chatCompletion?: boolean;
}
): Promise<T> {
const { url, info } = await makeRequestOptions(args, options);
Expand All @@ -26,6 +28,9 @@ export async function request<T>(
if (!response.ok) {
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
const output = await response.json();
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
}
if (output.error) {
throw new Error(output.error);
}
Expand Down
Loading

0 comments on commit f78bf7a

Please sign in to comment.