Skip to content

Commit

Permalink
fix the bug of text-generation example
Browse files Browse the repository at this point in the history
  • Loading branch information
NingW101 committed Aug 28, 2024
1 parent d5a8f87 commit d045f22
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/demo-site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"vite": "^4.3.2"
},
"dependencies": {
"@xenova/transformers": "^2.0.0-alpha.3",
"@huggingface/transformers": "^3.0.0-alpha.10",
"chart.js": "^4.3.0",
"prismjs": "^1.29.0"
}
Expand Down
8 changes: 4 additions & 4 deletions examples/demo-site/src/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -535,16 +535,16 @@ <h2 class="fw-bolder">Quick tour</h2>
</div>
<div class="mb-3">
<h5 class="mb-2">Installation</h5>
To install via <a href="https://www.npmjs.com/package/@xenova/transformers">NPM</a>, run:
<pre><code class="language-bash">npm i @xenova/transformers</code></pre>
To install via <a href="https://www.npmjs.com/package/@huggingface/transformers">NPM</a>, run:
<pre><code class="language-bash">npm i @huggingface/transformers</code></pre>

Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN
or static hosting. For example, using
<a href="https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules">ES Modules</a>,
you can import the library with:

<pre><code class="language-html">&lt;script type=&quot;module&quot;&gt;
import { pipeline } from &#39;https://cdn.jsdelivr.net/npm/@xenova/transformers&#39;;
import { pipeline } from &#39;https://cdn.jsdelivr.net/npm/@huggingface/transformers&#39;;
&lt;/script&gt;</code></pre>
</div>

Expand All @@ -564,7 +564,7 @@ <h5 class="mb-2">Basic example</h5>

</div>
<div class="col-lg-6 mb-4 mb-lg-0">
<pre><code class="language-js">import { pipeline } from '@xenova/transformers';
<pre><code class="language-js">import { pipeline } from '@huggingface/transformers';

// Allocate a pipeline for sentiment-analysis
let pipe = await pipeline('sentiment-analysis');
Expand Down
2 changes: 1 addition & 1 deletion examples/demo-site/src/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ worker.addEventListener('message', (event) => {
CODE_BLOCKS[target].update(message.data);
break;
default: // is textbox
elem.value = message.data
elem.value += message.data
break;
}

Expand Down
22 changes: 12 additions & 10 deletions examples/demo-site/src/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// Needed to ensure the UI thread is not blocked when running //
/////////////////////////////////////////////////////////////////

import { pipeline, env } from "@xenova/transformers";
import { pipeline, env } from "@huggingface/transformers";
env.allowLocalModels = false;

// Define task function mapping
Expand Down Expand Up @@ -59,13 +59,14 @@ class PipelineFactory {
* @param {*} progressCallback
* @returns {Promise}
*/
static getInstance(progressCallback = null) {
static getInstance(progressCallback = null, model_file_name = "") {
if (this.task === null || this.model === null) {
throw Error("Must set task and model")
}
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, {
progress_callback: progressCallback
progress_callback: progressCallback,
model_file_name: model_file_name,
});
}

Expand Down Expand Up @@ -182,21 +183,22 @@ async function text_generation(data) {
task: 'text-generation',
data: data
});
})
}, "decoder_model_merged"
)

let text = data.text.trim();

return await pipeline(text, {
...data.generation,
callback_function: function (beams) {
const decodedText = pipeline.tokenizer.decode(beams[0].output_token_ids, {
skip_special_tokens: true,
})
...data.generation }, {
callback_function: function (decodedText) {
const postProcessor = (text) => {
return text.replace('<|endoftext|>','');
}

self.postMessage({
type: 'update',
target: data.elementIdToUpdate,
data: decodedText
data: postProcessor(decodedText)
});
}
})
Expand Down
19 changes: 18 additions & 1 deletion src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ import {
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';

import {TextStreamer} from "./generation/streamers.js"


/**
* @typedef {string | RawImage | URL} ImageInput
Expand Down Expand Up @@ -984,7 +986,13 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
}

/** @type {TextGenerationPipelineCallback} */
async _call(texts, generate_kwargs = {}) {
async _call(texts, generate_kwargs = {}, {
skip_prompt = false,
callback_function = null,
token_callback_function = null,
decode_kwargs = {},
...kwargs
}) {
let isBatched = false;
let isChatInput = false;

Expand Down Expand Up @@ -1030,8 +1038,17 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
truncation: true,
});

const streamer = new TextStreamer(this.tokenizer, {
skip_prompt:skip_prompt,
callback_function:callback_function,
token_callback_function:token_callback_function,
decode_kwargs : decode_kwargs,
...kwargs
});

const outputTokenIds = /** @type {Tensor} */(await this.model.generate({
...text_inputs,
streamer:streamer,
...generate_kwargs
}));

Expand Down

0 comments on commit d045f22

Please sign in to comment.