From 129abcd42c2dc044571e7102586a988aa6936a95 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 24 Aug 2023 02:38:05 +0200 Subject: [PATCH 001/473] [version] Update to 3.0.0-alpha.0 --- README.md | 4 ++-- docs/snippets/2_installation.snippet | 2 +- docs/snippets/4_custom-usage.snippet | 2 +- package-lock.json | 36 ++++++++++++++-------------- package.json | 8 +++---- src/env.js | 2 +- 6 files changed, 27 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index e6f47415f..fb11ec927 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ npm i @xenova/transformers Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with: ```html ``` @@ -125,7 +125,7 @@ Want to jump straight in? Get started with one of our sample applications/templa -By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.5.3/dist/), which should work out-of-the-box. You can customize this as follows: +By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@3.0.0-alpha.0/dist/), which should work out-of-the-box. You can customize this as follows: ### Settings diff --git a/docs/snippets/2_installation.snippet b/docs/snippets/2_installation.snippet index 4d571dca2..49224cb7f 100644 --- a/docs/snippets/2_installation.snippet +++ b/docs/snippets/2_installation.snippet @@ -7,6 +7,6 @@ npm i @xenova/transformers Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with: ```html ``` diff --git a/docs/snippets/4_custom-usage.snippet b/docs/snippets/4_custom-usage.snippet index 9925ae4e7..1f7916206 100644 --- a/docs/snippets/4_custom-usage.snippet +++ b/docs/snippets/4_custom-usage.snippet @@ -1,6 +1,6 @@ -By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@2.5.3/dist/), which should work out-of-the-box. You can customize this as follows: +By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@xenova/transformers@3.0.0-alpha.0/dist/), which should work out-of-the-box. You can customize this as follows: ### Settings diff --git a/package-lock.json b/package-lock.json index f61e2102b..e0a36fc07 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,15 +1,15 @@ { "name": "@xenova/transformers", - "version": "2.5.3", + "version": "3.0.0-alpha.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@xenova/transformers", - "version": "2.5.3", + "version": "3.0.0-alpha.0", "license": "Apache-2.0", "dependencies": { - "onnxruntime-web": "1.14.0", + "onnxruntime-web": "1.15.1", "sharp": "^0.32.0" }, "devDependencies": { @@ -26,7 +26,7 @@ "webpack-dev-server": "^4.13.3" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.15.1" } }, "node_modules/@ampproject/remapping": { @@ -5673,14 +5673,14 @@ } }, "node_modules/onnxruntime-common": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz", - "integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew==" + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.15.1.tgz", + "integrity": "sha512-Y89eJ8QmaRsPZPWLaX7mfqhj63ny47rSkQe80hIo+lvBQdrdXYR9VO362xvZulk9DFkCnXmGidprvgJ07bKsIQ==" }, "node_modules/onnxruntime-node": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz", - "integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==", + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.15.1.tgz", + "integrity": "sha512-wzhVELulmrvNoMZw0/HfV+9iwgHX+kPS82nxodZ37WCXmbeo1jp3thamTsNg8MGhxvv4GmEzRum5mo40oqIsqw==", "optional": true, "os": [ "win32", @@ -5688,19 +5688,19 @@ "linux" ], "dependencies": { - "onnxruntime-common": "~1.14.0" + "onnxruntime-common": "~1.15.1" } }, "node_modules/onnxruntime-web": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz", - "integrity": "sha512-Kcqf43UMfW8mCydVGcX9OMXI2VN17c0p6XvR7IPSZzBf/6lteBzXHvcEVWDPmCKuGombl997HgLqj91F11DzXw==", + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.15.1.tgz", + "integrity": "sha512-Ky4AXFLFyiGRu5KQJdDcbhdNcO0f2ND/8IPmTEwcKKIHpCwH6/Q9UoMpcoFz78lxGvnmmy+FFgA/Bs1HjdM6LA==", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", "long": "^4.0.0", "onnx-proto": "^4.0.4", - "onnxruntime-common": "~1.14.0", + "onnxruntime-common": "~1.15.1", "platform": "^1.3.6" } }, @@ -5960,9 +5960,9 @@ } }, "node_modules/protobufjs": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.4.tgz", - "integrity": "sha512-AT+RJgD2sH8phPmCf7OUZR8xGdcJRga4+1cOaXJ64hvcSkVhNcRHOwIxUatPH15+nj59WAGTDv3LSGZPEQbJaQ==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "hasInstallScript": true, "dependencies": { "@protobufjs/aspromise": "^1.1.2", diff --git a/package.json b/package.json index d6af22a00..e36275a0f 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@xenova/transformers", - "version": "2.5.3", + "version": "3.0.0-alpha.0", "description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!", "main": "./src/transformers.js", "types": "./types/transformers.d.ts", @@ -38,11 +38,11 @@ }, "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { - "onnxruntime-web": "1.14.0", + "onnxruntime-web": "1.15.1", "sharp": "^0.32.0" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.15.1" }, "devDependencies": { "@types/jest": "^29.5.1", @@ -59,7 +59,7 @@ }, "overrides": { "semver": "^7.5.4", - "protobufjs": "^7.2.4" + "protobufjs": "^7.2.5" }, "files": [ "src", diff --git a/src/env.js b/src/env.js index 80ea2661c..b522b9c94 100644 --- a/src/env.js +++ b/src/env.js @@ -29,7 +29,7 @@ import url from 'url'; import { ONNX } from './backends/onnx.js'; const { env: onnx_env } = ONNX; -const VERSION = '2.5.3'; +const VERSION = '3.0.0-alpha.0'; // Check if various APIs are available (depends on environment) const WEB_CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self; From 8c9844627f74e1dcfdd35ba9275f4d14e514fa31 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 26 Aug 2023 01:06:17 +0200 Subject: [PATCH 002/473] Fix `SamImageProcessor` --- src/processors.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/processors.js b/src/processors.js index d3ca08017..e7459ef8f 100644 --- a/src/processors.js +++ b/src/processors.js @@ -701,9 +701,8 @@ export class SamImageProcessor extends ImageFeatureExtractor { } let input_points_tensor = new Tensor( - 'int64', - BigInt64Array.from(input_points.flat(Infinity) - .map(x => BigInt(Math.round(x)))), + 'float32', + Float32Array.from(input_points.flat(Infinity)), shape ) @@ -766,7 +765,7 @@ export class SamImageProcessor extends ImageFeatureExtractor { interpolated_mask = interpolated_mask.slice(null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]); // Downscale mask - interpolated_mask = interpolate(mask, original_size, 'bilinear', false); + interpolated_mask = interpolate(interpolated_mask, original_size, 'bilinear', false); if (binarize) { interpolated_mask = new Tensor( @@ -782,6 +781,7 @@ export class SamImageProcessor extends ImageFeatureExtractor { interpolated_masks.push(interpolated_mask); } + // TODO switch to stack let concatenated = cat(interpolated_masks); output_masks.push(concatenated); } From 75ae57b5d317b38debe60c6e1f5919692d608f75 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 26 Aug 2023 01:08:46 +0200 Subject: [PATCH 003/473] Split SAM encoder and decoder --- src/models.js | 156 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 128 insertions(+), 28 deletions(-) diff --git a/src/models.js b/src/models.js index 2d16968bf..218ca36c5 100644 --- a/src/models.js +++ b/src/models.js @@ -94,6 +94,7 @@ class ModelType { }; class EncoderOnlyModelType extends ModelType { }; class EncoderDecoderModelType extends ModelType { }; class Seq2SeqModelType extends EncoderDecoderModelType { }; +class MaskGenerationModelType extends EncoderDecoderModelType { }; // Specialized model type for SAM class DecoderOnlyModelType extends ModelType { }; ////////////////////////////////////////////////// @@ -659,6 +660,13 @@ export class PreTrainedModel extends Callable { getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); + } else if (modelType === MaskGenerationModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'vision_encoder', options), + constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options), + ]); + } else if (modelType === EncoderDecoderModelType) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), @@ -1930,7 +1938,7 @@ export class BartForConditionalGeneration extends BartPretrainedModel { * Returns the initial beam for generating output text. * @param {Object} inputs The input object containing the encoded input text. * @param {number} numOutputTokens The maximum number of output tokens to generate. - * @param {...any} args Additional arguments to pass to the sequence-to-sequence generation function. + * @param {...any} args Additional arguments to pass to the sequence-to-sequence generation function. * @returns {any} The initial beam for generating output text. */ getStartBeams(inputs, numOutputTokens, ...args) { @@ -2600,7 +2608,7 @@ export class GPT2Model extends GPT2PreTrainedModel { /** * GPT2Model is not compatible with `.generate()`, as it doesn't have a language model head. - * @param {...any} args + * @param {...any} args * @throws {Error} * @returns {Promise} */ @@ -2680,7 +2688,7 @@ export class GPTNeoPreTrainedModel extends PreTrainedModel { export class GPTNeoModel extends GPTNeoPreTrainedModel { /** * - * @param {...any} args + * @param {...any} args * @throws {Error} * @returns {Promise} */ @@ -2756,7 +2764,7 @@ export class GPTBigCodePreTrainedModel extends PreTrainedModel { export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { /** * - * @param {...any} args + * @param {...any} args * @throws {Error} * @returns {Promise} */ @@ -2840,7 +2848,7 @@ export class CodeGenModel extends CodeGenPreTrainedModel { * * @throws {Error} The current model class is not compatible with `.generate()` * - * @param {...any} args Arguments passed to the generate function + * @param {...any} args Arguments passed to the generate function * @returns {Promise} */ async generate(...args) { @@ -2930,8 +2938,7 @@ export class LlamaModel extends LlamaPreTrainedModel { * as it doesn't have a language model head. * * @throws {Error} The current model class is not compatible with `.generate()` - * - * @param {...any} args Arguments passed to the generate function + * @param {...any} args Arguments passed to the generate function * @returns {Promise} */ async generate(...args) { @@ -3068,12 +3075,105 @@ export class DetrSegmentationOutput extends ModelOutput { ////////////////////////////////////////////////// export class SamPreTrainedModel extends PreTrainedModel { } + +/** + * + * **Example:** Prompted-Mask-Generation + * ```javascript + * import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers'; + * + * const model = await SamModel.from_pretrained('Xenova/sam-vit-base'); + * const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base'); + * + * const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png'; + * const raw_image = await RawImage.read(img_url); + * const input_points = [[[450, 600]]] // 2D localization of a window + * + * const inputs = await processor(raw_image, input_points); + * const outputs = await model(inputs); + * + * const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes); + * // [ + * // Tensor { + * // dims: [ 1, 3, 1764, 2646 ], + * // type: 'bool', + * // data: Uint8Array(14002632) [ ... ], + * // size: 14002632 + * // } + * // ] +* const scores = outputs.iou_scores; + * // Tensor { + * // dims: [ 1, 1, 3 ], + * // type: 'float32', + * // data: Float32Array(3) [ + * // 0.9016823172569275, + * // 0.943422257900238, + * // 0.978232741355896 + * // ], + * // size: 3 + * // } + * ``` + */ export class SamModel extends SamPreTrainedModel { /** - * @param {Object} model_inputs - * @param {Tensor} model_inputs.pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`. - * @param {Tensor} model_inputs.input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. This is used by the prompt encoder to encode the prompt. - * @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`. + * Creates a new instance of the `SamModel` class. + * @param {Object} config The configuration object specifying the hyperparameters and other model settings. + * @param {Object} vision_encoder The ONNX session containing the vision encoder model. + * @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model. + */ + constructor(config, vision_encoder, prompt_encoder_mask_decoder) { + super(config, vision_encoder); + this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder; + } + + /** + * Compute image embeddings and positional image embeddings, given the pixel values of an image. + * @param {Object} model_inputs Object containing the model inputs. + * @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`. + * @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings. + */ + async get_image_embeddings({ pixel_values }) { + // in: + // - pixel_values: tensor.float32[batch_size,3,1024,1024] + // + // out: + // - image_embeddings: tensor.float32[batch_size,256,64,64] + // - image_positional_embeddings: tensor.float32[batch_size,256,64,64] + return await encoderForward(this, { pixel_values }) + } + + /** + * @typedef {Object} SamModelInputs Object containing the model inputs. + * @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`. + * These can be obtained using a `SamProcessor`. + * @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. + * This is used by the prompt encoder to encode the prompt. + * @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder. + * @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder. + */ + + /** + * @param {SamModelInputs} model_inputs Object containing the model inputs. + * @returns {Promise} The output of the model. + */ + async forward(model_inputs) { + if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) { + // Compute the image embeddings if they are missing + model_inputs = { + ...model_inputs, + ...(await this.get_image_embeddings(model_inputs)) + } + } + // Returns: + // - iou_scores: tensor.float32[batch_size,point_batch_size,3] + // - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256] + return await sessionRun(this.prompt_encoder_mask_decoder, model_inputs); + } + + /** + * Runs the model with the provided inputs + * @param {Object} model_inputs Model inputs + * @returns {Promise} Object containing segmentation outputs */ async _call(model_inputs) { return new SamImageSegmentationOutput(await super._call(model_inputs)); @@ -3106,7 +3206,7 @@ export class MarianPreTrainedModel extends PreTrainedModel { }; export class MarianModel extends MarianPreTrainedModel { /** * - * @param {...any} args + * @param {...any} args * @throws {Error} * @returns {Promise} */ @@ -3185,7 +3285,7 @@ export class M2M100PreTrainedModel extends PreTrainedModel { }; export class M2M100Model extends M2M100PreTrainedModel { /** * - * @param {...any} args + * @param {...any} args * @throws {Error} * @returns {Promise} */ @@ -3399,8 +3499,6 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['mobilebert', MobileBertModel], ['squeezebert', SqueezeBertModel], ['wav2vec2', Wav2Vec2Model], - - ['sam', SamModel], // TODO change to encoder-decoder when model is split correctly ]); const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ @@ -3532,7 +3630,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, EncoderOnlyModelType], [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, EncoderOnlyModelType], - [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MaskGenerationModelType], [MODEL_FOR_CTC_MAPPING_NAMES, EncoderOnlyModelType], [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], ]; @@ -3551,10 +3649,12 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModel.from_pretrained('bert-base-uncased'); + * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); */ export class AutoModel extends PretrainedMixin { - static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY]; + /** @type {Map[]} */ + // @ts-ignore + static MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map(x => x[0]); static BASE_IF_FAIL = true; } @@ -3563,7 +3663,7 @@ export class AutoModel extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english'); + * let model = await AutoModelForSequenceClassification.from_pretrained('Xenova/distilbert-base-uncased-finetuned-sst-2-english'); */ export class AutoModelForSequenceClassification extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]; @@ -3574,7 +3674,7 @@ export class AutoModelForSequenceClassification extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForTokenClassification.from_pretrained('Davlan/distilbert-base-multilingual-cased-ner-hrl'); + * let model = await AutoModelForTokenClassification.from_pretrained('Xenova/distilbert-base-multilingual-cased-ner-hrl'); */ export class AutoModelForTokenClassification extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]; @@ -3585,7 +3685,7 @@ export class AutoModelForTokenClassification extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForSeq2SeqLM.from_pretrained('t5-small'); + * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); */ export class AutoModelForSeq2SeqLM extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES]; @@ -3596,7 +3696,7 @@ export class AutoModelForSeq2SeqLM extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForCausalLM.from_pretrained('gpt2'); + * let model = await AutoModelForCausalLM.from_pretrained('Xenova/gpt2'); */ export class AutoModelForCausalLM extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]; @@ -3607,7 +3707,7 @@ export class AutoModelForCausalLM extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForMaskedLM.from_pretrained('bert-base-uncased'); + * let model = await AutoModelForMaskedLM.from_pretrained('Xenova/bert-base-uncased'); */ export class AutoModelForMaskedLM extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]; @@ -3618,7 +3718,7 @@ export class AutoModelForMaskedLM extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad'); + * let model = await AutoModelForQuestionAnswering.from_pretrained('Xenova/distilbert-base-cased-distilled-squad'); */ export class AutoModelForQuestionAnswering extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]; @@ -3629,7 +3729,7 @@ export class AutoModelForQuestionAnswering extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForVision2Seq.from_pretrained('nlpconnect/vit-gpt2-image-captioning'); + * let model = await AutoModelForVision2Seq.from_pretrained('Xenova/vit-gpt2-image-captioning'); */ export class AutoModelForVision2Seq extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]; @@ -3640,7 +3740,7 @@ export class AutoModelForVision2Seq extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224'); + * let model = await AutoModelForImageClassification.from_pretrained('Xenova/vit-base-patch16-224'); */ export class AutoModelForImageClassification extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]; @@ -3651,7 +3751,7 @@ export class AutoModelForImageClassification extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForImageSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic'); + * let model = await AutoModelForImageSegmentation.from_pretrained('Xenova/detr-resnet-50-panoptic'); */ export class AutoModelForImageSegmentation extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]; @@ -3662,7 +3762,7 @@ export class AutoModelForImageSegmentation extends PretrainedMixin { * The chosen model class is determined by the type specified in the model config. * * @example - * let model = await AutoModelForObjectDetection.from_pretrained('facebook/detr-resnet-50'); + * let model = await AutoModelForObjectDetection.from_pretrained('Xenova/detr-resnet-50'); */ export class AutoModelForObjectDetection extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]; From 2e08b744417958b5294a473d6717bb7aeaf137af Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 4 Nov 2023 18:49:01 +0200 Subject: [PATCH 004/473] Only use necessary inputs for prompt_encoder_mask_decoder --- src/models.js | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/models.js b/src/models.js index c457e1de5..a94b09f3a 100644 --- a/src/models.js +++ b/src/models.js @@ -3386,7 +3386,7 @@ export class SamPreTrainedModel extends PreTrainedModel { } * // size: 14002632 * // } * // ] -* const scores = outputs.iou_scores; + * const scores = outputs.iou_scores; * // Tensor { * // dims: [ 1, 1, 3 ], * // type: 'float32', @@ -3452,7 +3452,11 @@ export class SamModel extends SamPreTrainedModel { // Returns: // - iou_scores: tensor.float32[batch_size,point_batch_size,3] // - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256] - return await sessionRun(this.prompt_encoder_mask_decoder, model_inputs); + return await sessionRun(this.prompt_encoder_mask_decoder, { + input_points: model_inputs.input_points, + image_embeddings: model_inputs.image_embeddings, + image_positional_embeddings: model_inputs.image_positional_embeddings, + }); } /** From 08fef47501e87755bac3fa6670b0fa00ded4a92d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 4 Nov 2023 23:50:57 +0200 Subject: [PATCH 005/473] Update to onnxruntime v1.16.1 --- package-lock.json | 51 ++++++++++++++++++----------------------------- package.json | 4 ++-- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/package-lock.json b/package-lock.json index fd8534bac..a71e019d1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,7 +9,7 @@ "version": "3.0.0-alpha.0", "license": "Apache-2.0", "dependencies": { - "onnxruntime-web": "1.15.1", + "onnxruntime-web": "1.16.1", "sharp": "^0.32.0" }, "devDependencies": { @@ -26,7 +26,7 @@ "webpack-dev-server": "^4.13.3" }, "optionalDependencies": { - "onnxruntime-node": "1.15.1" + "onnxruntime-node": "1.16.1" } }, "node_modules/@ampproject/remapping": { @@ -5310,9 +5310,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==" + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" }, "node_modules/lru-cache": { "version": "6.0.0", @@ -5736,23 +5736,15 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/onnx-proto": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-4.0.4.tgz", - "integrity": "sha512-aldMOB3HRoo6q/phyB6QRQxSt895HNNw82BNyZ2CMh4bjeKv7g/c+VpAFtJuEMVfYLMbRx61hbuqnKceLeDcDA==", - "dependencies": { - "protobufjs": "^6.8.8" - } - }, "node_modules/onnxruntime-common": { - "version": "1.15.1", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.15.1.tgz", - "integrity": "sha512-Y89eJ8QmaRsPZPWLaX7mfqhj63ny47rSkQe80hIo+lvBQdrdXYR9VO362xvZulk9DFkCnXmGidprvgJ07bKsIQ==" + "version": "1.16.1", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.16.1.tgz", + "integrity": "sha512-dmKye7bL4/aKhF561h+o9yw1hCCcGfYRN1BoycXm+WUjWhAVlGkP6JdBcRk7MQ1qX/ASFk+8Ibl+yVgCTSP0Fg==" }, "node_modules/onnxruntime-node": { - "version": "1.15.1", - "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.15.1.tgz", - "integrity": "sha512-wzhVELulmrvNoMZw0/HfV+9iwgHX+kPS82nxodZ37WCXmbeo1jp3thamTsNg8MGhxvv4GmEzRum5mo40oqIsqw==", + "version": "1.16.1", + "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.16.1.tgz", + "integrity": "sha512-o/0zHhfViD1UF91o+ATbYD9QanirdgVnSZK4GVcNaSYNBIhYsiHtIJ0yjnNIA2ZGWoTKCtmR+kcTLliCGXaucw==", "optional": true, "os": [ "win32", @@ -5760,20 +5752,20 @@ "linux" ], "dependencies": { - "onnxruntime-common": "~1.15.1" + "onnxruntime-common": "~1.16.1" } }, "node_modules/onnxruntime-web": { - "version": "1.15.1", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.15.1.tgz", - "integrity": "sha512-Ky4AXFLFyiGRu5KQJdDcbhdNcO0f2ND/8IPmTEwcKKIHpCwH6/Q9UoMpcoFz78lxGvnmmy+FFgA/Bs1HjdM6LA==", + "version": "1.16.1", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.16.1.tgz", + "integrity": "sha512-4MkQvqXgQCYdxGgoPdprDH2AWoU9MaceXJj7peSFjkG5keK/3FdG7s/iiBaHZDa3Kv0jTRxAr/OXFrSZ/hl+vA==", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", - "long": "^4.0.0", - "onnx-proto": "^4.0.4", - "onnxruntime-common": "~1.15.1", - "platform": "^1.3.6" + "long": "^5.2.3", + "onnxruntime-common": "~1.16.1", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" } }, "node_modules/open": { @@ -6054,11 +6046,6 @@ "node": ">=12.0.0" } }, - "node_modules/protobufjs/node_modules/long": { - "version": "5.2.3", - "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", - "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" - }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", diff --git a/package.json b/package.json index 7fd1e70ea..467765189 100644 --- a/package.json +++ b/package.json @@ -38,11 +38,11 @@ }, "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { - "onnxruntime-web": "1.15.1", + "onnxruntime-web": "1.16.1", "sharp": "^0.32.0" }, "optionalDependencies": { - "onnxruntime-node": "1.15.1" + "onnxruntime-node": "1.16.1" }, "devDependencies": { "@types/jest": "^29.5.1", From ee24941ddd7492358238c49a217b08fcd57e836f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 5 Nov 2023 19:46:13 +0200 Subject: [PATCH 006/473] Binarize mask with Uint8Array data --- src/processors.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/processors.js b/src/processors.js index 954d699fb..9b9014a96 100644 --- a/src/processors.js +++ b/src/processors.js @@ -887,7 +887,7 @@ export class SamImageProcessor extends ImageFeatureExtractor { if (binarize) { interpolated_mask = new Tensor( 'bool', - Array.from(interpolated_mask.data).map(x => x > mask_threshold), + Uint8Array.from(interpolated_mask.data.map(x => +(x > mask_threshold))), interpolated_mask.dims ) } From 8913e9a3682727549ff746cbf9a68e9a649d9802 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 5 Nov 2023 19:52:45 +0200 Subject: [PATCH 007/473] Allow for separate computation of reshaped input points --- src/processors.js | 71 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/src/processors.js b/src/processors.js index 9b9014a96..59156fbcc 100644 --- a/src/processors.js +++ b/src/processors.js @@ -773,20 +773,18 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor { */ export class SamImageProcessor extends ImageFeatureExtractor { + /** - * @param {any[]} images The URL(s) of the image(s) to extract features from. - * @param {*} input_points A 3D or 4D array, representing the input points provided by the user. - * - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1. - * - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`. - * @returns {Promise} + * + * @param {*} input_points + * @param {HeightWidth[]} original_sizes + * @param {HeightWidth[]} reshaped_input_sizes + * @returns {Tensor} */ - async _call(images, input_points) { - let { - pixel_values, - original_sizes, - reshaped_input_sizes, - } = await super._call(images); + reshape_input_points(input_points, original_sizes, reshaped_input_sizes) { + // Make deep copy to avoid altering user's input + input_points = structuredClone(input_points); let shape = calculateDimensions(input_points); // TODO: add support for 2D input_points @@ -817,18 +815,31 @@ export class SamImageProcessor extends ImageFeatureExtractor { } } - let input_points_tensor = new Tensor( + return new Tensor( 'float32', Float32Array.from(input_points.flat(Infinity)), shape ) - // TODO: allowed to be floats? - // let input_points_tensor = new Tensor( - // 'float32', - // Float32Array.from(input_points.flat(Infinity)), - // shape - // ) + } + /** + * @param {any[]} images The URL(s) of the image(s) to extract features from. + * @param {*} input_points A 3D or 4D array, representing the input points provided by the user. + * - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1. + * - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`. + * @returns {Promise} + */ + async _call(images, input_points) { + // TODO allow user to use preprocessed images + const { + pixel_values, + original_sizes, + reshaped_input_sizes, + } = await super._call(images); + + const input_points_tensor = this.reshape_input_points( + input_points, original_sizes, reshaped_input_sizes + ); return { pixel_values, @@ -906,6 +917,30 @@ export class SamImageProcessor extends ImageFeatureExtractor { return output_masks; } + + /** + * Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + * @param {RawImage} image Input original image + * @param {number} target_size Target size of the resized image + * @param {Object} options Options for generating crop boxes + * @param {number} [options.crop_n_layers] If >0, mask prediction will be run again on crops of the image. + * Sets the number of layers to run, where each layer has 2**i_layer number of image crops. + * @param {number} [options.overlap_ratio] Sets the degree to which crops overlap. In the first crop layer, + * crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap. + * @param {number} [options.points_per_crop] Number of points to sample from each crop. + * @param {number} [options.crop_n_points_downscale_factor] The number of points-per-side sampled in layer n is + * scaled down by crop_n_points_downscale_factor**n. + * @returns {Object} An object containing the crop boxes, number of points per crop, cropped images, and input labels. + */ + generate_crop_boxes(image, target_size, { + crop_n_layers = 0, + overlap_ratio = 512 / 1500, + points_per_crop = 32, + crop_n_points_downscale_factor = 1, + } = {}) { + // TODO: Implement + // return { crop_boxes, points_per_crop, cropped_images, input_labels } + } } From 1acb6356f7961202c97173b0e42d260498cf3343 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 5 Nov 2023 19:58:58 +0200 Subject: [PATCH 008/473] Update `calculateDimensions` typing --- src/utils/core.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/core.js b/src/utils/core.js index 72e198789..3b9a5d856 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -121,7 +121,7 @@ export function exists(x) { * Calculates the dimensions of a nested array. * * @param {Array} arr The nested array to calculate dimensions for. - * @returns {Array} An array containing the dimensions of the input array. + * @returns {number[]} An array containing the dimensions of the input array. */ export function calculateDimensions(arr) { const dimensions = []; From 7e115dca762cb185e54a590080eb583b8ee096dc Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 5 Nov 2023 19:59:28 +0200 Subject: [PATCH 009/473] Implement additional helper functions for `RawImage` --- src/utils/image.js | 49 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/utils/image.js b/src/utils/image.js index 0016de8dc..b96ae01bd 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -16,6 +16,7 @@ import { env } from '../env.js'; import sharp from 'sharp'; const BROWSER_ENV = typeof self !== 'undefined'; +const WEBWORKER_ENV = BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope'; let createCanvasFunction; let ImageDataClass; @@ -156,6 +157,21 @@ export class RawImage { } } + /** + * Helper method to create a new Image from a tensor + * @param {import('./tensor.js').Tensor} tensor + */ + static fromTensor(tensor, channel_format = 'CHW') { + if (channel_format === 'CHW') { + tensor = tensor.transpose(1, 2, 0); + } else if (channel_format === 'HWC') { + // Do nothing + } else { + throw new Error(`Unsupported channel format: ${channel_format}`); + } + return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]); + } + /** * Convert the image to grayscale format. * @returns {RawImage} `this` to support chaining. @@ -487,6 +503,15 @@ export class RawImage { } } + async toBlob(type = 'image/png', quality = 1) { + if (!BROWSER_ENV) { + throw new Error('toBlob() is only supported in browser environments.') + } + + const canvas = this.toCanvas(); + return await canvas.convertToBlob({ type, quality }); + } + toCanvas() { if (!BROWSER_ENV) { throw new Error('toCanvas() is only supported in browser environments.') @@ -563,14 +588,15 @@ export class RawImage { save(path) { if (BROWSER_ENV) { + if (WEBWORKER_ENV) { + throw new Error('Unable to save an image from a Web Worker.') + } + const extension = path.split('.').pop().toLowerCase(); const mime = this._CONTENT_TYPE_MAP[extension] ?? 'image/png'; - // Convert image to canvas - const canvas = this.toCanvas(); - // Convert the canvas content to a data URL - const dataURL = canvas.toDataURL(mime); + const dataURL = this.toDataURL(mime); // Create an anchor element with the data URL as the href attribute const downloadLink = document.createElement('a'); @@ -594,6 +620,21 @@ export class RawImage { } } + /** + * Convert the image to a data URL. + * @param {string} mime The MIME type of the image. + * @returns {string} The data URL. + */ + toDataURL(mime = 'image/png') { + if (!BROWSER_ENV || WEBWORKER_ENV) { + throw new Error('toDataURL() is only supported in browser environments.') + } + // Convert image to canvas + const canvas = this.toCanvas(); + + // Convert the canvas content to a data URL + return canvas.toDataURL(mime); + } toSharp() { if (BROWSER_ENV) { throw new Error('toSharp() is only supported in server-side environments.') From 55066b1b5b3a410edeae4b5e321b22cb47584f6d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 5 Nov 2023 19:59:43 +0200 Subject: [PATCH 010/473] Implement tensor multiplication function --- src/utils/tensor.js | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index f5c6dff83..887b78122 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -160,6 +160,28 @@ export class Tensor extends ONNXTensor { return this; } + + /** + * Return a new Tensor with every element multiplied by a constant. + * @param {number} val The value to multiply by. + * @returns {Tensor} The new tensor. + */ + mul(val) { + return this.clone().mul_(val); + } + + /** + * Multiply the tensor by a constant in place. + * @param {number} val The value to multiply by. + * @returns {Tensor} Returns `this`. + */ + mul_(val) { + for (let i = 0; i < this.data.length; ++i) { + this.data[i] *= val; + } + return this; + } + clone() { return new Tensor(this.type, this.data.slice(), this.dims.slice()); } From 49e669a3ee4efe20006bbd88c2ecb9dd09d2d3ea Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 7 Nov 2023 00:50:52 +0200 Subject: [PATCH 011/473] Do padding after rescaling/normalizing --- src/processors.js | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/processors.js b/src/processors.js index 59156fbcc..73707e5c6 100644 --- a/src/processors.js +++ b/src/processors.js @@ -339,17 +339,8 @@ export class ImageFeatureExtractor extends FeatureExtractor { /** @type {HeightWidth} */ let reshaped_input_size = [image.height, image.width]; - // TODO is it okay to pad before rescaling/normalizing? - if (this.do_pad && this.pad_size) { - let left = 0; - let right = this.pad_size.width - image.width; - let top = 0; - let bottom = this.pad_size.height - image.height; - - image = await image.pad([left, right, top, bottom]); - } - - const pixelData = Float32Array.from(image.data); + let pixelData = Float32Array.from(image.data); + let imgDims = [image.height, image.width, image.channels]; if (this.do_rescale) { for (let i = 0; i < pixelData.length; ++i) { @@ -379,10 +370,34 @@ export class ImageFeatureExtractor extends FeatureExtractor { } } + // do padding after rescaling/normalizing + if (this.do_pad && this.pad_size) { + + const paddedPixelData = new Float32Array(this.pad_size.width * this.pad_size.height * image.channels); + + // Copy the original image into the padded image + for (let i = 0; i < image.height; ++i) { + const a = i * this.pad_size.width; + const b = i * image.width; + for (let j = 0; j < image.width; ++j) { + const c = (a + j) * image.channels; + const d = (b + j) * image.channels; + for (let k = 0; k < image.channels; ++k) { + paddedPixelData[c + k] = pixelData[d + k]; + } + } + } + + // Update pixel data and image dimensions + pixelData = paddedPixelData; + imgDims = [this.pad_size.height, this.pad_size.width, image.channels] + } + + // Create HWC tensor + const img = new Tensor('float32', pixelData, imgDims); + // convert to channel dimension format: - let imgDims = [image.height, image.width, image.channels]; - let img = new Tensor('float32', pixelData, imgDims); - let transposed = transpose(img, [2, 0, 1]); // hwc -> chw + const transposed = transpose(img, [2, 0, 1]); // hwc -> chw return { original_size: [srcHeight, srcWidth], From 190367398ecd0d75b0b4f4988ae4e71dae4519de Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 7 Nov 2023 00:51:10 +0200 Subject: [PATCH 012/473] Minor reformatting --- src/utils/image.js | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/utils/image.js b/src/utils/image.js index b96ae01bd..0c91f2c69 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -286,10 +286,10 @@ export class RawImage { // TODO use `resample` in browser environment // Store number of channels before resizing - let numChannels = this.channels; + const numChannels = this.channels; // Create canvas object for this image - let canvas = this.toCanvas(); + const canvas = this.toCanvas(); // Actually perform resizing using the canvas API const ctx = createCanvasFunction(width, height).getContext('2d'); @@ -297,8 +297,11 @@ export class RawImage { // Draw image to context, resizing in the process ctx.drawImage(canvas, 0, 0, width, height); + // Extract the resized data + const imageData = ctx.getImageData(0, 0, width, height).data; + // Create image from the resized data - let resizedImage = new RawImage(ctx.getImageData(0, 0, width, height).data, width, height, 4); + const resizedImage = new RawImage(imageData, width, height, 4); // Convert back so that image has the same number of channels as before return resizedImage.convert(numChannels); From 704d95dfba495a7e5f56731abf3feb552740a8d1 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 7 Nov 2023 00:51:27 +0200 Subject: [PATCH 013/473] Update allowed types for `min` and `max` functions --- src/utils/maths.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/maths.js b/src/utils/maths.js index 8bf5c7e95..47b998ff0 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.js @@ -232,7 +232,7 @@ export function magnitude(arr) { /** * Returns the value and index of the minimum element in an array. - * @param {number[]} arr array of numbers. + * @param {number[]|TypedArray} arr array of numbers. * @returns {number[]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin] * @throws {Error} If array is empty. */ @@ -252,7 +252,7 @@ export function min(arr) { /** * Returns the value and index of the maximum element in an array. - * @param {number[]} arr array of numbers. + * @param {number[]|TypedArray} arr array of numbers. * @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] * @throws {Error} If array is empty. */ From 66da1301ea6952d37b8f0bf7630d644a904beb03 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 7 Nov 2023 03:38:43 +0200 Subject: [PATCH 014/473] Update compatibility for webgpu EP Requires dynamic imports --- src/backends/onnx.js | 107 ++++++++++++++++++++++++++++++++++++++----- src/env.js | 23 ++++------ src/models.js | 35 ++++---------- src/utils/tensor.js | 19 ++++---- 4 files changed, 121 insertions(+), 63 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index a06beb0e3..c82c42345 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -16,30 +16,32 @@ * @module backends/onnx */ +import path from 'path'; +import { env, RUNNING_LOCALLY } from '../env.js'; + // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`. // In either case, we select the default export if it exists, otherwise we use the named export. import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web'; /** @type {module} The ONNX runtime module. */ -export let ONNX; - -export const executionProviders = [ - // 'webgpu', - 'wasm' -]; +let ONNX; -if (typeof process !== 'undefined' && process?.release?.name === 'node') { - // Running in a node-like environment. - ONNX = ONNX_NODE.default ?? ONNX_NODE; +const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; +const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node' - // Add `cpu` execution provider, with higher precedence that `wasm`. - executionProviders.unshift('cpu'); +const ONNX_MODULES = new Map(); +if (USE_ONNXRUNTIME_NODE) { + ONNX = ONNX_NODE.default ?? ONNX_NODE; + ONNX_MODULES.set('node', ONNX); } else { - // Running in a browser-environment + // @ts-ignore ONNX = ONNX_WEB.default ?? ONNX_WEB; + ONNX_MODULES.set('web', ONNX); + // Running in a browser-environment + // TODO: Check if 1.16.1 fixes this issue. // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x). // As a temporary fix, we disable it for now. // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644 @@ -48,3 +50,84 @@ if (typeof process !== 'undefined' && process?.release?.name === 'node') { ONNX.env.wasm.simd = false; } } + +/** + * Create an ONNX inference session, with fallback support if an operation is not supported. + * @param {Uint8Array} buffer + * @returns {Promise} The ONNX inference session. + */ +export async function createInferenceSession(buffer) { + let executionProviders; + let InferenceSession; + if (USE_ONNXRUNTIME_NODE) { + const ONNX_NODE = ONNX_MODULES.get('node'); + InferenceSession = ONNX_NODE.InferenceSession; + executionProviders = ['cpu']; + + } else if (WEBGPU_AVAILABLE && env.experimental.useWebGPU) { + // Only import the WebGPU version if the user enables the experimental flag. + let ONNX_WEBGPU = ONNX_MODULES.get('webgpu'); + if (ONNX_WEBGPU === undefined) { + ONNX_WEBGPU = await import('onnxruntime-web/webgpu'); + ONNX_MODULES.set('webgpu', ONNX_WEBGPU) + } + + InferenceSession = ONNX_WEBGPU.InferenceSession; + + // If WebGPU is available and the user enables the experimental flag, try to use the WebGPU execution provider. + executionProviders = ['webgpu', 'wasm']; + + ONNX_WEBGPU.env = env.backends.onnx; + + } else { + const ONNX_WEB = ONNX_MODULES.get('web'); + InferenceSession = ONNX_WEB.InferenceSession; + executionProviders = ['wasm']; + env.backends.onnx = ONNX_MODULES.get('web').env + } + + try { + return await InferenceSession.create(buffer, { + executionProviders, + }); + } catch (err) { + // If the execution provided was only wasm, throw the error + if (executionProviders.length === 1 && executionProviders[0] === 'wasm') { + throw err; + } + + console.warn(err); + console.warn( + 'Something went wrong during model construction (most likely a missing operation). ' + + 'Using `wasm` as a fallback. ' + ) + return await InferenceSession.create(buffer, { + executionProviders: ['wasm'] + }); + } +} + +/** + * Check if an object is an ONNX tensor. + * @param {any} x The object to check + * @returns {boolean} Whether the object is an ONNX tensor. + */ +export function isONNXTensor(x) { + for (const module of ONNX_MODULES.values()) { + if (x instanceof module.Tensor) { + return true; + } + } + return false; +} + +// Set path to wasm files. This is needed when running in a web worker. +// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths +// We use remote wasm files by default to make it easier for newer users. +// In practice, users should probably self-host the necessary .wasm files. +ONNX.env.wasm.wasmPaths = RUNNING_LOCALLY + ? path.join(env.__dirname, '/dist/') + : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`; + +// Expose ONNX environment variables to `env.backends.onnx` +env.backends.onnx = ONNX.env; diff --git a/src/env.js b/src/env.js index b522b9c94..b4c6d5bcc 100644 --- a/src/env.js +++ b/src/env.js @@ -26,9 +26,6 @@ import fs from 'fs'; import path from 'path'; import url from 'url'; -import { ONNX } from './backends/onnx.js'; -const { env: onnx_env } = ONNX; - const VERSION = '3.0.0-alpha.0'; // Check if various APIs are available (depends on environment) @@ -36,7 +33,7 @@ const WEB_CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self; const FS_AVAILABLE = !isEmpty(fs); // check if file system is available const PATH_AVAILABLE = !isEmpty(path); // check if path is available -const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE; +export const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE; const __dirname = RUNNING_LOCALLY ? path.dirname(path.dirname(url.fileURLToPath(import.meta.url))) @@ -53,14 +50,6 @@ const localModelPath = RUNNING_LOCALLY ? path.join(__dirname, DEFAULT_LOCAL_MODEL_PATH) : DEFAULT_LOCAL_MODEL_PATH; -// Set path to wasm files. This is needed when running in a web worker. -// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths -// We use remote wasm files by default to make it easier for newer users. -// In practice, users should probably self-host the necessary .wasm files. -onnx_env.wasm.wasmPaths = RUNNING_LOCALLY - ? path.join(__dirname, '/dist/') - : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${VERSION}/dist/`; - /** * Global variable used to control execution. This provides users a simple way to configure Transformers.js. @@ -83,16 +72,24 @@ onnx_env.wasm.wasmPaths = RUNNING_LOCALLY * @property {Object} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache */ + export const env = { /////////////////// Backends settings /////////////////// + // NOTE: These will be populated later by the backends themselves. backends: { // onnxruntime-web/onnxruntime-node - onnx: onnx_env, + onnx: {}, // TensorFlow.js tfjs: {}, }, + /////////////////// Experimental settings /////////////////// + experimental: { + // Whether to use the experimental WebGPU backend for ONNX.js. + useWebGPU: false, + }, + __dirname, version: VERSION, diff --git a/src/models.js b/src/models.js index a94b09f3a..7a788f5bc 100644 --- a/src/models.js +++ b/src/models.js @@ -80,9 +80,8 @@ import { Tensor, } from './utils/tensor.js'; -import { executionProviders, ONNX } from './backends/onnx.js'; +import { createInferenceSession, isONNXTensor } from './backends/onnx.js'; import { medianFilter } from './transformers.js'; -const { InferenceSession, Tensor: ONNXTensor } = ONNX; ////////////////////////////////////////////////// // Model types: used internally @@ -111,38 +110,19 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A Promise that resolves to an InferenceSession object. + * @returns {Promise} A Promise that resolves to an InferenceSession object. * @private */ async function constructSession(pretrained_model_name_or_path, fileName, options) { // TODO add option for user to force specify their desired execution provider let modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`; let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); - - try { - return await InferenceSession.create(buffer, { - executionProviders, - }); - } catch (err) { - // If the execution provided was only wasm, throw the error - if (executionProviders.length === 1 && executionProviders[0] === 'wasm') { - throw err; - } - - console.warn(err); - console.warn( - 'Something went wrong during model construction (most likely a missing operation). ' + - 'Using `wasm` as a fallback. ' - ) - return await InferenceSession.create(buffer, { - executionProviders: ['wasm'] - }); - } + return await createInferenceSession(buffer); } /** * Validate model inputs - * @param {InferenceSession} session The InferenceSession object that will be run. + * @param {Object} session The InferenceSession object that will be run. * @param {Object} inputs The inputs to check. * @returns {Promise} A Promise that resolves to the checked inputs. * @throws {Error} If any inputs are missing. @@ -182,7 +162,7 @@ async function validateInputs(session, inputs) { * - If additional inputs are passed, they will be ignored. * - If inputs are missing, an error will be thrown. * - * @param {InferenceSession} session The InferenceSession object to run. + * @param {Object} session The InferenceSession object to run. * @param {Object} inputs An object that maps input names to input tensors. * @returns {Promise} A Promise that resolves to an object that maps output names to output tensors. * @private @@ -209,7 +189,7 @@ async function sessionRun(session, inputs) { */ function replaceTensors(obj) { for (let prop in obj) { - if (obj[prop] instanceof ONNXTensor) { + if (isONNXTensor(obj[prop])) { obj[prop] = new Tensor(obj[prop]); } else if (typeof obj[prop] === 'object') { replaceTensors(obj[prop]); @@ -639,7 +619,8 @@ export class PreTrainedModel extends Callable { let promises = []; for (let key of Object.keys(this)) { let item = this[key]; - if (item instanceof InferenceSession) { + // TODO improve check for ONNX session + if (item?.handler?.dispose !== undefined) { promises.push(item.handler.dispose()) } } diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 887b78122..6382cbf67 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -7,8 +7,6 @@ * @module utils/tensor */ -import { ONNX } from '../backends/onnx.js'; - import { interpolate_data, transpose_data @@ -19,22 +17,21 @@ import { * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray */ -/** @type {Object} */ -const ONNXTensor = ONNX.Tensor; - -export class Tensor extends ONNXTensor { +export class Tensor { /** * Create a new Tensor or copy an existing Tensor. - * @param {[string, DataArray, number[]]|[ONNXTensor]} args + * @param {[string, DataArray, number[]]|Object} args */ constructor(...args) { - if (args[0] instanceof ONNX.Tensor) { + if (args.length === 1) { // Create shallow copy - super(args[0].type, args[0].data, args[0].dims); + Object.assign(this, args[0]); } else { - // Create new - super(...args); + // Create new tensor + this.type = args[0]; + this.data = args[1]; + this.dims = args[2]; } return new Proxy(this, { From 27fb0dcd7745927269d0ae3de9bc55ddc564887d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 7 Nov 2023 04:36:59 +0200 Subject: [PATCH 015/473] Fix assignment issue --- src/backends/onnx.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index c82c42345..2bbd7092c 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -77,7 +77,7 @@ export async function createInferenceSession(buffer) { // If WebGPU is available and the user enables the experimental flag, try to use the WebGPU execution provider. executionProviders = ['webgpu', 'wasm']; - ONNX_WEBGPU.env = env.backends.onnx; + Object.assign(ONNX_WEBGPU.env, env.backends.onnx); } else { const ONNX_WEB = ONNX_MODULES.get('web'); From 32bcaf7c9c5b6ccf190941a56335633364fb150b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 26 Nov 2023 17:56:44 +0200 Subject: [PATCH 016/473] Update dependency versions --- package-lock.json | 98 ++++++++++++++++++++++++++++++++++------------- package.json | 6 +-- 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/package-lock.json b/package-lock.json index a71e019d1..928e4b18f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,8 +9,8 @@ "version": "3.0.0-alpha.0", "license": "Apache-2.0", "dependencies": { - "onnxruntime-web": "1.16.1", - "sharp": "^0.32.0" + "onnxruntime-web": "1.16.3", + "sharp": "^0.32.6" }, "devDependencies": { "@types/jest": "^29.5.1", @@ -26,7 +26,7 @@ "webpack-dev-server": "^4.13.3" }, "optionalDependencies": { - "onnxruntime-node": "1.16.1" + "onnxruntime-node": "1.16.3" } }, "node_modules/@ampproject/remapping": { @@ -2017,6 +2017,11 @@ "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", "dev": true }, + "node_modules/b4a": { + "version": "1.6.4", + "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz", + "integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw==" + }, "node_modules/babel-jest": { "version": "29.6.1", "resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-29.6.1.tgz", @@ -3014,9 +3019,9 @@ } }, "node_modules/detect-libc": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.1.tgz", - "integrity": "sha512-463v3ZeIrcWtdgIg6vI6XUncguvr2TnGl4SzDXinkt9mSLpBJKXT3mW6xT3VQdDN11+WVs29pgvivTc4Lp8v+w==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz", + "integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==", "engines": { "node": ">=8" } @@ -3415,6 +3420,11 @@ "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", "dev": true }, + "node_modules/fast-fifo": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", + "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==" + }, "node_modules/fast-glob": { "version": "3.2.12", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", @@ -5616,9 +5626,9 @@ } }, "node_modules/node-addon-api": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.0.0.tgz", - "integrity": "sha512-GyHvgPvUXBvAkXa0YvYnhilSB1A+FRYMpIVggKzPZqdaZfevZOuzfWzyvgzOwRLHBeo/MMswmJFsrNF4Nw1pmA==" + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz", + "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA==" }, "node_modules/node-forge": { "version": "1.3.1", @@ -5737,14 +5747,14 @@ } }, "node_modules/onnxruntime-common": { - "version": "1.16.1", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.16.1.tgz", - "integrity": "sha512-dmKye7bL4/aKhF561h+o9yw1hCCcGfYRN1BoycXm+WUjWhAVlGkP6JdBcRk7MQ1qX/ASFk+8Ibl+yVgCTSP0Fg==" + "version": "1.16.3", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.16.3.tgz", + "integrity": "sha512-ZZfFzEqBf6YIGwB9PtBLESHI53jMXA+/hn+ACVUbEfPuK2xI5vMGpLPn+idpwCmHsKJNRzRwqV12K+6TQj6tug==" }, "node_modules/onnxruntime-node": { - "version": "1.16.1", - "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.16.1.tgz", - "integrity": "sha512-o/0zHhfViD1UF91o+ATbYD9QanirdgVnSZK4GVcNaSYNBIhYsiHtIJ0yjnNIA2ZGWoTKCtmR+kcTLliCGXaucw==", + "version": "1.16.3", + "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.16.3.tgz", + "integrity": "sha512-6T2pjwg5ik74VnI1IXFzxvPAm2UCo+vNNsDGbMP+A2q6GZPMYai2pMA17g3YMUvgOZLwsjWBUwNIlP4QaVRFlA==", "optional": true, "os": [ "win32", @@ -5752,18 +5762,18 @@ "linux" ], "dependencies": { - "onnxruntime-common": "~1.16.1" + "onnxruntime-common": "~1.16.3" } }, "node_modules/onnxruntime-web": { - "version": "1.16.1", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.16.1.tgz", - "integrity": "sha512-4MkQvqXgQCYdxGgoPdprDH2AWoU9MaceXJj7peSFjkG5keK/3FdG7s/iiBaHZDa3Kv0jTRxAr/OXFrSZ/hl+vA==", + "version": "1.16.3", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.16.3.tgz", + "integrity": "sha512-8O1xCG/RcNQNYYWvdiQJSNpncVg78OVOFeV6MYs/jx++/b12oje8gYUzKqz9wR/sXiX/8TCvdyHgEjj5gQGKUg==", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", "long": "^5.2.3", - "onnxruntime-common": "~1.16.1", + "onnxruntime-common": "~1.16.3", "platform": "^1.3.6", "protobufjs": "^7.2.4" } @@ -6137,6 +6147,11 @@ } ] }, + "node_modules/queue-tick": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz", + "integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag==" + }, "node_modules/randombytes": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/randombytes/-/randombytes-2.1.0.tgz", @@ -6676,18 +6691,18 @@ } }, "node_modules/sharp": { - "version": "0.32.0", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.0.tgz", - "integrity": "sha512-yLAypVcqj1toSAqRSwbs86nEzfyZVDYqjuUX8grhFpeij0DDNagKJXELS/auegDBRDg1XBtELdOGfo2X1cCpeA==", + "version": "0.32.6", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", + "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", - "detect-libc": "^2.0.1", - "node-addon-api": "^6.0.0", + "detect-libc": "^2.0.2", + "node-addon-api": "^6.1.0", "prebuild-install": "^7.1.1", - "semver": "^7.3.8", + "semver": "^7.5.4", "simple-get": "^4.0.1", - "tar-fs": "^2.1.1", + "tar-fs": "^3.0.4", "tunnel-agent": "^0.6.0" }, "engines": { @@ -6697,6 +6712,26 @@ "url": "https://opencollective.com/libvips" } }, + "node_modules/sharp/node_modules/tar-fs": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz", + "integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==", + "dependencies": { + "mkdirp-classic": "^0.5.2", + "pump": "^3.0.0", + "tar-stream": "^3.1.5" + } + }, + "node_modules/sharp/node_modules/tar-stream": { + "version": "3.1.6", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz", + "integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==", + "dependencies": { + "b4a": "^1.6.4", + "fast-fifo": "^1.2.0", + "streamx": "^2.15.0" + } + }, "node_modules/shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", @@ -7013,6 +7048,15 @@ "node": ">=0.10.0" } }, + "node_modules/streamx": { + "version": "2.15.5", + "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.5.tgz", + "integrity": "sha512-9thPGMkKC2GctCzyCUjME3yR03x2xNo0GPKGkRw2UMYN+gqWa9uqpyNWhmsNCutU5zHmkUum0LsCRQTXUgUCAg==", + "dependencies": { + "fast-fifo": "^1.1.0", + "queue-tick": "^1.0.1" + } + }, "node_modules/string_decoder": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", diff --git a/package.json b/package.json index 467765189..9163bdec0 100644 --- a/package.json +++ b/package.json @@ -38,11 +38,11 @@ }, "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { - "onnxruntime-web": "1.16.1", - "sharp": "^0.32.0" + "onnxruntime-web": "1.16.3", + "sharp": "^0.32.6" }, "optionalDependencies": { - "onnxruntime-node": "1.16.1" + "onnxruntime-node": "1.16.3" }, "devDependencies": { "@types/jest": "^29.5.1", From aa57f009ea9e2a09b4e3f7bc363c3695de9bdd8d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 2 Dec 2023 18:13:33 +0200 Subject: [PATCH 017/473] post-merge cleanup --- src/backends/onnx.js | 13 +++++++++-- src/models.js | 8 +++---- src/utils/image.js | 54 ++++++++++++++++++++++---------------------- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index a7f9bc7f4..c3f051a05 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -24,8 +24,8 @@ import { env, RUNNING_LOCALLY } from '../env.js'; import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web'; -/** @type {import('onnxruntime-web')} The ONNX runtime module. */ -export let ONNX; +/** @type {import('onnxruntime-web')|import('onnxruntime-node')} The ONNX runtime module. */ +let ONNX; const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node' @@ -121,6 +121,15 @@ export function isONNXTensor(x) { return false; } +/** + * Check if ONNX's WASM backend is being proxied. + * @returns {boolean} Whether ONNX's WASM backend is being proxied. + */ +export function isONNXProxy() { + // TODO: Update this when allowing non-WASM backends. + return ONNX.env.wasm.proxy; +} + // Set path to wasm files. This is needed when running in a web worker. // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths // We use remote wasm files by default to make it easier for newer users. diff --git a/src/models.js b/src/models.js index 807de8d42..fe46d152a 100644 --- a/src/models.js +++ b/src/models.js @@ -85,7 +85,7 @@ import { Tensor, } from './utils/tensor.js'; -import { createInferenceSession, isONNXTensor } from './backends/onnx.js'; +import { createInferenceSession, isONNXTensor, isONNXProxy } from './backends/onnx.js'; import { medianFilter } from './transformers.js'; ////////////////////////////////////////////////// @@ -127,9 +127,9 @@ async function constructSession(pretrained_model_name_or_path, fileName, options /** * Validate model inputs - * @param {InferenceSession} session The InferenceSession object that will be run. + * @param {Object} session The InferenceSession object that will be run. * @param {Object} inputs The inputs to check. - * @returns {Promise} A Promise that resolves to the checked inputs. + * @returns {Record} The checked inputs. * @throws {Error} If any inputs are missing. * @private */ @@ -152,7 +152,7 @@ function validateInputs(session, inputs) { // NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker // boundary, transferring ownership to the worker and invalidating the tensor. // So, in this case, we simply sacrifice a clone for it. - checkedInputs[inputName] = env.wasm.proxy ? tensor.clone() : tensor; + checkedInputs[inputName] = isONNXProxy() ? tensor.clone() : tensor; } if (missingInputs.length > 0) { throw new Error( diff --git a/src/utils/image.js b/src/utils/image.js index 1b482f504..ac0fb5dce 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -39,7 +39,7 @@ if (BROWSER_ENV) { const metadata = await img.metadata(); const rawChannels = metadata.channels; - let { data, info } = await img.raw().toBuffer({ resolveWithObject: true }); + const { data, info } = await img.raw().toBuffer({ resolveWithObject: true }); const newImage = new RawImage(new Uint8ClampedArray(data), info.width, info.height, info.channels); if (rawChannels !== undefined && rawChannels !== info.channels) { @@ -128,11 +128,11 @@ export class RawImage { * @returns {Promise} The image object. */ static async fromURL(url) { - let response = await getFile(url); + const response = await getFile(url); if (response.status !== 200) { throw new Error(`Unable to read image from "${url}" (${response.status} ${response.statusText})`); } - let blob = await response.blob(); + const blob = await response.blob(); return this.fromBlob(blob); } @@ -144,7 +144,7 @@ export class RawImage { static async fromBlob(blob) { if (BROWSER_ENV) { // Running in environment with canvas - let img = await loadImageFunction(blob); + const img = await loadImageFunction(blob); const ctx = createCanvasFunction(img.width, img.height).getContext('2d'); @@ -155,7 +155,7 @@ export class RawImage { } else { // Use sharp.js to read (and possible resize) the image. - let img = sharp(await blob.arrayBuffer()); + const img = sharp(await blob.arrayBuffer()); return await loadImageFunction(img); } @@ -185,7 +185,7 @@ export class RawImage { return this; } - let newData = new Uint8ClampedArray(this.width * this.height * 1); + const newData = new Uint8ClampedArray(this.width * this.height * 1); switch (this.channels) { case 3: // rgb to grayscale case 4: // rgba to grayscale @@ -212,7 +212,7 @@ export class RawImage { return this; } - let newData = new Uint8ClampedArray(this.width * this.height * 3); + const newData = new Uint8ClampedArray(this.width * this.height * 3); switch (this.channels) { case 1: // grayscale to rgb @@ -245,7 +245,7 @@ export class RawImage { return this; } - let newData = new Uint8ClampedArray(this.width * this.height * 4); + const newData = new Uint8ClampedArray(this.width * this.height * 4); switch (this.channels) { case 1: // grayscale to rgba @@ -290,10 +290,10 @@ export class RawImage { // TODO use `resample` in browser environment // Store number of channels before resizing - let numChannels = this.channels; + const numChannels = this.channels; // Create canvas object for this image - let canvas = this.toCanvas(); + const canvas = this.toCanvas(); // Actually perform resizing using the canvas API const ctx = createCanvasFunction(width, height).getContext('2d'); @@ -302,7 +302,7 @@ export class RawImage { ctx.drawImage(canvas, 0, 0, width, height); // Create image from the resized data - let resizedImage = new RawImage(ctx.getImageData(0, 0, width, height).data, width, height, 4); + const resizedImage = new RawImage(ctx.getImageData(0, 0, width, height).data, width, height, 4); // Convert back so that image has the same number of channels as before return resizedImage.convert(numChannels); @@ -361,13 +361,13 @@ export class RawImage { if (BROWSER_ENV) { // Store number of channels before padding - let numChannels = this.channels; + const numChannels = this.channels; // Create canvas object for this image - let canvas = this.toCanvas(); + const canvas = this.toCanvas(); - let newWidth = this.width + left + right; - let newHeight = this.height + top + bottom; + const newWidth = this.width + left + right; + const newHeight = this.height + top + bottom; // Create a new canvas of the desired size. const ctx = createCanvasFunction(newWidth, newHeight).getContext('2d'); @@ -379,7 +379,7 @@ export class RawImage { ); // Create image from the padded data - let paddedImage = new RawImage( + const paddedImage = new RawImage( ctx.getImageData(0, 0, newWidth, newHeight).data, newWidth, newHeight, 4); @@ -387,7 +387,7 @@ export class RawImage { return paddedImage.convert(numChannels); } else { - let img = this.toSharp().extend({ left, right, top, bottom }); + const img = this.toSharp().extend({ left, right, top, bottom }); return await loadImageFunction(img); } } @@ -451,16 +451,16 @@ export class RawImage { } // Determine bounds of the image in the new canvas - let width_offset = (this.width - crop_width) / 2; - let height_offset = (this.height - crop_height) / 2; + const width_offset = (this.width - crop_width) / 2; + const height_offset = (this.height - crop_height) / 2; if (BROWSER_ENV) { // Store number of channels before resizing - let numChannels = this.channels; + const numChannels = this.channels; // Create canvas object for this image - let canvas = this.toCanvas(); + const canvas = this.toCanvas(); // Create a new canvas of the desired size. This is needed since if the // image is too small, we need to pad it with black pixels. @@ -490,7 +490,7 @@ export class RawImage { ); // Create image from the resized data - let resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4); + const resizedImage = new RawImage(ctx.getImageData(0, 0, crop_width, crop_height).data, crop_width, crop_height, 4); // Convert back so that image has the same number of channels as before return resizedImage.convert(numChannels); @@ -510,8 +510,8 @@ export class RawImage { } else if (width_offset <= 0 && height_offset <= 0) { // Cropped image lies entirely outside the original image, // so we add padding - let top = Math.floor(-height_offset); - let left = Math.floor(-width_offset); + const top = Math.floor(-height_offset); + const left = Math.floor(-width_offset); img = img.extend({ top: top, left: left, @@ -575,13 +575,13 @@ export class RawImage { // Clone, and convert data to RGBA before drawing to canvas. // This is because the canvas API only supports RGBA - let cloned = this.clone().rgba(); + const cloned = this.clone().rgba(); // Create canvas object for the cloned image - let clonedCanvas = createCanvasFunction(cloned.width, cloned.height); + const clonedCanvas = createCanvasFunction(cloned.width, cloned.height); // Draw image to context - let data = new ImageDataClass(cloned.data, cloned.width, cloned.height); + const data = new ImageDataClass(cloned.data, cloned.width, cloned.height); clonedCanvas.getContext('2d').putImageData(data, 0, 0); return clonedCanvas; From f3509e590eb28056f29b3561b4138245aab09edf Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 2 Dec 2023 20:10:45 +0200 Subject: [PATCH 018/473] Support webgpu build --- webpack.config.js | 129 +++++++++++++++++++++++++++++++--------------- 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/webpack.config.js b/webpack.config.js index 6adcfd9d9..e234c7e3e 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -5,47 +5,92 @@ import path from 'path'; const __dirname = path.dirname(fileURLToPath(import.meta.url)); -export default { - mode: 'development', - devtool: 'source-map', - entry: { - // include dist in entry point so that when running dev server, - // we can access the files with /dist/... - 'dist/transformers': './src/transformers.js', - 'dist/transformers.min': './src/transformers.js', - }, - output: { - filename: '[name].js', - path: __dirname, - library: { - type: 'module', +/** + * Helper function to create webpack configurations. + * @param {Object} options Options for creating a webpack target. + * @param {string} options.name Name of output file. + * @param {string} options.suffix Suffix of output file. + * @param {string} options.format Format of output file. + * @param {string} options.type Type of library. + * @param {string} options.dynamicImportMode Dynamic import mode. + * @returns {import('webpack').Configuration} One webpack target. + */ +function buildConfig({ + name = '', + suffix = '.js', + type = 'module', // 'module' | 'commonjs' + dynamicImportMode = undefined, // 'eager' | undefined +} = {}) { + const outputModule = type === 'module'; + + return { + mode: 'development', + devtool: 'source-map', + entry: { + // include dist in entry point so that when running dev server, + // we can access the files with /dist/... + [`dist/transformers${name}`]: './src/transformers.js', + [`dist/transformers${name}.min`]: './src/transformers.js', + }, + output: { + filename: `[name]${suffix}`, + path: __dirname, + library: { + type, + }, + }, + plugins: [ + // Copy .wasm files to dist folder + new CopyWebpackPlugin({ + patterns: [ + { + from: 'node_modules/onnxruntime-web/dist/*.wasm', + to: 'dist/[name][ext]' + }, + ], + }), + ], + optimization: { + minimize: true, + minimizer: [new TerserPlugin({ + test: new RegExp(`\\.min\\${suffix}$`), + extractComments: false, + })], + }, + experiments: { + outputModule, }, - }, - plugins: [ - // Copy .wasm files to dist folder - new CopyWebpackPlugin({ - patterns: [ - { - from: 'node_modules/onnxruntime-web/dist/*.wasm', - to: 'dist/[name][ext]' - }, - ], - }), - ], - optimization: { - minimize: true, - minimizer: [new TerserPlugin({ - test: /\.min\.js$/, - extractComments: false, - })], - }, - devServer: { - static: { - directory: __dirname + module: { + parser: { + javascript: { + dynamicImportMode, + } + } }, - port: 8080 - }, - experiments: { - outputModule: true, - }, -}; + + // Development server + devServer: { + static: { + directory: __dirname + }, + port: 8080 + }, + }; +} + + +export default [ + buildConfig({ + type: 'module', + }), + buildConfig({ + name: '.webgpu', + type: 'module', + dynamicImportMode: 'eager', + }), + // TODO: + // buildConfig({ + // suffix: '.cjs', + // type: 'commonjs', + // }), +]; From f8bc912d31b011c5997c253686974338e327ac7b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 2 Dec 2023 23:39:03 +0200 Subject: [PATCH 019/473] Remove `stream/web` import --- package.json | 3 +-- src/utils/hub.js | 6 ------ 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/package.json b/package.json index a8fba9d9f..5faba71c2 100644 --- a/package.json +++ b/package.json @@ -73,8 +73,7 @@ "path": false, "url": false, "sharp": false, - "onnxruntime-node": false, - "stream/web": false + "onnxruntime-node": false }, "publishConfig": { "access": "public" diff --git a/src/utils/hub.js b/src/utils/hub.js index a6219bc25..944f87f71 100644 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -7,16 +7,10 @@ import fs from 'fs'; import path from 'path'; -import stream from 'stream/web'; import { env } from '../env.js'; import { dispatchCallback } from './core.js'; -if (!globalThis.ReadableStream) { - // @ts-ignore - globalThis.ReadableStream = stream.ReadableStream; // ReadableStream is not a global with Node 16 -} - /** * @typedef {Object} PretrainedOptions Options for loading a pretrained model. * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). From da20dcc9d9eafebc4c9d07af1e8ab555de755e17 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 20 Feb 2024 16:58:00 -0800 Subject: [PATCH 020/473] fixes for ort-1.17 --- package-lock.json | 57 +++++++++++++++++---------------------------- package.json | 4 ++-- src/models.js | 11 +++++++-- src/utils/tensor.js | 39 ++++++++++++++++++++++++------- 4 files changed, 63 insertions(+), 48 deletions(-) diff --git a/package-lock.json b/package-lock.json index c953c7232..49472b742 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,7 +10,7 @@ "license": "Apache-2.0", "dependencies": { "@huggingface/jinja": "^0.1.0", - "onnxruntime-web": "1.14.0", + "onnxruntime-web": "1.17.0", "sharp": "^0.32.0" }, "devDependencies": { @@ -27,7 +27,7 @@ "webpack-dev-server": "^4.13.3" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.17.0" } }, "node_modules/@ampproject/remapping": { @@ -5322,9 +5322,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==" + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" }, "node_modules/lru-cache": { "version": "6.0.0", @@ -5748,23 +5748,15 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/onnx-proto": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-4.0.4.tgz", - "integrity": "sha512-aldMOB3HRoo6q/phyB6QRQxSt895HNNw82BNyZ2CMh4bjeKv7g/c+VpAFtJuEMVfYLMbRx61hbuqnKceLeDcDA==", - "dependencies": { - "protobufjs": "^6.8.8" - } - }, "node_modules/onnxruntime-common": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz", - "integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew==" + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.0.tgz", + "integrity": "sha512-Vq1remJbCPITjDMJ04DA7AklUTnbYUp4vbnm6iL7ukSt+7VErH0NGYfekRSTjxxurEtX7w41PFfnQlE6msjPJw==" }, "node_modules/onnxruntime-node": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz", - "integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==", + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.17.0.tgz", + "integrity": "sha512-pRxdqSP3a6wtiFVkVX1V3/gsEMwBRUA9D2oYmcN3cjF+j+ILS+SIY2L7KxdWapsG6z64i5rUn8ijFZdIvbojBg==", "optional": true, "os": [ "win32", @@ -5772,20 +5764,20 @@ "linux" ], "dependencies": { - "onnxruntime-common": "~1.14.0" + "onnxruntime-common": "1.17.0" } }, "node_modules/onnxruntime-web": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz", - "integrity": "sha512-Kcqf43UMfW8mCydVGcX9OMXI2VN17c0p6XvR7IPSZzBf/6lteBzXHvcEVWDPmCKuGombl997HgLqj91F11DzXw==", + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.17.0.tgz", + "integrity": "sha512-O5IZrnJ4ABMmgttdcuG/y3z8WT0zMieCeh/4Eq3lf3CeLwKLoPno38WbAvDiRUkfKjXUyu2mw532YIuGi61YJA==", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", - "long": "^4.0.0", - "onnx-proto": "^4.0.4", - "onnxruntime-common": "~1.14.0", - "platform": "^1.3.6" + "long": "^5.2.3", + "onnxruntime-common": "1.17.0", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" } }, "node_modules/open": { @@ -6044,9 +6036,9 @@ } }, "node_modules/protobufjs": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.4.tgz", - "integrity": "sha512-AT+RJgD2sH8phPmCf7OUZR8xGdcJRga4+1cOaXJ64hvcSkVhNcRHOwIxUatPH15+nj59WAGTDv3LSGZPEQbJaQ==", + "version": "7.2.6", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.6.tgz", + "integrity": "sha512-dgJaEDDL6x8ASUZ1YqWciTRrdOuYNzoOf27oHNfdyvKqHr5i0FV7FSLU+aIeFjyFgVxrpTOtQUi0BLLBymZaBw==", "hasInstallScript": true, "dependencies": { "@protobufjs/aspromise": "^1.1.2", @@ -6066,11 +6058,6 @@ "node": ">=12.0.0" } }, - "node_modules/protobufjs/node_modules/long": { - "version": "5.2.3", - "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", - "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" - }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", diff --git a/package.json b/package.json index ff14f5c97..fa650c620 100644 --- a/package.json +++ b/package.json @@ -38,12 +38,12 @@ }, "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { - "onnxruntime-web": "1.14.0", + "onnxruntime-web": "1.17.0", "sharp": "^0.32.0", "@huggingface/jinja": "^0.1.0" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.17.0" }, "devDependencies": { "@types/jest": "^29.5.1", diff --git a/src/models.js b/src/models.js index 46c112a0b..f52202dd4 100644 --- a/src/models.js +++ b/src/models.js @@ -203,9 +203,16 @@ function validateInputs(session, inputs) { async function sessionRun(session, inputs) { const checkedInputs = validateInputs(session, inputs); try { - // @ts-ignore - let output = await session.run(checkedInputs); + // pass the original ort tensor + const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor])); + let output = await session.run(ortFeed); output = replaceTensors(output); + for (const [name, t] of Object.entries(checkedInputs)) { + // if we use gpu buffers for kv_caches, we own them and need to dispose() + if (name.startsWith('past_key_values')) { + t.dispose(); + }; + } return output; } catch (e) { // This usually occurs when the inputs are of the wrong type. diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 819c2dbb6..861d2d0d8 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -17,6 +17,7 @@ import { const DataTypeMap = Object.freeze({ float32: Float32Array, + float16: Uint16Array, float64: Float64Array, string: Array, // string[] int8: Int8Array, @@ -39,16 +40,32 @@ const ONNXTensor = ONNX.Tensor; export class Tensor { /** @type {number[]} Dimensions of the tensor. */ - dims; + get dims() { + // @ts-ignore + return this.ort_tensor.dims; + } + set dims(value) { + // FIXME: ONNXTensor declares dims as readonly so one needs to use the constructor() if dims change. + // @ts-ignore + this.ort_tensor.dims = value; + } /** @type {DataType} Type of the tensor. */ - type; + get type() { + return this.ort_tensor.type; + }; /** @type {DataArray} The data stored in the tensor. */ - data; + get data() { + return this.ort_tensor.data; + } /** @type {number} The number of elements in the tensor. */ - size; + get size() { + return this.ort_tensor.size; + }; + + ort_tensor; /** * Create a new Tensor or copy an existing Tensor. @@ -56,16 +73,15 @@ export class Tensor { */ constructor(...args) { if (args[0] instanceof ONNXTensor) { - // Create shallow copy - Object.assign(this, args[0]); - + this.ort_tensor = args[0]; } else { // Create new tensor - Object.assign(this, new ONNXTensor( + const t = new ONNXTensor( /** @type {DataType} */(args[0]), /** @type {Exclude} */(args[1]), args[2] - )); + ); + this.ort_tensor = t; } return new Proxy(this, { @@ -89,6 +105,11 @@ export class Tensor { }); } + dispose() { + this.ort_tensor.dispose(); + // this.ort_tensor = undefined; + } + /** * Returns an iterator object for iterating over the tensor data in row-major order. * If the tensor has more than one dimension, the iterator will yield subarrays. From 9e461b61c167327730e97b21298ad664934ff086 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 6 Mar 2024 16:32:14 -0800 Subject: [PATCH 021/473] let application pass session options to the runtime --- package-lock.json | 63 ++++++++++++++--------------------------------- package.json | 8 +++--- src/models.js | 40 +++++++++++++++++++----------- src/pipelines.js | 2 ++ src/utils/hub.js | 1 + 5 files changed, 50 insertions(+), 64 deletions(-) mode change 100644 => 100755 src/models.js mode change 100644 => 100755 src/pipelines.js mode change 100644 => 100755 src/utils/hub.js diff --git a/package-lock.json b/package-lock.json index a2e839770..ed3757e9d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,7 +10,7 @@ "license": "Apache-2.0", "dependencies": { "@huggingface/jinja": "^0.2.0", - "onnxruntime-web": "1.14.0", + "onnxruntime-web": "^1.17.1", "sharp": "^0.32.0" }, "devDependencies": { @@ -27,7 +27,7 @@ "webpack-dev-server": "^4.13.3" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.17.1" } }, "node_modules/@ampproject/remapping": { @@ -5322,9 +5322,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==" + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" }, "node_modules/lru-cache": { "version": "6.0.0", @@ -5748,44 +5748,22 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/onnx-proto": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-4.0.4.tgz", - "integrity": "sha512-aldMOB3HRoo6q/phyB6QRQxSt895HNNw82BNyZ2CMh4bjeKv7g/c+VpAFtJuEMVfYLMbRx61hbuqnKceLeDcDA==", - "dependencies": { - "protobufjs": "^6.8.8" - } - }, "node_modules/onnxruntime-common": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz", - "integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew==" - }, - "node_modules/onnxruntime-node": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz", - "integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==", - "optional": true, - "os": [ - "win32", - "darwin", - "linux" - ], - "dependencies": { - "onnxruntime-common": "~1.14.0" - } + "version": "1.17.1", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.1.tgz", + "integrity": "sha512-6wLNhpn+1hnsKN+jq6ulqUEJ61TdRmyFkGCvtRNnZkAupH8Yfr805UeNxjl9jtiX9B1q48pq6Q/67fEFpxT7Dw==" }, "node_modules/onnxruntime-web": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz", - "integrity": "sha512-Kcqf43UMfW8mCydVGcX9OMXI2VN17c0p6XvR7IPSZzBf/6lteBzXHvcEVWDPmCKuGombl997HgLqj91F11DzXw==", + "version": "1.17.1", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.17.1.tgz", + "integrity": "sha512-EotY9uJU4xFY/ZVZ2Zrl2OZmBcbTVTWn/2OOh4cCWODPwtsYN2xeJYgoz8LfCgZSrhenGg0q4ceYUWATXqEsYQ==", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", - "long": "^4.0.0", - "onnx-proto": "^4.0.4", - "onnxruntime-common": "~1.14.0", - "platform": "^1.3.6" + "long": "^5.2.3", + "onnxruntime-common": "1.17.1", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" } }, "node_modules/open": { @@ -6044,9 +6022,9 @@ } }, "node_modules/protobufjs": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.4.tgz", - "integrity": "sha512-AT+RJgD2sH8phPmCf7OUZR8xGdcJRga4+1cOaXJ64hvcSkVhNcRHOwIxUatPH15+nj59WAGTDv3LSGZPEQbJaQ==", + "version": "7.2.6", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.6.tgz", + "integrity": "sha512-dgJaEDDL6x8ASUZ1YqWciTRrdOuYNzoOf27oHNfdyvKqHr5i0FV7FSLU+aIeFjyFgVxrpTOtQUi0BLLBymZaBw==", "hasInstallScript": true, "dependencies": { "@protobufjs/aspromise": "^1.1.2", @@ -6066,11 +6044,6 @@ "node": ">=12.0.0" } }, - "node_modules/protobufjs/node_modules/long": { - "version": "5.2.3", - "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", - "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" - }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", diff --git a/package.json b/package.json index ff91eab0e..81faa6dac 100644 --- a/package.json +++ b/package.json @@ -38,12 +38,12 @@ }, "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { - "onnxruntime-web": "1.14.0", - "sharp": "^0.32.0", - "@huggingface/jinja": "^0.2.0" + "@huggingface/jinja": "^0.2.0", + "onnxruntime-web": "^1.17.1", + "sharp": "^0.32.0" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.17.1" }, "devDependencies": { "@types/jest": "^29.5.1", diff --git a/src/models.js b/src/models.js old mode 100644 new mode 100755 index abb9c3bbd..89ec63df3 --- a/src/models.js +++ b/src/models.js @@ -123,23 +123,29 @@ async function constructSession(pretrained_model_name_or_path, fileName, options let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); try { - return await InferenceSession.create(buffer, { - executionProviders, - }); - } catch (err) { - // If the execution provided was only wasm, throw the error - if (executionProviders.length === 1 && executionProviders[0] === 'wasm') { - throw err; + let opt = options.session_options || {}; + + // use default execution providers if application did not specify one + if (opt.executionProviders === undefined) { + opt.executionProviders = executionProviders; } - console.warn(err); - console.warn( - 'Something went wrong during model construction (most likely a missing operation). ' + - 'Using `wasm` as a fallback. ' - ) - return await InferenceSession.create(buffer, { - executionProviders: ['wasm'] - }); + // handle onnx external data files + if (opt.externalData !== undefined) { + for (let i = 0; i < opt.externalData.length; i++) { + const ext = opt.externalData[i]; + // if the external data is a string, fetch the file and replace the string with its content + if (typeof ext.data === "string") { + const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); + ext.data = ext_buffer; + } + } + } + return await InferenceSession.create(buffer, opt); + } catch (err) { + // if the session fails, let the application handle it. Ie. if webgpu fails and we + // fallback to wasm, let the application decide if we want to use a quantized model, etc. + throw err; } } @@ -741,6 +747,7 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', model_file_name = null, + session_options = {}, } = {}) { let options = { @@ -751,6 +758,7 @@ export class PreTrainedModel extends Callable { local_files_only, revision, model_file_name, + session_options, } const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this); @@ -5348,6 +5356,7 @@ export class PretrainedMixin { local_files_only = false, revision = 'main', model_file_name = null, + session_options = {}, } = {}) { let options = { @@ -5358,6 +5367,7 @@ export class PretrainedMixin { local_files_only, revision, model_file_name, + session_options, } config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); if (!options.config) { diff --git a/src/pipelines.js b/src/pipelines.js old mode 100644 new mode 100755 index 25dfb5875..fc77c5ae7 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3019,6 +3019,7 @@ export async function pipeline( cache_dir = null, local_files_only = false, revision = 'main', + session_options = {}, } = {} ) { // Helper method to construct pipeline @@ -3046,6 +3047,7 @@ export async function pipeline( cache_dir, local_files_only, revision, + session_options, } const classes = new Map([ diff --git a/src/utils/hub.js b/src/utils/hub.js old mode 100644 new mode 100755 index 93617674c..5b3fe3da3 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -30,6 +30,7 @@ if (!globalThis.ReadableStream) { * since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. * NOTE: This setting is ignored for local requests. * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. + * @property {{}} [session_options={}] Session options passed to the runtime. */ class FileResponse { From ed1ba9b1712d60e826d08ed5207d0ebbf015bbf5 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 6 Mar 2024 16:39:37 -0800 Subject: [PATCH 022/473] allow float16 for llm kv-cache --- src/models.js | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/models.js b/src/models.js index 89ec63df3..c9d4f82db 100755 --- a/src/models.js +++ b/src/models.js @@ -1304,6 +1304,8 @@ export class PreTrainedModel extends Callable { } else { // TODO support batches (i.e., batch_size > 1) const batch_size = 1; + const dtype = this.config.precision || 'float32'; + const empty = (dtype === 'float16') ? new Uint16Array() : []; // @ts-ignore if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) { @@ -1313,10 +1315,10 @@ export class PreTrainedModel extends Callable { let decoder_dims = [batch_size, this.num_decoder_heads, 0, this.decoder_dim_kv]; // @ts-ignore for (let i = 0; i < this.num_decoder_layers; ++i) { - decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims) - decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims) - decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims) - decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims) + decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor(dtype, empty, encoder_dims) + decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor(dtype, empty, encoder_dims) + decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor(dtype, empty, decoder_dims) + decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor(dtype, empty, decoder_dims) } } else if (this.config.model_type === 'falcon') { // NOTE: Custom implementation for Falcon @@ -1324,15 +1326,15 @@ export class PreTrainedModel extends Callable { let dims = [batch_size * this.num_heads, 0, this.dim_kv] // @ts-ignore for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) + decoderFeeds[`past_key_values.${i}.key`] = new Tensor(dtype, empty, dims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor(dtype, empty, dims) } } else if (this.config.multi_query) { // e.g., for `gpt_bigcode` // @ts-ignore let dims = [batch_size * this.num_heads, 0, 2 * this.dim_kv] // @ts-ignore for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims) + decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor(dtype, empty, dims) } } else if (this.config.model_type === 'bloom') { // NOTE: Custom implementation for Bloom @@ -1343,16 +1345,16 @@ export class PreTrainedModel extends Callable { let valueDims = [batch_size * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64] // @ts-ignore for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims) + decoderFeeds[`past_key_values.${i}.key`] = new Tensor(dtype, empty, keyDims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor(dtype, empty, valueDims) } } else { // Decoder-only // @ts-ignore let dims = [batch_size, this.num_heads, 0, this.dim_kv] // @ts-ignore for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) + decoderFeeds[`past_key_values.${i}.key`] = new Tensor(dtype, empty, dims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor(dtype, empty, dims) } } } From 62c2c3f8a8c892803eba94ac63dc3423b8963c90 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 8 Mar 2024 00:19:13 +0200 Subject: [PATCH 023/473] Update package-lock.json --- package-lock.json | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/package-lock.json b/package-lock.json index 0332d9f7d..5218cd464 100644 --- a/package-lock.json +++ b/package-lock.json @@ -27,7 +27,7 @@ "webpack-dev-server": "^4.13.3" }, "optionalDependencies": { - "onnxruntime-node": "1.17.1" + "onnxruntime-node": "1.17.0" } }, "node_modules/@ampproject/remapping": { @@ -5739,6 +5739,26 @@ "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.1.tgz", "integrity": "sha512-6wLNhpn+1hnsKN+jq6ulqUEJ61TdRmyFkGCvtRNnZkAupH8Yfr805UeNxjl9jtiX9B1q48pq6Q/67fEFpxT7Dw==" }, + "node_modules/onnxruntime-node": { + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.17.0.tgz", + "integrity": "sha512-pRxdqSP3a6wtiFVkVX1V3/gsEMwBRUA9D2oYmcN3cjF+j+ILS+SIY2L7KxdWapsG6z64i5rUn8ijFZdIvbojBg==", + "optional": true, + "os": [ + "win32", + "darwin", + "linux" + ], + "dependencies": { + "onnxruntime-common": "1.17.0" + } + }, + "node_modules/onnxruntime-node/node_modules/onnxruntime-common": { + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.0.tgz", + "integrity": "sha512-Vq1remJbCPITjDMJ04DA7AklUTnbYUp4vbnm6iL7ukSt+7VErH0NGYfekRSTjxxurEtX7w41PFfnQlE6msjPJw==", + "optional": true + }, "node_modules/onnxruntime-web": { "version": "1.17.1", "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.17.1.tgz", From 47efbe33dd49c50727474117922ab606a7a9cd2e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 8 Mar 2024 00:19:19 +0200 Subject: [PATCH 024/473] Update package.json --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index fbe55e28c..43ae7d214 100644 --- a/package.json +++ b/package.json @@ -43,7 +43,7 @@ "@huggingface/jinja": "^0.2.1" }, "optionalDependencies": { - "onnxruntime-node": "1.17.1" + "onnxruntime-node": "1.17.0" }, "devDependencies": { "@types/jest": "^29.5.1", From 37d829c4e28ba4b721e1125846fc47221c62b583 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 8 Mar 2024 18:42:59 +0200 Subject: [PATCH 025/473] Cleanup --- src/backends/onnx.js | 41 +++++------ src/env.js | 7 +- src/models.js | 158 +++++++++++++++++++++++++++++++++++++++---- src/pipelines.js | 3 +- src/utils/hub.js | 12 +++- src/utils/tensor.js | 13 ++-- webpack.config.js | 14 ++-- 7 files changed, 192 insertions(+), 56 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index c3f051a05..2b9b92b17 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -24,7 +24,8 @@ import { env, RUNNING_LOCALLY } from '../env.js'; import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web'; -/** @type {import('onnxruntime-web')|import('onnxruntime-node')} The ONNX runtime module. */ +export { Tensor } from 'onnxruntime-common'; + let ONNX; const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; @@ -41,7 +42,7 @@ if (USE_ONNXRUNTIME_NODE) { ONNX_MODULES.set('web', ONNX); // Running in a browser-environment - // TODO: Check if 1.16.1 fixes this issue. + // TODO: Check if 1.17.1 fixes this issue. // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x). // As a temporary fix, we disable it for now. // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644 @@ -53,16 +54,18 @@ if (USE_ONNXRUNTIME_NODE) { /** * Create an ONNX inference session, with fallback support if an operation is not supported. - * @param {Uint8Array} buffer + * @param {Uint8Array} buffer The ONNX model buffer. + * @param {Object} session_options ONNX inference session options. * @returns {Promise} The ONNX inference session. */ -export async function createInferenceSession(buffer) { +export async function createInferenceSession(buffer, session_options) { let executionProviders; let InferenceSession; if (USE_ONNXRUNTIME_NODE) { const ONNX_NODE = ONNX_MODULES.get('node'); InferenceSession = ONNX_NODE.InferenceSession; executionProviders = ['cpu']; + Object.assign(ONNX_NODE.env, env.backends.onnx); } else if (WEBGPU_AVAILABLE && env.experimental.useWebGPU) { // Only import the WebGPU version if the user enables the experimental flag. @@ -74,37 +77,25 @@ export async function createInferenceSession(buffer) { InferenceSession = ONNX_WEBGPU.InferenceSession; - // If WebGPU is available and the user enables the experimental flag, try to use the WebGPU execution provider. + // If WebGPU is available and the user enables the experimental flag, + // try to use the WebGPU execution provider. executionProviders = ['webgpu', 'wasm']; - Object.assign(ONNX_WEBGPU.env, env.backends.onnx); } else { const ONNX_WEB = ONNX_MODULES.get('web'); InferenceSession = ONNX_WEB.InferenceSession; executionProviders = ['wasm']; - env.backends.onnx = ONNX_MODULES.get('web').env + Object.assign(ONNX_WEB.env, env.backends.onnx); } - try { - return await InferenceSession.create(buffer, { - executionProviders, - }); - } catch (err) { - // If the execution provided was only wasm, throw the error - if (executionProviders.length === 1 && executionProviders[0] === 'wasm') { - throw err; - } - - console.warn(err); - console.warn( - 'Something went wrong during model construction (most likely a missing operation). ' + - 'Using `wasm` as a fallback. ' - ) - return await InferenceSession.create(buffer, { - executionProviders: ['wasm'] - }); + // NOTE: Important to create a clone, since ORT modifies the object. + const options = { + executionProviders, + ...session_options } + + return await InferenceSession.create(buffer, options); } /** diff --git a/src/env.js b/src/env.js index 194b3bd66..3b2724b84 100644 --- a/src/env.js +++ b/src/env.js @@ -29,7 +29,8 @@ import url from 'url'; const VERSION = '3.0.0-alpha.0'; // Check if various APIs are available (depends on environment) -const WEB_CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self; +const BROWSER_ENV = typeof self !== 'undefined'; +const WEB_CACHE_AVAILABLE = BROWSER_ENV && 'caches' in self; const FS_AVAILABLE = !isEmpty(fs); // check if file system is available const PATH_AVAILABLE = !isEmpty(path); // check if path is available @@ -60,7 +61,7 @@ const localModelPath = RUNNING_LOCALLY * If set to `false`, it will have the same effect as setting `local_files_only=true` when loading pipelines, models, tokenizers, processors, etc. * @property {string} remoteHost Host URL to load models from. Defaults to the Hugging Face Hub. * @property {string} remotePathTemplate Path template to fill in and append to `remoteHost` when loading models. - * @property {boolean} allowLocalModels Whether to allow loading of local files, defaults to `true`. + * @property {boolean} allowLocalModels Whether to allow loading of local files, defaults to `false` if running in-browser, and `true` otherwise. * If set to `false`, it will skip the local file check and try to load the model from the remote host. * @property {string} localModelPath Path to load local models from. Defaults to `/models/`. * @property {boolean} useFS Whether to use the file system to load files. By default, it is `true` if available. @@ -97,7 +98,7 @@ export const env = { remoteHost: 'https://huggingface.co/', remotePathTemplate: '{model}/resolve/{revision}/', - allowLocalModels: true, + allowLocalModels: !BROWSER_ENV, localModelPath: localModelPath, useFS: FS_AVAILABLE, diff --git a/src/models.js b/src/models.js index bbb472311..6aac18d39 100644 --- a/src/models.js +++ b/src/models.js @@ -81,7 +81,7 @@ import { Tensor, } from './utils/tensor.js'; -import { InferenceSession, isONNXTensor, isONNXProxy } from './backends/onnx.js'; +import { createInferenceSession, isONNXTensor, isONNXProxy } from './backends/onnx.js'; import { medianFilter } from './transformers.js'; ////////////////////////////////////////////////// @@ -110,16 +110,30 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); * Constructs an InferenceSession using a model file located at the specified path. * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. - * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model. + * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * @returns {Promise} A Promise that resolves to an InferenceSession object. * @private */ async function constructSession(pretrained_model_name_or_path, fileName, options) { - // TODO add option for user to force specify their desired execution provider - let modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`; - let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); + const modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`; + const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); - return await InferenceSession.create(buffer, options.session_options); + let session_options = options.session_options || {}; + + // handle onnx external data files + // TODO: parse external data from config/options + // if (session_options.externalData !== undefined) { + // for (let i = 0; i < session_options.externalData.length; i++) { + // const ext = session_options.externalData[i]; + // // if the external data is a string, fetch the file and replace the string with its content + // if (typeof ext.data === "string") { + // const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); + // ext.data = ext_buffer; + // } + // } + // } + + return await createInferenceSession(buffer, session_options); } /** @@ -715,7 +729,7 @@ export class PreTrainedModel extends Callable { * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. - * @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model. + * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * * @returns {Promise} A new instance of the `PreTrainedModel` class. */ @@ -4294,14 +4308,134 @@ export class YolosObjectDetectionOutput extends ModelOutput { ////////////////////////////////////////////////// + + ////////////////////////////////////////////////// export class SamPreTrainedModel extends PreTrainedModel { } + +/** + * Segment Anything Model (SAM) for generating segmentation masks, given an input image + * and optional 2D location and bounding boxes. + * + * **Example:** Perform mask generation w/ `Xenova/sam-vit-base`. + * ```javascript + * import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers'; + * + * const model = await SamModel.from_pretrained('Xenova/sam-vit-base'); + * const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base'); + * + * const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png'; + * const raw_image = await RawImage.read(img_url); + * const input_points = [[[450, 600]]] // 2D localization of a window + * + * const inputs = await processor(raw_image, input_points); + * const outputs = await model(inputs); + * + * const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes); + * // [ + * // Tensor { + * // dims: [ 1, 3, 1764, 2646 ], + * // type: 'bool', + * // data: Uint8Array(14002632) [ ... ], + * // size: 14002632 + * // } + * // ] + * const scores = outputs.iou_scores; + * // Tensor { + * // dims: [ 1, 1, 3 ], + * // type: 'float32', + * // data: Float32Array(3) [ + * // 0.8892380595207214, + * // 0.9311248064041138, + * // 0.983696699142456 + * // ], + * // size: 3 + * // } + * ``` + */ export class SamModel extends SamPreTrainedModel { /** - * @param {Object} model_inputs - * @param {Tensor} model_inputs.pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`. - * @param {Tensor} model_inputs.input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. This is used by the prompt encoder to encode the prompt. - * @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`. + * Creates a new instance of the `SamModel` class. + * @param {Object} config The configuration object specifying the hyperparameters and other model settings. + * @param {Object} vision_encoder The ONNX session containing the vision encoder model. + * @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model. + */ + constructor(config, vision_encoder, prompt_encoder_mask_decoder) { + super(config, vision_encoder); + this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder; + } + + /** + * Compute image embeddings and positional image embeddings, given the pixel values of an image. + * @param {Object} model_inputs Object containing the model inputs. + * @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`. + * @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings. + */ + async get_image_embeddings({ pixel_values }) { + // in: + // - pixel_values: tensor.float32[batch_size,3,1024,1024] + // + // out: + // - image_embeddings: tensor.float32[batch_size,256,64,64] + // - image_positional_embeddings: tensor.float32[batch_size,256,64,64] + return await encoderForward(this, { pixel_values }) + } + + /** + * @typedef {Object} SamModelInputs Object containing the model inputs. + * @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`. + * These can be obtained using a `SamProcessor`. + * @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. + * This is used by the prompt encoder to encode the prompt. + * @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`. + * This is used by the prompt encoder to encode the prompt. There are 4 types of labels: + * - `1`: the point is a point that contains the object of interest + * - `0`: the point is a point that does not contain the object of interest + * - `-1`: the point corresponds to the background + * - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + * @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder. + * @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder. + */ + + /** + * @param {SamModelInputs} model_inputs Object containing the model inputs. + * @returns {Promise} The output of the model. + */ + async forward(model_inputs) { + if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) { + // Compute the image embeddings if they are missing + model_inputs = { + ...model_inputs, + ...(await this.get_image_embeddings(model_inputs)) + } + } + + if (!model_inputs.input_labels) { + // Set default input labels if they are missing + const shape = model_inputs.input_points.dims.slice(0, -1); + const numElements = shape.reduce((a, b) => a * b, 1); + model_inputs.input_labels = new Tensor( + 'int64', + new BigInt64Array(numElements).fill(1n), + shape + ); + } + + // Returns: + // - iou_scores: tensor.float32[batch_size,point_batch_size,3] + // - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256] + return await sessionRun(this.prompt_encoder_mask_decoder, { + input_points: model_inputs.input_points, + input_labels: model_inputs.input_labels, + image_embeddings: model_inputs.image_embeddings, + image_positional_embeddings: model_inputs.image_positional_embeddings, + }); + } + + /** + * Runs the model with the provided inputs + * @param {Object} model_inputs Model inputs + * @returns {Promise} Object containing segmentation outputs */ async _call(model_inputs) { return new SamImageSegmentationOutput(await super._call(model_inputs)); @@ -5671,7 +5805,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], - [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration], [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], diff --git a/src/pipelines.js b/src/pipelines.js index fc77c5ae7..7b73393e3 100755 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3005,7 +3005,7 @@ const TASK_ALIASES = Object.freeze({ * - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`. * - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`. * @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used. - * @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline. + * @param {import('./utils/hub.js').PretrainedModelOptions} [options] Optional parameters for the pipeline. * @returns {Promise} A Pipeline object for the specified task. * @throws {Error} If an unsupported pipeline is requested. */ @@ -3020,6 +3020,7 @@ export async function pipeline( local_files_only = false, revision = 'main', session_options = {}, + // TODO: device option } = {} ) { // Helper method to construct pipeline diff --git a/src/utils/hub.js b/src/utils/hub.js index 61ffc1d4a..34bd07433 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -13,7 +13,6 @@ import { dispatchCallback } from './core.js'; /** * @typedef {Object} PretrainedOptions Options for loading a pretrained model. - * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). * @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates. * @property {Object} [config=null] Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: * - The model is a model provided by the library (loaded with the *model id* string of a pretrained model). @@ -23,8 +22,17 @@ import { dispatchCallback } from './core.js'; * @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id, * since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. * NOTE: This setting is ignored for local requests. + */ + +/** + * @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model. + * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. - * @property {{}} [session_options={}] Session options passed to the runtime. + * @property {Object} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen. + */ + +/** + * @typedef {PretrainedOptions & ModelSpecificPretrainedOptions} PretrainedModelOptions Options for loading a pretrained model. */ class FileResponse { diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 37df45d61..cd64cc7c2 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -12,6 +12,9 @@ import { transpose_data } from './maths.js'; +import { + Tensor as ONNXTensor, isONNXTensor, +} from '../backends/onnx.js'; const DataTypeMap = Object.freeze({ float32: Float32Array, @@ -34,7 +37,6 @@ const DataTypeMap = Object.freeze({ * @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray */ -const ONNXTensor = ONNX.Tensor; export class Tensor { /** @type {number[]} Dimensions of the tensor. */ @@ -67,19 +69,18 @@ export class Tensor { /** * Create a new Tensor or copy an existing Tensor. - * @param {[DataType, DataArray, number[]]|[import('onnxruntime-common').Tensor]} args + * @param {[DataType, DataArray, number[]]|[ONNXTensor]} args */ constructor(...args) { - if (args[0] instanceof ONNXTensor) { - this.ort_tensor = args[0]; + if (isONNXTensor(args[0])) { + this.ort_tensor = /** @type {ONNXTensor} */ (args[0]); } else { // Create new tensor - const t = new ONNXTensor( + this.ort_tensor = new ONNXTensor( /** @type {DataType} */(args[0]), /** @type {Exclude} */(args[1]), args[2] ); - this.ort_tensor = t; } return new Proxy(this, { diff --git a/webpack.config.js b/webpack.config.js index e234c7e3e..e5ad664f1 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -60,13 +60,13 @@ function buildConfig({ experiments: { outputModule, }, - module: { - parser: { - javascript: { - dynamicImportMode, - } - } - }, + // module: { + // parser: { + // javascript: { + // dynamicImportMode, + // } + // } + // }, // Development server devServer: { From f385eb3f549f4fa84dcfb4f02963b1b0a38c7919 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 9 Mar 2024 02:35:41 +0200 Subject: [PATCH 026/473] Add guard before `ONNX.env.wasm` --- src/backends/onnx.js | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 2b9b92b17..887669da1 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -121,13 +121,15 @@ export function isONNXProxy() { return ONNX.env.wasm.proxy; } -// Set path to wasm files. This is needed when running in a web worker. -// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths -// We use remote wasm files by default to make it easier for newer users. -// In practice, users should probably self-host the necessary .wasm files. -ONNX.env.wasm.wasmPaths = RUNNING_LOCALLY - ? path.join(env.__dirname, '/dist/') - : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`; +if (ONNX?.env?.wasm) { + // Set path to wasm files. This is needed when running in a web worker. + // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths + // We use remote wasm files by default to make it easier for newer users. + // In practice, users should probably self-host the necessary .wasm files. + ONNX.env.wasm.wasmPaths = RUNNING_LOCALLY + ? path.join(env.__dirname, '/dist/') + : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`; +} // Expose ONNX environment variables to `env.backends.onnx` env.backends.onnx = ONNX.env; From f7b0dce5b2ccae55b0b42ea112ebdb939776b6d8 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 9 Mar 2024 02:35:46 +0200 Subject: [PATCH 027/473] Update models.js --- src/models.js | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/models.js b/src/models.js index 6aac18d39..e935f3111 100644 --- a/src/models.js +++ b/src/models.js @@ -5540,8 +5540,6 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['glpn', ['GLPNModel', GLPNModel]], ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]], - - ['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly ]); const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ From e9b8ba82acd698c6fd2c8c25f1038856fda148a8 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 9 Mar 2024 03:06:05 +0200 Subject: [PATCH 028/473] Add webgpu-embedding-benchmark demo --- .../webgpu-embedding-benchmark/.gitignore | 24 ++ .../webgpu-embedding-benchmark/index.html | 46 ++++ examples/webgpu-embedding-benchmark/main.js | 259 ++++++++++++++++++ .../webgpu-embedding-benchmark/package.json | 18 ++ examples/webgpu-embedding-benchmark/style.css | 82 ++++++ .../webgpu-embedding-benchmark/vite.config.js | 6 + 6 files changed, 435 insertions(+) create mode 100644 examples/webgpu-embedding-benchmark/.gitignore create mode 100644 examples/webgpu-embedding-benchmark/index.html create mode 100644 examples/webgpu-embedding-benchmark/main.js create mode 100644 examples/webgpu-embedding-benchmark/package.json create mode 100644 examples/webgpu-embedding-benchmark/style.css create mode 100644 examples/webgpu-embedding-benchmark/vite.config.js diff --git a/examples/webgpu-embedding-benchmark/.gitignore b/examples/webgpu-embedding-benchmark/.gitignore new file mode 100644 index 000000000..a547bf36d --- /dev/null +++ b/examples/webgpu-embedding-benchmark/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/examples/webgpu-embedding-benchmark/index.html b/examples/webgpu-embedding-benchmark/index.html new file mode 100644 index 000000000..74c78664d --- /dev/null +++ b/examples/webgpu-embedding-benchmark/index.html @@ -0,0 +1,46 @@ + + + + + + + Transformers.js | WebGPU Benchmark + + + +

+ 🤗 Transformers.js WebGPU Benchmark +

+

+ This benchmark measures the execution time of Xenova/all-MiniLM-L6-v2 (bert-based embedding model) + using the WASM and WebGPU execution providers across different batch sizes. +

+
+ +
+
+ + +
+ +
+ Options +
+ + +
+
+ + +
+
+ + + + +
+
+ + + \ No newline at end of file diff --git a/examples/webgpu-embedding-benchmark/main.js b/examples/webgpu-embedding-benchmark/main.js new file mode 100644 index 000000000..f7b40d23b --- /dev/null +++ b/examples/webgpu-embedding-benchmark/main.js @@ -0,0 +1,259 @@ +import './style.css'; +import { env, AutoModel, ones } from '@xenova/transformers'; +import Chart from 'chart.js/auto'; + +// Throw an error if WebGPU is not supported +if (!navigator.gpu) { + const err = 'WebGPU is not supported by this browser.'; + alert(err) + throw Error(err); +} + +// Proxy the WASM backend to prevent the UI from freezing +env.backends.onnx.wasm.proxy = true; +env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/'; +env.backends.onnx.wasm.numThreads = 1; +env.experimental.useWebGPU = true; + +// Reference the elements that we will need +const ctx = document.getElementById('chart'); +const batchSizes = document.getElementById('batch-sizes'); +const xscale = document.getElementById('x-scale'); +const yscale = document.getElementById('y-scale'); +const sequenceLength = document.getElementById('sequence-length'); +const status = document.getElementById('status'); +const start = document.getElementById('start'); +const stop = document.getElementById('stop'); + +// Benchmark settings +const NUM_WARMUP_STEPS = 3; +const QUANTIZED = false; +const MODEL_ID = 'Xenova/all-MiniLM-L6-v2'; + +// Chart configuration +const config = { + type: 'line', + data: { + labels: [], + datasets: [{ + label: 'WASM', + data: [], + borderColor: 'red', + backgroundColor: 'rgba(255, 0, 0, 0.5)', + }, { + label: 'WebGPU', + data: [], + borderColor: 'blue', + backgroundColor: 'rgba(0, 0, 255, 0.5)', + }] + }, + options: { + responsive: true, + maintainAspectRatio: false, + plugins: { + legend: { + position: 'top', + }, + }, + scales: { + x: { + title: { + display: true, + text: 'Batch size', + }, + min: 1, + }, + y: { + title: { + display: true, + text: 'Time (ms)', + }, + } + } + }, +}; + +const toggleScale = (chart, axis, enabled) => { + chart.options.scales[axis].type = enabled ? 'logarithmic' : 'linear'; + chart.update(); +} + +xscale.addEventListener('change', () => toggleScale(chart, 'x', xscale.checked)); +yscale.addEventListener('change', () => toggleScale(chart, 'y', yscale.checked)); + +const chart = new Chart(ctx, config); + +status.textContent = 'Loading model...'; + +let model_CPU; +try { + model_CPU = await AutoModel.from_pretrained(MODEL_ID, { + quantized: QUANTIZED, + session_options: { + executionProviders: ['wasm'] + } + }); +} catch (err) { + status.textContent = err.message; + alert(err.message) + throw err; +} + +let model_GPU; +try { + model_GPU = await AutoModel.from_pretrained(MODEL_ID, { + quantized: QUANTIZED, + session_options: { + executionProviders: ['webgpu'] + } + }); +} catch (err) { + status.textContent = err.message; + alert(err.message) + throw err; +} + +let adapterInfo; +try { + // Shouldn't fail since the WebGPU model has loaded successfully + const adapter = await navigator.gpu.requestAdapter(); + adapterInfo = await adapter.requestAdapterInfo(); +} catch (err) { + adapterInfo = {}; +} + +status.textContent = 'Ready'; + +let interrupted = false; +start.addEventListener('click', async () => { + start.disabled = true; + stop.disabled = false; + interrupted = false; + + // Reset + chart.data.labels = []; + for (let i = 0; i < chart.data.datasets; ++i) { + chart.data.datasets[i].data = []; + } + chart.update(); + + const seqLength = parseInt(sequenceLength.value); + + status.textContent = 'Warming up...'; + + const generateDummyInputs = (batch_size) => { + + const inputs = ones([batch_size, seqLength]); + + const model_inputs = { + input_ids: inputs, + attention_mask: inputs, + } + return model_inputs; + } + + // Warm up: This is important for the WebGPU execution provider, which compiles the shaders on first load + for (let i = 0; i < NUM_WARMUP_STEPS; ++i) { + const model_inputs = generateDummyInputs(1); + await model_CPU(model_inputs); + await model_GPU(model_inputs); + } + + status.textContent = 'Running benchmark...'; + + const batch_sizes = batchSizes.value.split(',').map(x => parseInt(x)).filter(x => x); + + for (const batch_size of batch_sizes) { + if (interrupted) break; + + const model_inputs = generateDummyInputs(batch_size); + + let wasmTime; + { // Run WASM + const start = performance.now(); + await model_CPU(model_inputs); + const end = performance.now(); + wasmTime = end - start; + } + + let webGPUTime; + { // Run WebGPU + const start = performance.now(); + await model_GPU(model_inputs); + const end = performance.now(); + webGPUTime = end - start; + } + chart.data.labels.push(batch_size); + chart.data.datasets[0].data.push(wasmTime); + chart.data.datasets[1].data.push(webGPUTime); + chart.update(); + } + + // Calculate max speedup: + if (chart.data.labels.length === 0) return; + + const table = generateResultsTable(chart.data, seqLength); + + const speedup = chart.data.datasets[0].data.at(-1) / chart.data.datasets[1].data.at(-1); + const roundedSpeedup = speedup.toFixed(2); + const params = new URLSearchParams({ + title: `⚡ WebGPU Benchmark Results (${roundedSpeedup}x speedup)`, + description: table.outerHTML, + }); + + const paramsStr = params.toString(); + status.innerHTML = `⚡ Done! WebGPU is ${roundedSpeedup}x faster! Share results`; + start.disabled = false; +}); +start.disabled = false; + +stop.addEventListener('click', () => { + status.textContent = 'Stopping...'; + interrupted = true; + stop.disabled = true; +}); + +function generateResultsTable(data, sequence_length) { + const datasets = data.datasets.map(d => d.data); + const batch_sizes = data.labels; + + const container = document.createElement('div'); + + const table = document.createElement('table'); + const thead = table.createTHead(); + const tbody = table.createTBody(); + + // Add header row + const headerRow = thead.insertRow(); + headerRow.insertCell().textContent = 'Batch Size'; + headerRow.insertCell().textContent = `WASM (ms)`; + headerRow.insertCell().textContent = `WebGPU (ms)`; + + // Add data rows + batch_sizes.forEach((batchSize, rowIndex) => { + const row = tbody.insertRow(); + row.insertCell().textContent = batchSize; + datasets.forEach(dataset => { + row.insertCell().textContent = dataset[rowIndex].toFixed(2); + }); + }); + + container.appendChild(table); + + const createBulletPoint = (text) => { + const li = document.createElement('li'); + li.textContent = text; + return li; + } + + // Add other information + const info = document.createElement('ul'); + info.appendChild(createBulletPoint(`Model: ${MODEL_ID}`)); + info.appendChild(createBulletPoint(`Quantized: ${QUANTIZED}`)); + info.appendChild(createBulletPoint(`Sequence length: ${sequence_length}`)); + info.appendChild(createBulletPoint(`Browser: ${navigator.userAgent}`)); + info.appendChild(createBulletPoint(`GPU: vendor=${adapterInfo.vendor}, architecture=${adapterInfo.architecture}, device=${adapterInfo.device}, description=${adapterInfo.description}`)); + container.appendChild(info); + + return container; +} diff --git a/examples/webgpu-embedding-benchmark/package.json b/examples/webgpu-embedding-benchmark/package.json new file mode 100644 index 000000000..d90288d7a --- /dev/null +++ b/examples/webgpu-embedding-benchmark/package.json @@ -0,0 +1,18 @@ +{ + "name": "webgpu-embedding-benchmark", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "devDependencies": { + "vite": "^5.0.12" + }, + "dependencies": { + "@xenova/transformers": "^3.0.0", + "chart.js": "^4.4.2" + } +} diff --git a/examples/webgpu-embedding-benchmark/style.css b/examples/webgpu-embedding-benchmark/style.css new file mode 100644 index 000000000..d5748146d --- /dev/null +++ b/examples/webgpu-embedding-benchmark/style.css @@ -0,0 +1,82 @@ +* { + box-sizing: border-box; + padding: 0; + margin: 0; + font-family: sans-serif; +} + +html, +body { + height: 100%; +} + +body { + padding: 16px 32px; + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; +} + +h1 { + text-align: center; +} + +#status { + min-height: 16px; + margin: 8px 0; +} + +button { + transition: all .25s; + background: rgba(40, 44, 52, 0.05); + border: 1px solid transparent; + border-radius: 6px; + color: #3080d0; + text-decoration: none !important; + display: inline-block; + font-size: 14px; + font-weight: 500; + padding: 8px 16px; + cursor: pointer; + -webkit-user-select: none; + -moz-user-select: none; + user-select: none; +} + +button:disabled { + background: rgba(40, 44, 52, 0.1); + color: #a0a0a0; + cursor: not-allowed; +} + +button:hover { + background: rgba(40, 44, 52, 0.1); +} + +p { + text-align: center; + font-size: 12px; + max-width: 600px; + padding: 8px; +} + +#chart-container { + position: relative; + height: 60vh; + width: min(90vw, 800px); + padding-right: 50px; + margin-bottom: 10px; +} + +details { + position: fixed; + background-color: white; + right: 0; + top: 0; + padding: 16px; +} + +summary { + text-align: right; +} \ No newline at end of file diff --git a/examples/webgpu-embedding-benchmark/vite.config.js b/examples/webgpu-embedding-benchmark/vite.config.js new file mode 100644 index 000000000..6c32f52df --- /dev/null +++ b/examples/webgpu-embedding-benchmark/vite.config.js @@ -0,0 +1,6 @@ +import { defineConfig } from 'vite'; +export default defineConfig({ + build: { + target: 'esnext' + } +}); From 69e5f5534093261ed8ec3dad685293a53477af2d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 10 Mar 2024 18:42:24 +0200 Subject: [PATCH 029/473] Add `RawImage.fromCanvas` helper function --- src/utils/image.js | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/utils/image.js b/src/utils/image.js index 152c9d544..71e9d0a89 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -124,6 +124,20 @@ export class RawImage { } } + /** + * Read an image from a canvas. + * @param {HTMLCanvasElement|OffscreenCanvas} canvas The canvas to read the image from. + * @returns {RawImage} The image object. + */ + static fromCanvas(canvas) { + if (!BROWSER_ENV) { + throw new Error('fromCanvas() is only supported in browser environments.') + } + + const ctx = canvas.getContext('2d'); + const data = ctx.getImageData(0, 0, canvas.width, canvas.height).data; + return new RawImage(data, canvas.width, canvas.height, 4); + } /** * Read an image from a URL or file path. From 8b7c2cee1a7ccc26f2dbef701f06e674ce028c5e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Mar 2024 02:17:02 +0200 Subject: [PATCH 030/473] Improve web/node split --- src/backends/onnx.js | 103 ++++++++++++++++++------------------------- src/configs.js | 4 +- webpack.config.js | 25 ++++------- 3 files changed, 55 insertions(+), 77 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 887669da1..36ad8ae52 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -22,36 +22,23 @@ import { env, RUNNING_LOCALLY } from '../env.js'; // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`. // In either case, we select the default export if it exists, otherwise we use the named export. import * as ONNX_NODE from 'onnxruntime-node'; -import * as ONNX_WEB from 'onnxruntime-web'; +import * as ONNX_WEB from 'onnxruntime-web/webgpu'; export { Tensor } from 'onnxruntime-common'; -let ONNX; - const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; -const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node' - -const ONNX_MODULES = new Map(); +const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node'; +let ONNX; if (USE_ONNXRUNTIME_NODE) { ONNX = ONNX_NODE.default ?? ONNX_NODE; - ONNX_MODULES.set('node', ONNX); } else { - // @ts-ignore - ONNX = ONNX_WEB.default ?? ONNX_WEB; - ONNX_MODULES.set('web', ONNX); - - // Running in a browser-environment - // TODO: Check if 1.17.1 fixes this issue. - // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x). - // As a temporary fix, we disable it for now. - // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644 - const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent); - if (isIOS) { - ONNX.env.wasm.simd = false; - } + ONNX = ONNX_WEB; } +// @ts-ignore +const InferenceSession = ONNX.InferenceSession; + /** * Create an ONNX inference session, with fallback support if an operation is not supported. * @param {Uint8Array} buffer The ONNX model buffer. @@ -60,33 +47,18 @@ if (USE_ONNXRUNTIME_NODE) { */ export async function createInferenceSession(buffer, session_options) { let executionProviders; - let InferenceSession; if (USE_ONNXRUNTIME_NODE) { - const ONNX_NODE = ONNX_MODULES.get('node'); - InferenceSession = ONNX_NODE.InferenceSession; executionProviders = ['cpu']; - Object.assign(ONNX_NODE.env, env.backends.onnx); - - } else if (WEBGPU_AVAILABLE && env.experimental.useWebGPU) { - // Only import the WebGPU version if the user enables the experimental flag. - let ONNX_WEBGPU = ONNX_MODULES.get('webgpu'); - if (ONNX_WEBGPU === undefined) { - ONNX_WEBGPU = await import('onnxruntime-web/webgpu'); - ONNX_MODULES.set('webgpu', ONNX_WEBGPU) + } else if (env.experimental.useWebGPU) { + // Only use the WebGPU version if the user enables the experimental flag. + if (WEBGPU_AVAILABLE) { + executionProviders = ['webgpu', 'wasm']; + } else { + console.warn('`env.experimental.useWebGPU = true` but WebGPU is not available in this environment. Using WASM as the execution provider.'); + executionProviders = ['wasm']; } - - InferenceSession = ONNX_WEBGPU.InferenceSession; - - // If WebGPU is available and the user enables the experimental flag, - // try to use the WebGPU execution provider. - executionProviders = ['webgpu', 'wasm']; - Object.assign(ONNX_WEBGPU.env, env.backends.onnx); - } else { - const ONNX_WEB = ONNX_MODULES.get('web'); - InferenceSession = ONNX_WEB.InferenceSession; executionProviders = ['wasm']; - Object.assign(ONNX_WEB.env, env.backends.onnx); } // NOTE: Important to create a clone, since ORT modifies the object. @@ -104,32 +76,45 @@ export async function createInferenceSession(buffer, session_options) { * @returns {boolean} Whether the object is an ONNX tensor. */ export function isONNXTensor(x) { - for (const module of ONNX_MODULES.values()) { - if (x instanceof module.Tensor) { - return true; - } - } - return false; + return x instanceof ONNX.Tensor; } -/** - * Check if ONNX's WASM backend is being proxied. - * @returns {boolean} Whether ONNX's WASM backend is being proxied. - */ -export function isONNXProxy() { - // TODO: Update this when allowing non-WASM backends. - return ONNX.env.wasm.proxy; -} +// @ts-ignore +const ONNX_ENV = ONNX?.env; +if (ONNX_ENV?.wasm) { + // Initialize wasm backend with suitable default settings. -if (ONNX?.env?.wasm) { // Set path to wasm files. This is needed when running in a web worker. // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths // We use remote wasm files by default to make it easier for newer users. // In practice, users should probably self-host the necessary .wasm files. - ONNX.env.wasm.wasmPaths = RUNNING_LOCALLY + ONNX_ENV.wasm.wasmPaths = RUNNING_LOCALLY ? path.join(env.__dirname, '/dist/') : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`; + + // Proxy the WASM backend to prevent the UI from freezing + ONNX_ENV.wasm.proxy = true; + // ONNX_ENV.wasm.numThreads = 1; // TODO is this needed? + + // Running in a browser-environment + // TODO: Check if 1.17.1 fixes this issue. + // SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x). + // As a temporary fix, we disable it for now. + // For more information, see: https://github.com/microsoft/onnxruntime/issues/15644 + const isIOS = typeof navigator !== 'undefined' && /iP(hone|od|ad).+16_4.+AppleWebKit/.test(navigator.userAgent); + if (isIOS) { + ONNX_ENV.wasm.simd = false; + } +} + +/** + * Check if ONNX's WASM backend is being proxied. + * @returns {boolean} Whether ONNX's WASM backend is being proxied. + */ +export function isONNXProxy() { + // TODO: Update this when allowing non-WASM backends. + return ONNX_ENV?.wasm?.proxy; } // Expose ONNX environment variables to `env.backends.onnx` -env.backends.onnx = ONNX.env; +env.backends.onnx = ONNX_ENV; diff --git a/src/configs.js b/src/configs.js index 4506d2d9c..8a67f853d 100644 --- a/src/configs.js +++ b/src/configs.js @@ -97,10 +97,10 @@ export class PretrainedConfig { * Helper class which is used to instantiate pretrained configs with the `from_pretrained` function. * * @example - * let config = await AutoConfig.from_pretrained('bert-base-uncased'); + * const config = await AutoConfig.from_pretrained('Xenova/bert-base-uncased'); */ export class AutoConfig { - /** @type {PretrainedConfig.from_pretrained} */ + /** @type {typeof PretrainedConfig.from_pretrained} */ static async from_pretrained(...args) { return PretrainedConfig.from_pretrained(...args); } diff --git a/webpack.config.js b/webpack.config.js index e5ad664f1..ab7549932 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -19,10 +19,14 @@ function buildConfig({ name = '', suffix = '.js', type = 'module', // 'module' | 'commonjs' - dynamicImportMode = undefined, // 'eager' | undefined + ignoreModules = [], // 'eager' | undefined } = {}) { const outputModule = type === 'module'; + const alias = Object.fromEntries(ignoreModules.map((module) => { + return [module, false]; + })); + return { mode: 'development', devtool: 'source-map', @@ -60,13 +64,7 @@ function buildConfig({ experiments: { outputModule, }, - // module: { - // parser: { - // javascript: { - // dynamicImportMode, - // } - // } - // }, + resolve: { alias }, // Development server devServer: { @@ -84,13 +82,8 @@ export default [ type: 'module', }), buildConfig({ - name: '.webgpu', - type: 'module', - dynamicImportMode: 'eager', + suffix: '.cjs', + type: 'commonjs', + ignoreModules: ['onnxruntime-web', 'onnxruntime-web/webgpu'], }), - // TODO: - // buildConfig({ - // suffix: '.cjs', - // type: 'commonjs', - // }), ]; From 034b959826ec61618c26db7db11e72cb26c97522 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Mar 2024 02:17:29 +0200 Subject: [PATCH 031/473] Update processors.js --- src/processors.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/processors.js b/src/processors.js index 4dcb4862f..a7df4cd2e 100644 --- a/src/processors.js +++ b/src/processors.js @@ -700,7 +700,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { const pixel_values = stack(imageData.map(x => x.pixel_values), 0); return { - pixel_values: pixel_values, + pixel_values, // Original sizes of images original_sizes: imageData.map(x => x.original_size), From d1b1059c7c6af09ec4592f1e4ed4c0bcc21786be Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Mar 2024 03:03:20 +0200 Subject: [PATCH 032/473] Improve how we set devices --- src/backends/onnx.js | 30 +++++++++++++++++------------- src/env.js | 6 ------ src/models.js | 41 ++++++++++++++++++++++++++--------------- src/pipelines.js | 3 ++- src/utils/devices.js | 3 +++ src/utils/hub.js | 1 + 6 files changed, 49 insertions(+), 35 deletions(-) create mode 100644 src/utils/devices.js diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 36ad8ae52..d5917d888 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -29,11 +29,20 @@ export { Tensor } from 'onnxruntime-common'; const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node'; +const supportedExecutionProviders = []; +let defaultExecutionProviders; let ONNX; if (USE_ONNXRUNTIME_NODE) { ONNX = ONNX_NODE.default ?? ONNX_NODE; + supportedExecutionProviders.push('cpu'); + defaultExecutionProviders = ['cpu']; } else { ONNX = ONNX_WEB; + if (WEBGPU_AVAILABLE) { + supportedExecutionProviders.push('webgpu'); + } + supportedExecutionProviders.push('wasm'); + defaultExecutionProviders = ['wasm']; } // @ts-ignore @@ -42,23 +51,18 @@ const InferenceSession = ONNX.InferenceSession; /** * Create an ONNX inference session, with fallback support if an operation is not supported. * @param {Uint8Array} buffer The ONNX model buffer. - * @param {Object} session_options ONNX inference session options. + * @param {Object} session_options ONNX inference session options. + * @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on. * @returns {Promise} The ONNX inference session. */ -export async function createInferenceSession(buffer, session_options) { - let executionProviders; - if (USE_ONNXRUNTIME_NODE) { - executionProviders = ['cpu']; - } else if (env.experimental.useWebGPU) { - // Only use the WebGPU version if the user enables the experimental flag. - if (WEBGPU_AVAILABLE) { - executionProviders = ['webgpu', 'wasm']; +export async function createInferenceSession(buffer, session_options, device = null) { + let executionProviders = defaultExecutionProviders; + if (device) { // User has specified a device + if (supportedExecutionProviders.includes(device)) { + executionProviders = [device]; } else { - console.warn('`env.experimental.useWebGPU = true` but WebGPU is not available in this environment. Using WASM as the execution provider.'); - executionProviders = ['wasm']; + throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedExecutionProviders.join(', ')}.`) } - } else { - executionProviders = ['wasm']; } // NOTE: Important to create a clone, since ORT modifies the object. diff --git a/src/env.js b/src/env.js index 3b2724b84..e4a736457 100644 --- a/src/env.js +++ b/src/env.js @@ -84,12 +84,6 @@ export const env = { tfjs: {}, }, - /////////////////// Experimental settings /////////////////// - experimental: { - // Whether to use the experimental WebGPU backend for ONNX.js. - useWebGPU: false, - }, - __dirname, version: VERSION, diff --git a/src/models.js b/src/models.js index e935f3111..4219ca9ea 100644 --- a/src/models.js +++ b/src/models.js @@ -6,10 +6,10 @@ * * ```javascript * import { AutoModel, AutoTokenizer } from '@xenova/transformers'; - * + * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); * let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased'); - * + * * let inputs = await tokenizer('I love transformers!'); * let { logits } = await model(inputs); * // Tensor { @@ -28,7 +28,7 @@ * * let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); * let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small'); - * + * * let { input_ids } = await tokenizer('translate English to German: I love transformers!'); * let outputs = await model.generate(input_ids); * let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true }); @@ -118,22 +118,29 @@ async function constructSession(pretrained_model_name_or_path, fileName, options const modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`; const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); - let session_options = options.session_options || {}; + const session_options = options.session_options ?? {}; // handle onnx external data files - // TODO: parse external data from config/options - // if (session_options.externalData !== undefined) { - // for (let i = 0; i < session_options.externalData.length; i++) { - // const ext = session_options.externalData[i]; - // // if the external data is a string, fetch the file and replace the string with its content - // if (typeof ext.data === "string") { - // const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); - // ext.data = ext_buffer; - // } + if (session_options.externalData !== undefined) { + for (let i = 0; i < session_options.externalData.length; i++) { + const ext = session_options.externalData[i]; + // if the external data is a string, fetch the file and replace the string with its content + if (typeof ext.data === "string") { + const ext_buffer = await getModelFile(pretrained_model_name_or_path, ext.data, true, options); + ext.data = ext_buffer; + } + } + } + + // TODO: Add support for preferredOutputLocation + // if (options.device == "webgpu") { + // for (let i = 0; i < config.layers; ++i) { + // options.session_options.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer'; + // options.session_options.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer'; // } // } - return await createInferenceSession(buffer, session_options); + return await createInferenceSession(buffer, session_options, options.device); } /** @@ -198,7 +205,7 @@ async function sessionRun(session, inputs) { try { // pass the original ort tensor const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor])); - let output = await session.run(ortFeed); + let output = await session.run(ortFeed); output = replaceTensors(output); for (const [name, t] of Object.entries(checkedInputs)) { // if we use gpu buffers for kv_caches, we own them and need to dispose() @@ -741,6 +748,7 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', model_file_name = null, + device = null, session_options = {}, } = {}) { @@ -752,6 +760,7 @@ export class PreTrainedModel extends Callable { local_files_only, revision, model_file_name, + device, session_options, } @@ -5448,6 +5457,7 @@ export class PretrainedMixin { local_files_only = false, revision = 'main', model_file_name = null, + device = null, session_options = {}, } = {}) { @@ -5459,6 +5469,7 @@ export class PretrainedMixin { local_files_only, revision, model_file_name, + device, session_options, } config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); diff --git a/src/pipelines.js b/src/pipelines.js index 7b73393e3..13475c460 100755 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3019,8 +3019,8 @@ export async function pipeline( cache_dir = null, local_files_only = false, revision = 'main', + device= null, session_options = {}, - // TODO: device option } = {} ) { // Helper method to construct pipeline @@ -3048,6 +3048,7 @@ export async function pipeline( cache_dir, local_files_only, revision, + device, session_options, } diff --git a/src/utils/devices.js b/src/utils/devices.js new file mode 100644 index 000000000..8a0a83dca --- /dev/null +++ b/src/utils/devices.js @@ -0,0 +1,3 @@ +/** + * @typedef {'cpu'|'gpu'|'wasm'|'webgpu'|null} DeviceType + */ diff --git a/src/utils/hub.js b/src/utils/hub.js index 34bd07433..4062d4826 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -28,6 +28,7 @@ import { dispatchCallback } from './core.js'; * @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model. * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. + * @property {import("./devices.js").DeviceType} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. * @property {Object} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen. */ From c55d0f01336155259daf0445a8b89bbfb6d128f4 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Mar 2024 03:14:29 +0200 Subject: [PATCH 033/473] Only run tests if not draft --- .github/workflows/tests.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 67f047551..387c00206 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,12 +7,17 @@ on: pull_request: branches: - main - + types: + - opened + - reopened + - synchronize + - ready_for_review env: TESTING_REMOTELY: true jobs: build: + if: github.event.pull_request.draft == false runs-on: ubuntu-latest strategy: From d8b318db526dae22dbf4029a9c88a65636ef693c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 11 Mar 2024 23:38:39 +0200 Subject: [PATCH 034/473] Only continue trying if error is "Unsupported model" Fixes https://github.com/xenova/transformers.js/issues/314 --- src/pipelines.js | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index 13475c460..f9774f803 100755 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3019,7 +3019,7 @@ export async function pipeline( cache_dir = null, local_files_only = false, revision = 'main', - device= null, + device = null, session_options = {}, } = {} ) { @@ -3105,7 +3105,15 @@ async function loadItems(mapping, model, pretrainedOptions) { resolve(await c.from_pretrained(model, pretrainedOptions)); return; } catch (err) { - e = err; + if (err.message?.includes('Unsupported model type')) { + // If the error is due to an unsupported model type, we + // save the error and try the next class. + e = err; + } else { + reject(err); + return; + } + } } reject(e); From 90a5e95d0ad2b176c901be750424739d5426aa56 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Mar 2024 00:12:27 +0200 Subject: [PATCH 035/473] Update sharp.js dependency --- package-lock.json | 836 +++++++++++++++++++++++++++------------------- package.json | 4 +- 2 files changed, 487 insertions(+), 353 deletions(-) diff --git a/package-lock.json b/package-lock.json index 5218cd464..697265108 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,7 +11,7 @@ "dependencies": { "@huggingface/jinja": "^0.2.1", "onnxruntime-web": "1.17.1", - "sharp": "^0.32.0" + "sharp": "^0.33.2" }, "devDependencies": { "@types/jest": "^29.5.1", @@ -744,6 +744,15 @@ "node": ">=10.0.0" } }, + "node_modules/@emnapi/runtime": { + "version": "0.45.0", + "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-0.45.0.tgz", + "integrity": "sha512-Txumi3td7J4A/xTTwlssKieHKTGl3j4A1tglBx72auZ49YK7ePY6XZricgIg9mnZT4xPfA+UPCUdnhRuEFDL+w==", + "optional": true, + "dependencies": { + "tslib": "^2.4.0" + } + }, "node_modules/@huggingface/jinja": { "version": "0.2.1", "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.1.tgz", @@ -752,6 +761,437 @@ "node": ">=18" } }, + "node_modules/@img/sharp-darwin-arm64": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-darwin-arm64/-/sharp-darwin-arm64-0.33.2.tgz", + "integrity": "sha512-itHBs1rPmsmGF9p4qRe++CzCgd+kFYktnsoR1sbIAfsRMrJZau0Tt1AH9KVnufc2/tU02Gf6Ibujx+15qRE03w==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "glibc": ">=2.26", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-darwin-arm64": "1.0.1" + } + }, + "node_modules/@img/sharp-darwin-x64": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-darwin-x64/-/sharp-darwin-x64-0.33.2.tgz", + "integrity": "sha512-/rK/69Rrp9x5kaWBjVN07KixZanRr+W1OiyKdXcbjQD6KbW+obaTeBBtLUAtbBsnlTTmWthw99xqoOS7SsySDg==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "glibc": ">=2.26", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-darwin-x64": "1.0.1" + } + }, + "node_modules/@img/sharp-libvips-darwin-arm64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-arm64/-/sharp-libvips-darwin-arm64-1.0.1.tgz", + "integrity": "sha512-kQyrSNd6lmBV7O0BUiyu/OEw9yeNGFbQhbxswS1i6rMDwBBSX+e+rPzu3S+MwAiGU3HdLze3PanQ4Xkfemgzcw==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "macos": ">=11", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-darwin-x64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-darwin-x64/-/sharp-libvips-darwin-x64-1.0.1.tgz", + "integrity": "sha512-eVU/JYLPVjhhrd8Tk6gosl5pVlvsqiFlt50wotCvdkFGf+mDNBJxMh+bvav+Wt3EBnNZWq8Sp2I7XfSjm8siog==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "macos": ">=10.13", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-arm": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm/-/sharp-libvips-linux-arm-1.0.1.tgz", + "integrity": "sha512-FtdMvR4R99FTsD53IA3LxYGghQ82t3yt0ZQ93WMZ2xV3dqrb0E8zq4VHaTOuLEAuA83oDawHV3fd+BsAPadHIQ==", + "cpu": [ + "arm" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.28", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-arm64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-arm64/-/sharp-libvips-linux-arm64-1.0.1.tgz", + "integrity": "sha512-bnGG+MJjdX70mAQcSLxgeJco11G+MxTz+ebxlz8Y3dxyeb3Nkl7LgLI0mXupoO+u1wRNx/iRj5yHtzA4sde1yA==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.26", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-s390x": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-s390x/-/sharp-libvips-linux-s390x-1.0.1.tgz", + "integrity": "sha512-3+rzfAR1YpMOeA2zZNp+aYEzGNWK4zF3+sdMxuCS3ey9HhDbJ66w6hDSHDMoap32DueFwhhs3vwooAB2MaK4XQ==", + "cpu": [ + "s390x" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.28", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linux-x64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linux-x64/-/sharp-libvips-linux-x64-1.0.1.tgz", + "integrity": "sha512-3NR1mxFsaSgMMzz1bAnnKbSAI+lHXVTqAHgc1bgzjHuXjo4hlscpUxc0vFSAPKI3yuzdzcZOkq7nDPrP2F8Jgw==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.26", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linuxmusl-arm64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-arm64/-/sharp-libvips-linuxmusl-arm64-1.0.1.tgz", + "integrity": "sha512-5aBRcjHDG/T6jwC3Edl3lP8nl9U2Yo8+oTl5drd1dh9Z1EBfzUKAJFUDTDisDjUwc7N4AjnPGfCA3jl3hY8uDg==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "musl": ">=1.2.2", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-libvips-linuxmusl-x64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@img/sharp-libvips-linuxmusl-x64/-/sharp-libvips-linuxmusl-x64-1.0.1.tgz", + "integrity": "sha512-dcT7inI9DBFK6ovfeWRe3hG30h51cBAP5JXlZfx6pzc/Mnf9HFCQDLtYf4MCBjxaaTfjCCjkBxcy3XzOAo5txw==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "musl": ">=1.2.2", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-linux-arm": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm/-/sharp-linux-arm-0.33.2.tgz", + "integrity": "sha512-Fndk/4Zq3vAc4G/qyfXASbS3HBZbKrlnKZLEJzPLrXoJuipFNNwTes71+Ki1hwYW5lch26niRYoZFAtZVf3EGA==", + "cpu": [ + "arm" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.28", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-arm": "1.0.1" + } + }, + "node_modules/@img/sharp-linux-arm64": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-arm64/-/sharp-linux-arm64-0.33.2.tgz", + "integrity": "sha512-pz0NNo882vVfqJ0yNInuG9YH71smP4gRSdeL09ukC2YLE6ZyZePAlWKEHgAzJGTiOh8Qkaov6mMIMlEhmLdKew==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.26", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-arm64": "1.0.1" + } + }, + "node_modules/@img/sharp-linux-s390x": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-s390x/-/sharp-linux-s390x-0.33.2.tgz", + "integrity": "sha512-MBoInDXDppMfhSzbMmOQtGfloVAflS2rP1qPcUIiITMi36Mm5YR7r0ASND99razjQUpHTzjrU1flO76hKvP5RA==", + "cpu": [ + "s390x" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.28", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-s390x": "1.0.1" + } + }, + "node_modules/@img/sharp-linux-x64": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-linux-x64/-/sharp-linux-x64-0.33.2.tgz", + "integrity": "sha512-xUT82H5IbXewKkeF5aiooajoO1tQV4PnKfS/OZtb5DDdxS/FCI/uXTVZ35GQ97RZXsycojz/AJ0asoz6p2/H/A==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "glibc": ">=2.26", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linux-x64": "1.0.1" + } + }, + "node_modules/@img/sharp-linuxmusl-arm64": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-arm64/-/sharp-linuxmusl-arm64-0.33.2.tgz", + "integrity": "sha512-F+0z8JCu/UnMzg8IYW1TMeiViIWBVg7IWP6nE0p5S5EPQxlLd76c8jYemG21X99UzFwgkRo5yz2DS+zbrnxZeA==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "musl": ">=1.2.2", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linuxmusl-arm64": "1.0.1" + } + }, + "node_modules/@img/sharp-linuxmusl-x64": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-linuxmusl-x64/-/sharp-linuxmusl-x64-0.33.2.tgz", + "integrity": "sha512-+ZLE3SQmSL+Fn1gmSaM8uFusW5Y3J9VOf+wMGNnTtJUMUxFhv+P4UPaYEYT8tqnyYVaOVGgMN/zsOxn9pSsO2A==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "musl": ">=1.2.2", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + }, + "optionalDependencies": { + "@img/sharp-libvips-linuxmusl-x64": "1.0.1" + } + }, + "node_modules/@img/sharp-wasm32": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-wasm32/-/sharp-wasm32-0.33.2.tgz", + "integrity": "sha512-fLbTaESVKuQcpm8ffgBD7jLb/CQLcATju/jxtTXR1XCLwbOQt+OL5zPHSDMmp2JZIeq82e18yE0Vv7zh6+6BfQ==", + "cpu": [ + "wasm32" + ], + "optional": true, + "dependencies": { + "@emnapi/runtime": "^0.45.0" + }, + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-win32-ia32": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-win32-ia32/-/sharp-win32-ia32-0.33.2.tgz", + "integrity": "sha512-okBpql96hIGuZ4lN3+nsAjGeggxKm7hIRu9zyec0lnfB8E7Z6p95BuRZzDDXZOl2e8UmR4RhYt631i7mfmKU8g==", + "cpu": [ + "ia32" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, + "node_modules/@img/sharp-win32-x64": { + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/@img/sharp-win32-x64/-/sharp-win32-x64-0.33.2.tgz", + "integrity": "sha512-E4magOks77DK47FwHUIGH0RYWSgRBfGdK56kIHSVeB9uIS4pPFr4N2kIVsXdQQo4LzOsENKV5KAhRlRL7eMAdg==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": "^18.17.0 || ^20.3.0 || >=21.0.0", + "npm": ">=9.6.5", + "pnpm": ">=7.1.0", + "yarn": ">=3.2.0" + }, + "funding": { + "url": "https://opencollective.com/libvips" + } + }, "node_modules/@istanbuljs/load-nyc-config": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@istanbuljs/load-nyc-config/-/load-nyc-config-1.1.0.tgz", @@ -2026,11 +2466,6 @@ "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", "dev": true }, - "node_modules/b4a": { - "version": "1.6.4", - "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz", - "integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw==" - }, "node_modules/babel-jest": { "version": "29.6.1", "resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-29.6.1.tgz", @@ -2137,25 +2572,6 @@ "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", "dev": true }, - "node_modules/base64-js": { - "version": "1.5.1", - "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", - "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ] - }, "node_modules/batch": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/batch/-/batch-0.6.1.tgz", @@ -2171,16 +2587,6 @@ "node": ">=8" } }, - "node_modules/bl": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz", - "integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==", - "dependencies": { - "buffer": "^5.5.0", - "inherits": "^2.0.4", - "readable-stream": "^3.4.0" - } - }, "node_modules/bluebird": { "version": "3.7.2", "resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.7.2.tgz", @@ -2295,29 +2701,6 @@ "node-int64": "^0.4.0" } }, - "node_modules/buffer": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz", - "integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], - "dependencies": { - "base64-js": "^1.3.1", - "ieee754": "^1.1.13" - } - }, "node_modules/buffer-from": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", @@ -2483,11 +2866,6 @@ "fsevents": "~2.3.2" } }, - "node_modules/chownr": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz", - "integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==" - }, "node_modules/chrome-trace-event": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.3.tgz", @@ -2950,20 +3328,6 @@ "ms": "2.0.0" } }, - "node_modules/decompress-response": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz", - "integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==", - "dependencies": { - "mimic-response": "^3.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/dedent": { "version": "0.7.0", "resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz", @@ -2974,6 +3338,7 @@ "version": "0.6.0", "resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz", "integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==", + "dev": true, "engines": { "node": ">=4.0.0" } @@ -3151,14 +3516,6 @@ "node": ">= 0.8" } }, - "node_modules/end-of-stream": { - "version": "1.4.4", - "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz", - "integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==", - "dependencies": { - "once": "^1.4.0" - } - }, "node_modules/enhanced-resolve": { "version": "5.13.0", "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.13.0.tgz", @@ -3350,14 +3707,6 @@ "node": ">= 0.8.0" } }, - "node_modules/expand-template": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz", - "integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==", - "engines": { - "node": ">=6" - } - }, "node_modules/expect": { "version": "29.6.1", "resolved": "https://registry.npmjs.org/expect/-/expect-29.6.1.tgz", @@ -3429,11 +3778,6 @@ "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", "dev": true }, - "node_modules/fast-fifo": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", - "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==" - }, "node_modules/fast-glob": { "version": "3.2.12", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", @@ -3624,11 +3968,6 @@ "node": ">= 0.6" } }, - "node_modules/fs-constants": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz", - "integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==" - }, "node_modules/fs-monkey": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.3.tgz", @@ -3709,11 +4048,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/github-from-package": { - "version": "0.0.0", - "resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz", - "integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==" - }, "node_modules/glob": { "version": "7.2.3", "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", @@ -3992,25 +4326,6 @@ "node": ">=0.10.0" } }, - "node_modules/ieee754": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/ieee754/-/ieee754-1.2.1.tgz", - "integrity": "sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ] - }, "node_modules/ignore": { "version": "5.2.4", "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.4.tgz", @@ -4061,12 +4376,8 @@ "node_modules/inherits": { "version": "2.0.4", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", - "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==" - }, - "node_modules/ini": { - "version": "1.3.8", - "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz", - "integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==" + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "dev": true }, "node_modules/interpret": { "version": "3.1.1", @@ -5497,17 +5808,6 @@ "node": ">=6" } }, - "node_modules/mimic-response": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz", - "integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/minimalistic-assert": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", @@ -5530,6 +5830,7 @@ "version": "1.2.8", "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", + "dev": true, "funding": { "url": "https://github.com/sponsors/ljharb" } @@ -5546,11 +5847,6 @@ "node": ">=10" } }, - "node_modules/mkdirp-classic": { - "version": "0.5.3", - "resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz", - "integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==" - }, "node_modules/mkdirp2": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/mkdirp2/-/mkdirp2-1.0.5.tgz", @@ -5576,11 +5872,6 @@ "multicast-dns": "cli.js" } }, - "node_modules/napi-build-utils": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-1.0.2.tgz", - "integrity": "sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==" - }, "node_modules/natural-compare": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", @@ -5602,22 +5893,6 @@ "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", "dev": true }, - "node_modules/node-abi": { - "version": "3.35.0", - "resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.35.0.tgz", - "integrity": "sha512-jAlSOFR1Bls963NmFwxeQkNTzqjUF0NThm8Le7eRIRGzFUVJuMOFZDLv5Y30W/Oaw+KEebEJLAigwO9gQHoEmw==", - "dependencies": { - "semver": "^7.3.5" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/node-addon-api": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz", - "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA==" - }, "node_modules/node-forge": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.3.1.tgz", @@ -5715,6 +5990,7 @@ "version": "1.4.0", "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "dev": true, "dependencies": { "wrappy": "1" } @@ -5957,31 +6233,6 @@ "resolved": "https://registry.npmjs.org/platform/-/platform-1.3.6.tgz", "integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==" }, - "node_modules/prebuild-install": { - "version": "7.1.1", - "resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz", - "integrity": "sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==", - "dependencies": { - "detect-libc": "^2.0.0", - "expand-template": "^2.0.3", - "github-from-package": "0.0.0", - "minimist": "^1.2.3", - "mkdirp-classic": "^0.5.3", - "napi-build-utils": "^1.0.1", - "node-abi": "^3.3.0", - "pump": "^3.0.0", - "rc": "^1.2.7", - "simple-get": "^4.0.0", - "tar-fs": "^2.0.0", - "tunnel-agent": "^0.6.0" - }, - "bin": { - "prebuild-install": "bin.js" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/pretty-format": { "version": "29.6.1", "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.6.1.tgz", @@ -6072,15 +6323,6 @@ "node": ">= 0.10" } }, - "node_modules/pump": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz", - "integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==", - "dependencies": { - "end-of-stream": "^1.1.0", - "once": "^1.3.1" - } - }, "node_modules/punycode": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.0.tgz", @@ -6141,11 +6383,6 @@ } ] }, - "node_modules/queue-tick": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz", - "integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag==" - }, "node_modules/randombytes": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/randombytes/-/randombytes-2.1.0.tgz", @@ -6188,20 +6425,6 @@ "node": ">= 0.8" } }, - "node_modules/rc": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", - "integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==", - "dependencies": { - "deep-extend": "^0.6.0", - "ini": "~1.3.0", - "minimist": "^1.2.0", - "strip-json-comments": "~2.0.1" - }, - "bin": { - "rc": "cli.js" - } - }, "node_modules/react-is": { "version": "18.2.0", "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.2.0.tgz", @@ -6212,6 +6435,7 @@ "version": "3.6.1", "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", + "dev": true, "dependencies": { "inherits": "^2.0.3", "string_decoder": "^1.1.1", @@ -6478,6 +6702,7 @@ "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true, "funding": [ { "type": "github", @@ -6685,45 +6910,42 @@ } }, "node_modules/sharp": { - "version": "0.32.6", - "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz", - "integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==", + "version": "0.33.2", + "resolved": "https://registry.npmjs.org/sharp/-/sharp-0.33.2.tgz", + "integrity": "sha512-WlYOPyyPDiiM07j/UO+E720ju6gtNtHjEGg5vovUk1Lgxyjm2LFO+37Nt/UI3MMh2l6hxTWQWi7qk3cXJTutcQ==", "hasInstallScript": true, "dependencies": { "color": "^4.2.3", "detect-libc": "^2.0.2", - "node-addon-api": "^6.1.0", - "prebuild-install": "^7.1.1", - "semver": "^7.5.4", - "simple-get": "^4.0.1", - "tar-fs": "^3.0.4", - "tunnel-agent": "^0.6.0" + "semver": "^7.5.4" }, "engines": { - "node": ">=14.15.0" + "libvips": ">=8.15.1", + "node": "^18.17.0 || ^20.3.0 || >=21.0.0" }, "funding": { "url": "https://opencollective.com/libvips" - } - }, - "node_modules/sharp/node_modules/tar-fs": { - "version": "3.0.4", - "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz", - "integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==", - "dependencies": { - "mkdirp-classic": "^0.5.2", - "pump": "^3.0.0", - "tar-stream": "^3.1.5" - } - }, - "node_modules/sharp/node_modules/tar-stream": { - "version": "3.1.6", - "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz", - "integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==", - "dependencies": { - "b4a": "^1.6.4", - "fast-fifo": "^1.2.0", - "streamx": "^2.15.0" + }, + "optionalDependencies": { + "@img/sharp-darwin-arm64": "0.33.2", + "@img/sharp-darwin-x64": "0.33.2", + "@img/sharp-libvips-darwin-arm64": "1.0.1", + "@img/sharp-libvips-darwin-x64": "1.0.1", + "@img/sharp-libvips-linux-arm": "1.0.1", + "@img/sharp-libvips-linux-arm64": "1.0.1", + "@img/sharp-libvips-linux-s390x": "1.0.1", + "@img/sharp-libvips-linux-x64": "1.0.1", + "@img/sharp-libvips-linuxmusl-arm64": "1.0.1", + "@img/sharp-libvips-linuxmusl-x64": "1.0.1", + "@img/sharp-linux-arm": "0.33.2", + "@img/sharp-linux-arm64": "0.33.2", + "@img/sharp-linux-s390x": "0.33.2", + "@img/sharp-linux-x64": "0.33.2", + "@img/sharp-linuxmusl-arm64": "0.33.2", + "@img/sharp-linuxmusl-x64": "0.33.2", + "@img/sharp-wasm32": "0.33.2", + "@img/sharp-win32-ia32": "0.33.2", + "@img/sharp-win32-x64": "0.33.2" } }, "node_modules/shebang-command": { @@ -6776,49 +6998,6 @@ "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", "dev": true }, - "node_modules/simple-concat": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz", - "integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ] - }, - "node_modules/simple-get": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz", - "integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], - "dependencies": { - "decompress-response": "^6.0.0", - "once": "^1.3.1", - "simple-concat": "^1.0.0" - } - }, "node_modules/simple-swizzle": { "version": "0.2.2", "resolved": "https://registry.npmjs.org/simple-swizzle/-/simple-swizzle-0.2.2.tgz", @@ -7042,19 +7221,11 @@ "node": ">=0.10.0" } }, - "node_modules/streamx": { - "version": "2.15.5", - "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.5.tgz", - "integrity": "sha512-9thPGMkKC2GctCzyCUjME3yR03x2xNo0GPKGkRw2UMYN+gqWa9uqpyNWhmsNCutU5zHmkUum0LsCRQTXUgUCAg==", - "dependencies": { - "fast-fifo": "^1.1.0", - "queue-tick": "^1.0.1" - } - }, "node_modules/string_decoder": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", + "dev": true, "dependencies": { "safe-buffer": "~5.2.0" } @@ -7116,14 +7287,6 @@ "node": ">=6" } }, - "node_modules/strip-json-comments": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz", - "integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==", - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/supports-color": { "version": "8.1.1", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", @@ -7188,32 +7351,6 @@ "node": ">=6" } }, - "node_modules/tar-fs": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz", - "integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==", - "dependencies": { - "chownr": "^1.1.1", - "mkdirp-classic": "^0.5.2", - "pump": "^3.0.0", - "tar-stream": "^2.1.4" - } - }, - "node_modules/tar-stream": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz", - "integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==", - "dependencies": { - "bl": "^4.0.3", - "end-of-stream": "^1.4.1", - "fs-constants": "^1.0.0", - "inherits": "^2.0.3", - "readable-stream": "^3.1.1" - }, - "engines": { - "node": ">=6" - } - }, "node_modules/temp-path": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/temp-path/-/temp-path-1.0.0.tgz", @@ -7353,16 +7490,11 @@ "node": ">=0.6" } }, - "node_modules/tunnel-agent": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz", - "integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==", - "dependencies": { - "safe-buffer": "^5.0.1" - }, - "engines": { - "node": "*" - } + "node_modules/tslib": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", + "optional": true }, "node_modules/type-detect": { "version": "4.0.8", @@ -7493,7 +7625,8 @@ "node_modules/util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", - "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==" + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", + "dev": true }, "node_modules/utils-merge": { "version": "1.0.1", @@ -7997,7 +8130,8 @@ "node_modules/wrappy": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", - "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==" + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "dev": true }, "node_modules/write-file-atomic": { "version": "4.0.2", diff --git a/package.json b/package.json index 43ae7d214..17097f03f 100644 --- a/package.json +++ b/package.json @@ -38,9 +38,9 @@ }, "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { + "@huggingface/jinja": "^0.2.1", "onnxruntime-web": "1.17.1", - "sharp": "^0.32.0", - "@huggingface/jinja": "^0.2.1" + "sharp": "^0.33.2" }, "optionalDependencies": { "onnxruntime-node": "1.17.0" From cd0cd9cade01e53187dce7fb1f5a1bf95aba23fb Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Mar 2024 00:53:04 +0200 Subject: [PATCH 036/473] Fix `onnxruntime-common` import in tests --- tests/init.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/init.js b/tests/init.js index 2fcd8609e..7710c3223 100644 --- a/tests/init.js +++ b/tests/init.js @@ -6,7 +6,7 @@ import * as types from "node:util/types"; // Import onnxruntime-node's default backend import { onnxruntimeBackend } from "onnxruntime-node/dist/backend"; -import ONNX_COMMON from "onnxruntime-common"; +import * as ONNX_COMMON from "onnxruntime-common"; export function init() { // In rare cases (specifically when running unit tests with GitHub actions), possibly due to From b0b5e412beba47bd95f45e88d961a4cdaf5fcc01 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Mar 2024 02:25:57 +0200 Subject: [PATCH 037/473] Update main.js --- examples/webgpu-embedding-benchmark/main.js | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/webgpu-embedding-benchmark/main.js b/examples/webgpu-embedding-benchmark/main.js index f7b40d23b..fa3ce8ce2 100644 --- a/examples/webgpu-embedding-benchmark/main.js +++ b/examples/webgpu-embedding-benchmark/main.js @@ -10,10 +10,8 @@ if (!navigator.gpu) { } // Proxy the WASM backend to prevent the UI from freezing -env.backends.onnx.wasm.proxy = true; env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/'; env.backends.onnx.wasm.numThreads = 1; -env.experimental.useWebGPU = true; // Reference the elements that we will need const ctx = document.getElementById('chart'); @@ -89,9 +87,7 @@ let model_CPU; try { model_CPU = await AutoModel.from_pretrained(MODEL_ID, { quantized: QUANTIZED, - session_options: { - executionProviders: ['wasm'] - } + device: 'webgpu' }); } catch (err) { status.textContent = err.message; From 873c349470108a1803eecbf812ca36eaa05694ce Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Mar 2024 02:38:20 +0200 Subject: [PATCH 038/473] Update main.js --- examples/webgpu-embedding-benchmark/main.js | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/webgpu-embedding-benchmark/main.js b/examples/webgpu-embedding-benchmark/main.js index fa3ce8ce2..2b8a36f86 100644 --- a/examples/webgpu-embedding-benchmark/main.js +++ b/examples/webgpu-embedding-benchmark/main.js @@ -87,7 +87,7 @@ let model_CPU; try { model_CPU = await AutoModel.from_pretrained(MODEL_ID, { quantized: QUANTIZED, - device: 'webgpu' + device: 'wasm', }); } catch (err) { status.textContent = err.message; @@ -99,9 +99,7 @@ let model_GPU; try { model_GPU = await AutoModel.from_pretrained(MODEL_ID, { quantized: QUANTIZED, - session_options: { - executionProviders: ['webgpu'] - } + device: 'webgpu', }); } catch (err) { status.textContent = err.message; From 521cefc5090d2aa17d1da031fb308993b88b66f9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Mar 2024 14:28:51 +0200 Subject: [PATCH 039/473] Add `dtype` support --- src/models.js | 25 +++++++++++++++++++++++-- src/pipelines.js | 2 ++ src/utils/dtypes.js | 13 +++++++++++++ src/utils/hub.js | 1 + 4 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 src/utils/dtypes.js diff --git a/src/models.js b/src/models.js index 4219ca9ea..8489f2fd8 100644 --- a/src/models.js +++ b/src/models.js @@ -83,6 +83,7 @@ import { import { createInferenceSession, isONNXTensor, isONNXProxy } from './backends/onnx.js'; import { medianFilter } from './transformers.js'; +import { DATA_TYPES } from './utils/dtypes.js'; ////////////////////////////////////////////////// // Model types: used internally @@ -115,14 +116,30 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); * @private */ async function constructSession(pretrained_model_name_or_path, fileName, options) { - const modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`; + + // If options.dtype is specified, we use that for the variant. + // Otherwise, we use options.quantized to determine the variant. + let variant = ''; + if (options.dtype) { + if (!DATA_TYPES.hasOwnProperty(options.dtype)) { + throw new Error(`Invalid dtype: ${options.dtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`); + } + const dtype = DATA_TYPES[options.dtype]; + if (dtype !== 'fp32') { + variant = `_${options.dtype}`; + } + } else if (options.quantized) { + variant = '_quantized'; + } + + const modelFileName = `onnx/${fileName}${variant}.onnx`; const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); const session_options = options.session_options ?? {}; // handle onnx external data files if (session_options.externalData !== undefined) { - for (let i = 0; i < session_options.externalData.length; i++) { + for (let i = 0; i < session_options.externalData.length; ++i) { const ext = session_options.externalData[i]; // if the external data is a string, fetch the file and replace the string with its content if (typeof ext.data === "string") { @@ -749,6 +766,7 @@ export class PreTrainedModel extends Callable { revision = 'main', model_file_name = null, device = null, + dtype = null, session_options = {}, } = {}) { @@ -761,6 +779,7 @@ export class PreTrainedModel extends Callable { revision, model_file_name, device, + dtype, session_options, } @@ -5458,6 +5477,7 @@ export class PretrainedMixin { revision = 'main', model_file_name = null, device = null, + dtype = null, session_options = {}, } = {}) { @@ -5470,6 +5490,7 @@ export class PretrainedMixin { revision, model_file_name, device, + dtype, session_options, } config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); diff --git a/src/pipelines.js b/src/pipelines.js index f9774f803..58ca3df04 100755 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3020,6 +3020,7 @@ export async function pipeline( local_files_only = false, revision = 'main', device = null, + dtype = null, session_options = {}, } = {} ) { @@ -3049,6 +3050,7 @@ export async function pipeline( local_files_only, revision, device, + dtype, session_options, } diff --git a/src/utils/dtypes.js b/src/utils/dtypes.js new file mode 100644 index 000000000..0f025dc59 --- /dev/null +++ b/src/utils/dtypes.js @@ -0,0 +1,13 @@ + +export const DATA_TYPES = Object.freeze({ + fp32: 'fp32', + fp16: 'fp16', + int8: 'int8', + uint8: 'uint8', + + // Aliases (same as torch.float32 and torch.float16) + float32: 'fp32', + float16: 'fp16', +}); + +/** @typedef {keyof typeof DATA_TYPES} DataType */ diff --git a/src/utils/hub.js b/src/utils/hub.js index 4062d4826..b3172d07d 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -29,6 +29,7 @@ import { dispatchCallback } from './core.js'; * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. * @property {import("./devices.js").DeviceType} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. + * @property {import("./dtypes.js").DataType} [dtype=null] The data type to use for the model. If not specified, the data type will be chosen from the environment settings. * @property {Object} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen. */ From ba21199ee1165f4190cb9e3189a8339773b7fe96 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Mar 2024 16:04:19 +0200 Subject: [PATCH 040/473] Update benchmark space --- .../webgpu-embedding-benchmark/index.html | 30 ++- examples/webgpu-embedding-benchmark/main.js | 205 +++++++++++------- examples/webgpu-embedding-benchmark/style.css | 7 +- 3 files changed, 158 insertions(+), 84 deletions(-) diff --git a/examples/webgpu-embedding-benchmark/index.html b/examples/webgpu-embedding-benchmark/index.html index 74c78664d..6ceeae4ab 100644 --- a/examples/webgpu-embedding-benchmark/index.html +++ b/examples/webgpu-embedding-benchmark/index.html @@ -12,8 +12,7 @@

🤗 Transformers.js WebGPU Benchmark

- This benchmark measures the execution time of Xenova/all-MiniLM-L6-v2 (bert-based embedding model) + This benchmark measures the execution time of BERT-based embedding models using the WASM and WebGPU execution providers across different batch sizes.

@@ -26,6 +25,25 @@

Options +
+ WASM (int8)
+ WASM (fp16)
+ WASM (fp32)
+ + WebGPU (fp16)
+ WebGPU (fp32)
+
+
+
+ + +
@@ -34,13 +52,13 @@

+
- - - - + Log scale (x)
+ Log scale (y)
+ \ No newline at end of file diff --git a/examples/webgpu-embedding-benchmark/main.js b/examples/webgpu-embedding-benchmark/main.js index 2b8a36f86..ebd03fedc 100644 --- a/examples/webgpu-embedding-benchmark/main.js +++ b/examples/webgpu-embedding-benchmark/main.js @@ -19,31 +19,22 @@ const batchSizes = document.getElementById('batch-sizes'); const xscale = document.getElementById('x-scale'); const yscale = document.getElementById('y-scale'); const sequenceLength = document.getElementById('sequence-length'); +const modelID = document.getElementById('model-id'); const status = document.getElementById('status'); const start = document.getElementById('start'); const stop = document.getElementById('stop'); +const tests = document.getElementsByClassName('tests'); // Benchmark settings const NUM_WARMUP_STEPS = 3; -const QUANTIZED = false; -const MODEL_ID = 'Xenova/all-MiniLM-L6-v2'; +const MODEL_CACHE = new Map(); // Chart configuration const config = { type: 'line', data: { labels: [], - datasets: [{ - label: 'WASM', - data: [], - borderColor: 'red', - backgroundColor: 'rgba(255, 0, 0, 0.5)', - }, { - label: 'WebGPU', - data: [], - borderColor: 'blue', - backgroundColor: 'rgba(0, 0, 255, 0.5)', - }] + datasets: [], }, options: { responsive: true, @@ -70,125 +61,178 @@ const config = { } }, }; +const chart = new Chart(ctx, config); -const toggleScale = (chart, axis, enabled) => { +const toggleScale = (axis, enabled) => { chart.options.scales[axis].type = enabled ? 'logarithmic' : 'linear'; chart.update(); } -xscale.addEventListener('change', () => toggleScale(chart, 'x', xscale.checked)); -yscale.addEventListener('change', () => toggleScale(chart, 'y', yscale.checked)); +const getSelectedTests = () => { + return [...tests].filter(x => x.checked); +} -const chart = new Chart(ctx, config); +const updateDatasets = () => { + chart.data.datasets = getSelectedTests().map(test => { + const color = test.getAttribute('data-color'); + return { + label: test.value, + data: [], + borderColor: `rgba(${color}, 1)`, + backgroundColor: `rgba(${color}, 0.5)`, + } + }) + chart.update(); +} +updateDatasets(); +[...tests].forEach(test => test.addEventListener('change', updateDatasets)); -status.textContent = 'Loading model...'; +xscale.addEventListener('change', () => toggleScale('x', xscale.checked)); +yscale.addEventListener('change', () => toggleScale('y', yscale.checked)); -let model_CPU; -try { - model_CPU = await AutoModel.from_pretrained(MODEL_ID, { - quantized: QUANTIZED, - device: 'wasm', - }); -} catch (err) { - status.textContent = err.message; - alert(err.message) - throw err; -} +const generateDummyInputs = (batch_size, seqLength) => { + const inputs = ones([batch_size, seqLength]); -let model_GPU; -try { - model_GPU = await AutoModel.from_pretrained(MODEL_ID, { - quantized: QUANTIZED, - device: 'webgpu', - }); -} catch (err) { - status.textContent = err.message; - alert(err.message) - throw err; + const model_inputs = { + input_ids: inputs, + attention_mask: inputs, + } + return model_inputs; } let adapterInfo; +let gpuHasFp16 = false; try { // Shouldn't fail since the WebGPU model has loaded successfully const adapter = await navigator.gpu.requestAdapter(); adapterInfo = await adapter.requestAdapterInfo(); + gpuHasFp16 = adapter.features.has('shader-f16') } catch (err) { adapterInfo = {}; } +if (!gpuHasFp16) { + const element = document.querySelector('.tests[data-device="webgpu"][data-dtype="fp16"]'); + element.setAttribute('unsupported', true); + element.disabled = true; + element.title = 'This device does not support fp16 on WebGPU'; +} status.textContent = 'Ready'; let interrupted = false; start.addEventListener('click', async () => { + const validTests = [...tests].filter(test => !test.getAttribute('unsupported')) + // Update UI start.disabled = true; stop.disabled = false; + batchSizes.disabled = true; + sequenceLength.disabled = true; + modelID.disabled = true; + validTests.forEach(test => test.disabled = true); interrupted = false; + // Get parameters + const model_id = modelID.value; + const batch_sizes = batchSizes.value.split(',').map(x => parseInt(x)).filter(x => x); + const seqLength = parseInt(sequenceLength.value); + const selectedTests = getSelectedTests().map(x => ({ + label: x.value, + dtype: x.getAttribute('data-dtype'), + device: x.getAttribute('data-device'), + })); + // Reset chart.data.labels = []; for (let i = 0; i < chart.data.datasets; ++i) { - chart.data.datasets[i].data = []; + chart.data.datasets[i].data.length = 0; } chart.update(); - const seqLength = parseInt(sequenceLength.value); - - status.textContent = 'Warming up...'; - - const generateDummyInputs = (batch_size) => { + // NOTE: Models must be loaded sequentially (otherwise it will fail due to multiple calls to initWasm()) + const testsToRun = new Map(); + for (const test of selectedTests) { + const { label, dtype, device, quantized } = test; - const inputs = ones([batch_size, seqLength]); + const key = `${model_id}///${label}` - const model_inputs = { - input_ids: inputs, - attention_mask: inputs, + const cached = MODEL_CACHE.get(key); + if (cached) { + testsToRun.set(label, cached); + continue; + } + status.textContent = 'Loading model(s)...'; + + try { + const model = await AutoModel.from_pretrained(model_id, { + quantized, + device, + dtype, + }); + MODEL_CACHE.set(key, model); + testsToRun.set(label, model); + } catch (err) { + status.textContent = err.message; + alert(err.message) + throw err; } - return model_inputs; } + status.textContent = 'Warming up...'; + // Warm up: This is important for the WebGPU execution provider, which compiles the shaders on first load for (let i = 0; i < NUM_WARMUP_STEPS; ++i) { - const model_inputs = generateDummyInputs(1); - await model_CPU(model_inputs); - await model_GPU(model_inputs); + const model_inputs = generateDummyInputs(1, seqLength); + for (const [label, model] of testsToRun) { + await model(model_inputs); + } } status.textContent = 'Running benchmark...'; - const batch_sizes = batchSizes.value.split(',').map(x => parseInt(x)).filter(x => x); - for (const batch_size of batch_sizes) { if (interrupted) break; - const model_inputs = generateDummyInputs(batch_size); + const model_inputs = generateDummyInputs(batch_size, seqLength); - let wasmTime; - { // Run WASM - const start = performance.now(); - await model_CPU(model_inputs); - const end = performance.now(); - wasmTime = end - start; - } + const times = [] - let webGPUTime; - { // Run WebGPU + for (const [label, model] of testsToRun) { const start = performance.now(); - await model_GPU(model_inputs); + await model(model_inputs); const end = performance.now(); - webGPUTime = end - start; + times.push(end - start); } + chart.data.labels.push(batch_size); - chart.data.datasets[0].data.push(wasmTime); - chart.data.datasets[1].data.push(webGPUTime); + for (let i = 0; i < times.length; ++i) { + chart.data.datasets[i].data.push(times[i]); + } chart.update(); } // Calculate max speedup: if (chart.data.labels.length === 0) return; - const table = generateResultsTable(chart.data, seqLength); + const testNames = [...testsToRun.keys()]; + const table = generateResultsTable(model_id, testNames, chart.data, seqLength); - const speedup = chart.data.datasets[0].data.at(-1) / chart.data.datasets[1].data.at(-1); + + // Calculate slowest and fastest times + let minMaxTimes = [Infinity, 0]; + let minMaxIndices = [0, 0]; + for (let i = 0; i < chart.data.datasets.length; i++) { + const lastTime = chart.data.datasets[i].data.at(-1); + if (lastTime < minMaxTimes[0]) { + minMaxTimes[0] = lastTime; + minMaxIndices[0] = i; + } + if (lastTime > minMaxTimes[1]) { + minMaxTimes[1] = lastTime; + minMaxIndices[1] = i; + } + } + + const speedup = minMaxTimes[1] / minMaxTimes[0]; const roundedSpeedup = speedup.toFixed(2); const params = new URLSearchParams({ title: `⚡ WebGPU Benchmark Results (${roundedSpeedup}x speedup)`, @@ -196,9 +240,14 @@ start.addEventListener('click', async () => { }); const paramsStr = params.toString(); - status.innerHTML = `⚡ Done! WebGPU is ${roundedSpeedup}x faster! Share results`; + status.innerHTML = `⚡ Done! ${testNames.at(minMaxIndices[0])} is ${roundedSpeedup}x faster than ${testNames.at(minMaxIndices[1])}! ⚡
Share results`; start.disabled = false; + batchSizes.disabled = false; + sequenceLength.disabled = false; + modelID.disabled = false; + validTests.forEach(test => test.disabled = false); }); + start.disabled = false; stop.addEventListener('click', () => { @@ -207,7 +256,8 @@ stop.addEventListener('click', () => { stop.disabled = true; }); -function generateResultsTable(data, sequence_length) { +function generateResultsTable(model_id, testNames, data, sequence_length) { + const datasets = data.datasets.map(d => d.data); const batch_sizes = data.labels; @@ -220,8 +270,9 @@ function generateResultsTable(data, sequence_length) { // Add header row const headerRow = thead.insertRow(); headerRow.insertCell().textContent = 'Batch Size'; - headerRow.insertCell().textContent = `WASM (ms)`; - headerRow.insertCell().textContent = `WebGPU (ms)`; + testNames.forEach(model => { + headerRow.insertCell().textContent = model; + }); // Add data rows batch_sizes.forEach((batchSize, rowIndex) => { @@ -242,8 +293,8 @@ function generateResultsTable(data, sequence_length) { // Add other information const info = document.createElement('ul'); - info.appendChild(createBulletPoint(`Model: ${MODEL_ID}`)); - info.appendChild(createBulletPoint(`Quantized: ${QUANTIZED}`)); + info.appendChild(createBulletPoint(`Model: ${model_id}`)); + info.appendChild(createBulletPoint(`Tests run: ${testNames.join(', ')}`)); info.appendChild(createBulletPoint(`Sequence length: ${sequence_length}`)); info.appendChild(createBulletPoint(`Browser: ${navigator.userAgent}`)); info.appendChild(createBulletPoint(`GPU: vendor=${adapterInfo.vendor}, architecture=${adapterInfo.architecture}, device=${adapterInfo.device}, description=${adapterInfo.description}`)); diff --git a/examples/webgpu-embedding-benchmark/style.css b/examples/webgpu-embedding-benchmark/style.css index d5748146d..9253d75e3 100644 --- a/examples/webgpu-embedding-benchmark/style.css +++ b/examples/webgpu-embedding-benchmark/style.css @@ -25,6 +25,7 @@ h1 { #status { min-height: 16px; margin: 8px 0; + text-align: center; } button { @@ -79,4 +80,8 @@ details { summary { text-align: right; -} \ No newline at end of file +} + +hr { + margin: 8px 0; +} From 9371df8548c93f2f0cede61af856ce38306d8f9e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Mar 2024 16:28:06 +0200 Subject: [PATCH 041/473] Fix visual chart bug --- examples/webgpu-embedding-benchmark/main.js | 70 +++++++++++---------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/examples/webgpu-embedding-benchmark/main.js b/examples/webgpu-embedding-benchmark/main.js index ebd03fedc..97c306c76 100644 --- a/examples/webgpu-embedding-benchmark/main.js +++ b/examples/webgpu-embedding-benchmark/main.js @@ -30,39 +30,42 @@ const NUM_WARMUP_STEPS = 3; const MODEL_CACHE = new Map(); // Chart configuration -const config = { - type: 'line', - data: { - labels: [], - datasets: [], - }, - options: { - responsive: true, - maintainAspectRatio: false, - plugins: { - legend: { - position: 'top', - }, +const initChart = () => { + const config = { + type: 'line', + data: { + labels: [], + datasets: [], }, - scales: { - x: { - title: { - display: true, - text: 'Batch size', + options: { + responsive: true, + maintainAspectRatio: false, + plugins: { + legend: { + position: 'top', }, - min: 1, }, - y: { - title: { - display: true, - text: 'Time (ms)', + scales: { + x: { + title: { + display: true, + text: 'Batch size', + }, + min: 1, }, + y: { + title: { + display: true, + text: 'Time (ms)', + }, + } } - } - }, -}; -const chart = new Chart(ctx, config); - + }, + }; + const chart = new Chart(ctx, config); + return chart; +} +let chart = initChart(); const toggleScale = (axis, enabled) => { chart.options.scales[axis].type = enabled ? 'logarithmic' : 'linear'; chart.update(); @@ -142,18 +145,16 @@ start.addEventListener('click', async () => { })); // Reset - chart.data.labels = []; - for (let i = 0; i < chart.data.datasets; ++i) { - chart.data.datasets[i].data.length = 0; - } - chart.update(); + chart.destroy(); + chart = initChart(); + updateDatasets(); // NOTE: Models must be loaded sequentially (otherwise it will fail due to multiple calls to initWasm()) const testsToRun = new Map(); for (const test of selectedTests) { const { label, dtype, device, quantized } = test; - const key = `${model_id}///${label}` + const key = `${model_id}///${label}`; const cached = MODEL_CACHE.get(key); if (cached) { @@ -242,6 +243,7 @@ start.addEventListener('click', async () => { const paramsStr = params.toString(); status.innerHTML = `⚡ Done! ${testNames.at(minMaxIndices[0])} is ${roundedSpeedup}x faster than ${testNames.at(minMaxIndices[1])}! ⚡
Share results`; start.disabled = false; + stop.disabled = true; batchSizes.disabled = false; sequenceLength.disabled = false; modelID.disabled = false; From d24f764846cbe3d3b025162700b6ede23939814d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 13 Mar 2024 18:55:27 +0200 Subject: [PATCH 042/473] Update examples/webgpu-embedding-benchmark/main.js Co-authored-by: Victor Nogueira --- examples/webgpu-embedding-benchmark/main.js | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/webgpu-embedding-benchmark/main.js b/examples/webgpu-embedding-benchmark/main.js index 97c306c76..bdf731395 100644 --- a/examples/webgpu-embedding-benchmark/main.js +++ b/examples/webgpu-embedding-benchmark/main.js @@ -9,7 +9,6 @@ if (!navigator.gpu) { throw Error(err); } -// Proxy the WASM backend to prevent the UI from freezing env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/'; env.backends.onnx.wasm.numThreads = 1; From 91c570c86b63822dcaf350398cae62035539f750 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 13 Mar 2024 19:43:27 +0200 Subject: [PATCH 043/473] Create video background removal demo --- .../.gitignore | 24 ++++ .../index.html | 43 ++++++ .../webgpu-video-background-removal/main.js | 128 ++++++++++++++++++ .../package.json | 17 +++ .../webgpu-video-background-removal/style.css | 87 ++++++++++++ .../vite.config.js | 6 + 6 files changed, 305 insertions(+) create mode 100644 examples/webgpu-video-background-removal/.gitignore create mode 100644 examples/webgpu-video-background-removal/index.html create mode 100644 examples/webgpu-video-background-removal/main.js create mode 100644 examples/webgpu-video-background-removal/package.json create mode 100644 examples/webgpu-video-background-removal/style.css create mode 100644 examples/webgpu-video-background-removal/vite.config.js diff --git a/examples/webgpu-video-background-removal/.gitignore b/examples/webgpu-video-background-removal/.gitignore new file mode 100644 index 000000000..a547bf36d --- /dev/null +++ b/examples/webgpu-video-background-removal/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/examples/webgpu-video-background-removal/index.html b/examples/webgpu-video-background-removal/index.html new file mode 100644 index 000000000..59e0a8428 --- /dev/null +++ b/examples/webgpu-video-background-removal/index.html @@ -0,0 +1,43 @@ + + + + + + + Transformers.js | Real-time background removal + + + +

+ Real-time background removal w/ + 🤗 Transformers.js +

+

+ Runs locally in your browser, powered by + MODNet +

+
+ + + +
+
+
+ + () +
+ +
+
+ + () +
+ +
+
+ + + + + + \ No newline at end of file diff --git a/examples/webgpu-video-background-removal/main.js b/examples/webgpu-video-background-removal/main.js new file mode 100644 index 000000000..620f21afb --- /dev/null +++ b/examples/webgpu-video-background-removal/main.js @@ -0,0 +1,128 @@ +import './style.css'; + +import { env, AutoModel, AutoProcessor, RawImage } from '@xenova/transformers'; + +env.backends.onnx.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/'; +env.backends.onnx.wasm.numThreads = 1; + +// Reference the elements that we will need +const status = document.getElementById('status'); +const container = document.getElementById('container'); +const canvas = document.getElementById('canvas'); +const outputCanvas = document.getElementById('output-canvas'); +const video = document.getElementById('video'); +const sizeSlider = document.getElementById('size'); +const sizeLabel = document.getElementById('size-value'); +const scaleSlider = document.getElementById('scale'); +const scaleLabel = document.getElementById('scale-value'); + +function setStreamSize(width, height) { + video.width = outputCanvas.width = canvas.width = Math.round(width); + video.height = outputCanvas.height = canvas.height = Math.round(height); +} + +status.textContent = 'Loading model...'; + +// Load model and processor +const model_id = 'Xenova/modnet'; +let model; +try { + model = await AutoModel.from_pretrained(model_id, { + device: 'webgpu', + dtype: 'fp32', // TODO: add fp16 support + }); +} catch (err) { + status.textContent = err.message; + alert(err.message) + throw err; +} + +const processor = await AutoProcessor.from_pretrained(model_id); + +// Set up controls +let size = 256; +processor.feature_extractor.size = { shortest_edge: size }; +sizeSlider.addEventListener('input', () => { + size = Number(sizeSlider.value); + processor.feature_extractor.size = { shortest_edge: size }; + sizeLabel.textContent = size; +}); +sizeSlider.disabled = false; + +let scale = 0.5; +scaleSlider.addEventListener('input', () => { + scale = Number(scaleSlider.value); + setStreamSize(video.videoWidth * scale, video.videoHeight * scale); + scaleLabel.textContent = scale; +}); +scaleSlider.disabled = false; + +status.textContent = 'Ready'; + +let isProcessing = false; +let previousTime; +const context = canvas.getContext('2d', { willReadFrequently: true }); +const outputContext = outputCanvas.getContext('2d', { willReadFrequently: true }); +function updateCanvas() { + const { width, height } = canvas; + + if (!isProcessing) { + isProcessing = true; + (async function () { + // Read the current frame from the video + context.drawImage(video, 0, 0, width, height); + const currentFrame = context.getImageData(0, 0, width, height); + const image = new RawImage(currentFrame.data, width, height, 4); + + // Pre-process image + const inputs = await processor(image); + + // Predict alpha matte + const { output } = await model({ input: inputs.pixel_values }); + + const mask = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(width, height); + + // Update alpha channel + const outPixelData = currentFrame; + for (let i = 0; i < mask.data.length; ++i) { + outPixelData.data[4 * i + 3] = mask.data[i]; + } + outputContext.putImageData(outPixelData, 0, 0); + + if (previousTime !== undefined) { + const fps = 1000 / (performance.now() - previousTime); + status.textContent = `FPS: ${fps.toFixed(2)}`; + } + previousTime = performance.now(); + + isProcessing = false; + })(); + } + + window.requestAnimationFrame(updateCanvas); +} + +// Start the video stream +navigator.mediaDevices.getUserMedia( + { video: true }, // Ask for video +).then((stream) => { + // Set up the video and canvas elements. + video.srcObject = stream; + video.play(); + + const videoTrack = stream.getVideoTracks()[0]; + const { width, height } = videoTrack.getSettings(); + + setStreamSize(width * scale, height * scale); + + // Set container width and height depending on the image aspect ratio + const ar = width / height; + const [cw, ch] = (ar > 720 / 405) ? [720, 720 / ar] : [405 * ar, 405]; + container.style.width = `${cw}px`; + container.style.height = `${ch}px`; + + // Start the animation loop + setTimeout(updateCanvas, 50); +}).catch((error) => { + alert(error); +}); diff --git a/examples/webgpu-video-background-removal/package.json b/examples/webgpu-video-background-removal/package.json new file mode 100644 index 000000000..9ebe47afe --- /dev/null +++ b/examples/webgpu-video-background-removal/package.json @@ -0,0 +1,17 @@ +{ + "name": "webgpu-video-background-removal", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "devDependencies": { + "vite": "^5.0.12" + }, + "dependencies": { + "@xenova/transformers": "^3.0.0" + } +} diff --git a/examples/webgpu-video-background-removal/style.css b/examples/webgpu-video-background-removal/style.css new file mode 100644 index 000000000..a86729e1c --- /dev/null +++ b/examples/webgpu-video-background-removal/style.css @@ -0,0 +1,87 @@ +* { + box-sizing: border-box; + padding: 0; + margin: 0; + font-family: sans-serif; +} + +html, +body { + height: 100%; +} + +body { + padding: 16px 32px; +} + +body, +#container { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; +} + +#controls { + display: flex; + padding: 1rem; + gap: 1rem; +} + +#controls>div { + text-align: center; +} + +h1, +h4 { + text-align: center; +} + +h4 { + margin-top: 0.5rem; +} + +#container { + position: relative; + width: 720px; + height: 405px; + max-width: 100%; + max-height: 100%; + border: 2px dashed #D1D5DB; + border-radius: 0.75rem; + overflow: hidden; + margin-top: 1rem; + background-size: 100% 100%; + background-position: center; + background-repeat: no-repeat; +} + +#overlay, +canvas { + position: absolute; + width: 100%; + height: 100%; +} + +#status { + min-height: 16px; + margin: 8px 0; +} + +.bounding-box { + position: absolute; + box-sizing: border-box; + border: solid 2px; +} + +.bounding-box-label { + color: white; + position: absolute; + font-size: 12px; + margin: -16px 0 0 -2px; + padding: 1px; +} + +#video, #canvas { + display: none; +} diff --git a/examples/webgpu-video-background-removal/vite.config.js b/examples/webgpu-video-background-removal/vite.config.js new file mode 100644 index 000000000..6c32f52df --- /dev/null +++ b/examples/webgpu-video-background-removal/vite.config.js @@ -0,0 +1,6 @@ +import { defineConfig } from 'vite'; +export default defineConfig({ + build: { + target: 'esnext' + } +}); From bf8c8d5fe5acb195541040d87d589f74eada4434 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 16 Mar 2024 17:15:26 +0200 Subject: [PATCH 044/473] Update package.json and package-lock.json --- package-lock.json | 13 ++++++++++--- package.json | 1 + 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/package-lock.json b/package-lock.json index 697265108..b6c78783d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,6 +15,7 @@ }, "devDependencies": { "@types/jest": "^29.5.1", + "@webgpu/types": "^0.1.40", "catharsis": "github:xenova/catharsis", "copy-webpack-plugin": "^11.0.0", "jest": "^29.5.0", @@ -2206,6 +2207,12 @@ "@xtuc/long": "4.2.2" } }, + "node_modules/@webgpu/types": { + "version": "0.1.40", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.40.tgz", + "integrity": "sha512-/BBkHLS6/eQjyWhY2H7Dx5DHcVrS2ICj9owvSRdgtQT6KcafLZA86tPze0xAOsd4FbsYKCUBUQyNi87q7gV7kw==", + "dev": true + }, "node_modules/@webpack-cli/configtest": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/@webpack-cli/configtest/-/configtest-2.0.1.tgz", @@ -3931,9 +3938,9 @@ "integrity": "sha512-c7CZADjRcl6j0PlvFy0ZqXQ67qSEZfrVPynmnL+2zPc+NtMvrF8Y0QceMo7QqnSPc7+uWjUIAbvCQ5WIKlMVdQ==" }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { diff --git a/package.json b/package.json index 17097f03f..ad7cd7ab4 100644 --- a/package.json +++ b/package.json @@ -47,6 +47,7 @@ }, "devDependencies": { "@types/jest": "^29.5.1", + "@webgpu/types": "^0.1.40", "catharsis": "github:xenova/catharsis", "copy-webpack-plugin": "^11.0.0", "jest": "^29.5.0", From 8370ac7297eee4d26ed607c07184349f3e4de1be Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 16 Mar 2024 17:17:46 +0200 Subject: [PATCH 045/473] Apply threshold to object detection models --- src/processors.js | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/processors.js b/src/processors.js index a7df4cd2e..0b3e5232c 100644 --- a/src/processors.js +++ b/src/processors.js @@ -115,10 +115,13 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes = // This is the background class, skip it continue; } - indices.push(maxIndex); - // Compute softmax over classes probs = softmax(logit.data); + + if (probs[maxIndex] < threshold) { + continue; + } + indices.push(maxIndex); } for (const index of indices) { From 375ba7cd6fbb8c5d758400cc4cccbd945c8b40e4 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 16 Mar 2024 17:18:08 +0200 Subject: [PATCH 046/473] Only pad if not already the correct size --- src/processors.js | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/processors.js b/src/processors.js index 0b3e5232c..b80ba90e0 100644 --- a/src/processors.js +++ b/src/processors.js @@ -649,7 +649,10 @@ export class ImageFeatureExtractor extends FeatureExtractor { } // do padding after rescaling/normalizing - if (this.do_pad && this.pad_size) { + if (this.do_pad && this.pad_size && ( + // only pad if not already the correct size + (this.pad_size.width !== image.width || this.pad_size.height !== image.height) + )) { const paddedPixelData = new Float32Array(this.pad_size.width * this.pad_size.height * image.channels); From 29cc6f5aa621802f06976f4afb0ecfc3f273aaed Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 16 Mar 2024 17:18:18 +0200 Subject: [PATCH 047/473] Typo --- src/processors.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/processors.js b/src/processors.js index b80ba90e0..f54de171c 100644 --- a/src/processors.js +++ b/src/processors.js @@ -650,7 +650,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { // do padding after rescaling/normalizing if (this.do_pad && this.pad_size && ( - // only pad if not already the correct size + // only pad if not already the correct size (this.pad_size.width !== image.width || this.pad_size.height !== image.height) )) { From 6b1b9cd145afde5680a6b6c54f87ee3ab674e0d1 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 18 Mar 2024 16:00:50 +0200 Subject: [PATCH 048/473] Update jsconfig.json --- jsconfig.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jsconfig.json b/jsconfig.json index 5430d98f2..d3e3df397 100644 --- a/jsconfig.json +++ b/jsconfig.json @@ -7,7 +7,7 @@ // Tells the compiler to check JS files "checkJs": true, "target": "esnext", - "module": "esnext", + "module": "nodenext", "moduleResolution": "nodenext", }, "typeAcquisition": { From 6dde38a2372e8db18d9e1fad5ba732905ee835cc Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 19 Mar 2024 00:08:40 +0200 Subject: [PATCH 049/473] Improve default selection of dtypes and devices --- src/backends/onnx.js | 42 +++++++++++++++++++++---------- src/models.js | 59 +++++++++++++++++++++++++------------------- src/utils/devices.js | 10 +++++++- src/utils/dtypes.js | 36 ++++++++++++++++++++++++--- src/utils/hub.js | 1 - 5 files changed, 104 insertions(+), 44 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index d5917d888..710557648 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -29,7 +29,10 @@ export { Tensor } from 'onnxruntime-common'; const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node'; +/** @type {import('../utils/devices.js').DeviceType[]} */ const supportedExecutionProviders = []; + +/** @type {import('../utils/devices.js').DeviceType[]} */ let defaultExecutionProviders; let ONNX; if (USE_ONNXRUNTIME_NODE) { @@ -49,25 +52,32 @@ if (USE_ONNXRUNTIME_NODE) { const InferenceSession = ONNX.InferenceSession; /** - * Create an ONNX inference session, with fallback support if an operation is not supported. - * @param {Uint8Array} buffer The ONNX model buffer. - * @param {Object} session_options ONNX inference session options. + * Map a device to the execution providers to use for the given device. * @param {import("../utils/devices.js").DeviceType} [device=null] (Optional) The device to run the inference on. - * @returns {Promise} The ONNX inference session. + * @returns {import("../utils/devices.js").DeviceType[]} The execution providers to use for the given device. */ -export async function createInferenceSession(buffer, session_options, device = null) { +export function deviceToExecutionProviders(device) { + // TODO: Use mapping from device to execution providers for overloaded devices (e.g., 'gpu' or 'cpu'). let executionProviders = defaultExecutionProviders; if (device) { // User has specified a device - if (supportedExecutionProviders.includes(device)) { - executionProviders = [device]; - } else { + if (!supportedExecutionProviders.includes(device)) { throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedExecutionProviders.join(', ')}.`) } + executionProviders = [device]; } + return executionProviders; +} + +/** + * Create an ONNX inference session. + * @param {Uint8Array} buffer The ONNX model buffer. + * @param {Object} session_options ONNX inference session options. + * @returns {Promise} The ONNX inference session. + */ +export async function createInferenceSession(buffer, session_options) { // NOTE: Important to create a clone, since ORT modifies the object. const options = { - executionProviders, ...session_options } @@ -92,13 +102,19 @@ if (ONNX_ENV?.wasm) { // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths // We use remote wasm files by default to make it easier for newer users. // In practice, users should probably self-host the necessary .wasm files. - ONNX_ENV.wasm.wasmPaths = RUNNING_LOCALLY - ? path.join(env.__dirname, '/dist/') - : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`; + // ONNX_ENV.wasm.wasmPaths = RUNNING_LOCALLY + // ? path.join(env.__dirname, '/dist/') + // : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`; + // TODO: update this before release + ONNX_ENV.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/'; // Proxy the WASM backend to prevent the UI from freezing ONNX_ENV.wasm.proxy = true; - // ONNX_ENV.wasm.numThreads = 1; // TODO is this needed? + + // https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated + if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) { + ONNX_ENV.wasm.numThreads = 1; + } // Running in a browser-environment // TODO: Check if 1.17.1 fixes this issue. diff --git a/src/models.js b/src/models.js index 8489f2fd8..e1a45781d 100644 --- a/src/models.js +++ b/src/models.js @@ -42,6 +42,18 @@ import { AutoConfig, } from './configs.js'; +import { + deviceToExecutionProviders, + createInferenceSession, + isONNXTensor, + isONNXProxy, +} from './backends/onnx.js'; +import { + DATA_TYPES, + DEFAULT_DEVICE_DTYPE_MAPPING, + DEFAULT_DTYPE_SUFFIX_MAPPING, FP16_SUPPORTED, +} from './utils/dtypes.js'; + import { Callable, isIntegralNumber, @@ -81,9 +93,7 @@ import { Tensor, } from './utils/tensor.js'; -import { createInferenceSession, isONNXTensor, isONNXProxy } from './backends/onnx.js'; -import { medianFilter } from './transformers.js'; -import { DATA_TYPES } from './utils/dtypes.js'; +import { medianFilter } from './utils/maths.js'; ////////////////////////////////////////////////// // Model types: used internally @@ -117,26 +127,29 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); */ async function constructSession(pretrained_model_name_or_path, fileName, options) { - // If options.dtype is specified, we use that for the variant. - // Otherwise, we use options.quantized to determine the variant. - let variant = ''; - if (options.dtype) { - if (!DATA_TYPES.hasOwnProperty(options.dtype)) { - throw new Error(`Invalid dtype: ${options.dtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`); - } - const dtype = DATA_TYPES[options.dtype]; - if (dtype !== 'fp32') { - variant = `_${options.dtype}`; - } - } else if (options.quantized) { - variant = '_quantized'; + // If the device is not specified, we use the default (supported) execution providers. + const executionProviders = deviceToExecutionProviders(options.device); + + // If options.dtype is specified, we use it to choose the suffix for the model file. + // Otherwise, we use the default dtype for the device. + const dtype = options.dtype ?? DEFAULT_DEVICE_DTYPE_MAPPING[executionProviders[0]]; + if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(dtype)) { + throw new Error(`Invalid dtype: ${dtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`); + } else if (dtype === DATA_TYPES.fp16 && !FP16_SUPPORTED) { + throw new Error(`The device does not support fp16.`); } - const modelFileName = `onnx/${fileName}${variant}.onnx`; + // Construct the model file name + const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[dtype]; + const modelFileName = `onnx/${fileName}${suffix}.onnx`; + const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); const session_options = options.session_options ?? {}; + // Overwrite `executionProviders` if not specified + session_options.executionProviders ??= executionProviders; + // handle onnx external data files if (session_options.externalData !== undefined) { for (let i = 0; i < session_options.externalData.length; ++i) { @@ -157,7 +170,7 @@ async function constructSession(pretrained_model_name_or_path, fileName, options // } // } - return await createInferenceSession(buffer, session_options, options.device); + return await createInferenceSession(buffer, session_options); } /** @@ -758,7 +771,6 @@ export class PreTrainedModel extends Callable { * @returns {Promise} A new instance of the `PreTrainedModel` class. */ static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, progress_callback = null, config = null, cache_dir = null, @@ -771,7 +783,6 @@ export class PreTrainedModel extends Callable { } = {}) { let options = { - quantized, progress_callback, config, cache_dir, @@ -4988,9 +4999,9 @@ export class SpeechT5Model extends SpeechT5PreTrainedModel { }; * const processor = await AutoProcessor.from_pretrained('Xenova/speecht5_tts'); * * // Load the models - * // NOTE: We use the unquantized versions as they are more accurate - * const model = await SpeechT5ForTextToSpeech.from_pretrained('Xenova/speecht5_tts', { quantized: false }); - * const vocoder = await SpeechT5HifiGan.from_pretrained('Xenova/speecht5_hifigan', { quantized: false }); + * // NOTE: We use the full-precision versions as they are more accurate + * const model = await SpeechT5ForTextToSpeech.from_pretrained('Xenova/speecht5_tts', { dtype: 'fp32' }); + * const vocoder = await SpeechT5HifiGan.from_pretrained('Xenova/speecht5_hifigan', { dtype: 'fp32' }); * * // Load speaker embeddings from URL * const speaker_embeddings_data = new Float32Array( @@ -5469,7 +5480,6 @@ export class PretrainedMixin { /** @type {PreTrainedModel.from_pretrained} */ static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, progress_callback = null, config = null, cache_dir = null, @@ -5482,7 +5492,6 @@ export class PretrainedMixin { } = {}) { let options = { - quantized, progress_callback, config, cache_dir, diff --git a/src/utils/devices.js b/src/utils/devices.js index 8a0a83dca..6ca0d00d1 100644 --- a/src/utils/devices.js +++ b/src/utils/devices.js @@ -1,3 +1,11 @@ + +export const DEVICE_TYPES = Object.freeze({ + cpu: 'cpu', + gpu: 'gpu', + wasm: 'wasm', + webgpu: 'webgpu', +}); + /** - * @typedef {'cpu'|'gpu'|'wasm'|'webgpu'|null} DeviceType + * @typedef {keyof typeof DEVICE_TYPES} DeviceType */ diff --git a/src/utils/dtypes.js b/src/utils/dtypes.js index 0f025dc59..268c8b60e 100644 --- a/src/utils/dtypes.js +++ b/src/utils/dtypes.js @@ -1,13 +1,41 @@ +import { DEVICE_TYPES } from "./devices.js"; + +// TODO: Use the adapter from `env.backends.onnx.webgpu.adapter` to check for `shader-f16` support, +// when available in https://github.com/microsoft/onnxruntime/pull/19940. +// For more information, see https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753 +async function isFp16Supported() { + try { + const adapter = await navigator.gpu.requestAdapter(); + return adapter.features.has('shader-f16'); + } catch (e) { + return false + } +} +export const FP16_SUPPORTED = await isFp16Supported(); export const DATA_TYPES = Object.freeze({ fp32: 'fp32', fp16: 'fp16', + q8: 'q8', int8: 'int8', uint8: 'uint8', - - // Aliases (same as torch.float32 and torch.float16) - float32: 'fp32', - float16: 'fp16', }); /** @typedef {keyof typeof DATA_TYPES} DataType */ + +const defaultGpuDtype = FP16_SUPPORTED ? DATA_TYPES.fp16 : DATA_TYPES.fp32; +export const DEFAULT_DEVICE_DTYPE_MAPPING = Object.freeze({ + [DEVICE_TYPES.cpu]: DATA_TYPES.q8, + [DEVICE_TYPES.gpu]: defaultGpuDtype, + [DEVICE_TYPES.wasm]: DATA_TYPES.q8, + [DEVICE_TYPES.webgpu]: defaultGpuDtype, +}); + +/** @type {Record} */ +export const DEFAULT_DTYPE_SUFFIX_MAPPING = Object.freeze({ + [DATA_TYPES.fp32]: '', + [DATA_TYPES.fp16]: '_fp16', + [DATA_TYPES.int8]: '_int8', + [DATA_TYPES.uint8]: '_uint8', + [DATA_TYPES.q8]: '_quantized', +}); diff --git a/src/utils/hub.js b/src/utils/hub.js index b3172d07d..32407bed6 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -26,7 +26,6 @@ import { dispatchCallback } from './core.js'; /** * @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model. - * @property {boolean?} [quantized=true] Whether to load the 8-bit quantized version of the model (only applicable when loading model files). * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. * @property {import("./devices.js").DeviceType} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. * @property {import("./dtypes.js").DataType} [dtype=null] The data type to use for the model. If not specified, the data type will be chosen from the environment settings. From 9efbada2b97fdb22fb0775ebbe30d93371cb83a3 Mon Sep 17 00:00:00 2001 From: Hans Date: Tue, 19 Mar 2024 06:27:24 +0800 Subject: [PATCH 050/473] Support 4-bit quantization in conversion script (#637) Co-authored-by: Joshua Lochner --- scripts/convert.py | 92 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 24 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index 3a2b223e8..f911f9244 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from typing import Optional, Set from tqdm import tqdm +from enum import Enum from transformers import ( AutoConfig, @@ -14,11 +15,14 @@ import onnx from optimum.exporters.onnx import main_export, export_models +from optimum.onnx.graph_transformations import check_and_save_model from optimum.exporters.tasks import TasksManager from onnxruntime.quantization import ( quantize_dynamic, QuantType ) +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer +from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer DEFAULT_QUANTIZE_PARAMS = { 'per_channel': True, @@ -134,6 +138,11 @@ 'unispeech-sat', ] +class QuantMode(Enum): + Q8 = 'q8' + Q4 = 'q4' + BNB4 = 'bnb4' + @dataclass class ConversionArguments: @@ -158,6 +167,12 @@ class ConversionArguments: "help": "Whether to quantize the model." } ) + quantize_mode: QuantMode = field( + default=QuantMode.Q8, + metadata={ + "help": "Quantization mode to use. Options are: int4, int8, bnb4" + } + ) output_parent_dir: str = field( default='./models/', metadata={ @@ -256,7 +271,7 @@ def traverse_graph(graph): return operators -def quantize(model_names_or_paths, **quantize_kwargs): +def quantize(mode, model_names_or_paths, **quantize_kwargs): """ Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU @@ -270,11 +285,13 @@ def quantize(model_names_or_paths, **quantize_kwargs): quantize_config = dict( **quantize_kwargs, + quantize_mode=mode.value, per_model_config={} ) + directory_path = os.path.dirname(model_names_or_paths[0]) + for model in tqdm(model_names_or_paths, desc='Quantizing'): - directory_path = os.path.dirname(model) file_name_without_extension = os.path.splitext( os.path.basename(model))[0] @@ -289,28 +306,54 @@ def quantize(model_names_or_paths, **quantize_kwargs): loaded_model = onnx.load_model(model) op_types = get_operators(loaded_model) - weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 - - quantize_dynamic( - model_input=model, - model_output=os.path.join( - directory_path, f'{file_name_without_extension}_quantized.onnx'), - - weight_type=weight_type, - optimize_model=False, - - # TODO allow user to specify these - # op_types_to_quantize=['MatMul', 'Add', 'Conv'], - extra_options=dict( - EnableSubgraph=True - ), - **quantize_kwargs - ) - quantize_config['per_model_config'][file_name_without_extension] = dict( - op_types=list(op_types), - weight_type=str(weight_type), - ) + save_path = os.path.join(directory_path, f'{file_name_without_extension}_quantized.onnx') + + if mode == QuantMode.Q8: + weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 + + del loaded_model + quantize_dynamic( + model_input=model, + model_output=save_path, + weight_type=weight_type, + #optimize_model=False, + + # TODO allow user to specify these + # op_types_to_quantize=['MatMul', 'Add', 'Conv'], + extra_options=dict( + EnableSubgraph=True + ), + ) + + quantize_config['per_model_config'][file_name_without_extension] = dict( + op_types=list(op_types), + weight_type=str(weight_type), + ) + + elif mode == QuantMode.Q4: + quant = MatMul4BitsQuantizer( + model=loaded_model, + block_size=quantize_kwargs.get('block_size', 32), + accuracy_level=quantize_kwargs.get('accuracy_level'), + is_symmetric=quantize_kwargs.get('is_symmetric', True), + ) + quant.process() + check_and_save_model(quant.model.model, save_path) + del quant + + elif mode == QuantMode.BNB4: + quant = MatMulBnb4Quantizer( + model=loaded_model, + block_size=quantize_kwargs.get('block_size', 64), + quant_type=quantize_kwargs.get('quant_type', MatMulBnb4Quantizer.NF4), + ) + quant.process() + check_and_save_model(quant.model.model, save_path) + del quant + + else: + raise ValueError(f'Invalid quantization mode: {mode}') # Save quantization config with open(os.path.join(directory_path, 'quantize_config.json'), 'w') as fp: @@ -375,6 +418,7 @@ def main(): opset=conv_args.opset, device=conv_args.device, trust_remote_code=conv_args.trust_remote_code, + legacy=True, **custom_kwargs, ) @@ -518,7 +562,7 @@ def main(): if conv_args.reduce_range is not None: quantize_config['reduce_range'] = conv_args.reduce_range - quantize([ + quantize(QuantMode(conv_args.quantize_mode), [ os.path.join(output_model_folder, x) for x in os.listdir(output_model_folder) if x.endswith('.onnx') and not x.endswith('_quantized.onnx') From ce75a0f832da02f47caa34a2e87591426aa88832 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 19 Mar 2024 00:44:36 +0200 Subject: [PATCH 051/473] Update conversion script --- scripts/convert.py | 122 +++++++++++++++++++++++++++------------ scripts/requirements.txt | 9 +-- 2 files changed, 89 insertions(+), 42 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index f911f9244..90ffc062b 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -138,8 +138,13 @@ 'unispeech-sat', ] + class QuantMode(Enum): + # F32 = 'fp32' + FP16 = 'fp16' Q8 = 'q8' + QI8 = 'int8' + QU8 = 'uint8' Q4 = 'q4' BNB4 = 'bnb4' @@ -168,9 +173,9 @@ class ConversionArguments: } ) quantize_mode: QuantMode = field( - default=QuantMode.Q8, + default=None, metadata={ - "help": "Quantization mode to use. Options are: int4, int8, bnb4" + "help": f"Quantization mode to use. Options are: {', '.join([x.value for x in QuantMode])}" } ) output_parent_dir: str = field( @@ -285,18 +290,21 @@ def quantize(mode, model_names_or_paths, **quantize_kwargs): quantize_config = dict( **quantize_kwargs, - quantize_mode=mode.value, per_model_config={} ) directory_path = os.path.dirname(model_names_or_paths[0]) + outputs = [] for model in tqdm(model_names_or_paths, desc='Quantizing'): file_name_without_extension = os.path.splitext( os.path.basename(model))[0] # NOTE: - # As of 2023/04/20, the current latest version of onnxruntime-web is 1.14.0, and does not support INT8 weights for Conv layers. + # As of 2024/03/18, the current latest version of onnxruntime-web is 1.17.1, and does not support INT8 weights for Conv layers. + # If you attempt to run a model with INT8 weights for Conv layers, you will get an error like: + # `Can't create a session. ERROR_CODE: 9, ERROR_MESSAGE: Could not find an implementation for ConvInteger(10) node with name '/.../Conv_quant'` + # # For this reason, we choose model weight types to ensure compatibility with onnxruntime-web. # # As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer. @@ -305,19 +313,27 @@ def quantize(mode, model_names_or_paths, **quantize_kwargs): # - https://github.com/microsoft/onnxruntime/issues/2339 loaded_model = onnx.load_model(model) - op_types = get_operators(loaded_model) + suffix = 'quantized' if mode == QuantMode.Q8 else mode.value + save_path = os.path.join( + directory_path, + f'{file_name_without_extension}_{suffix}.onnx', + ) - save_path = os.path.join(directory_path, f'{file_name_without_extension}_quantized.onnx') + if mode in (QuantMode.Q8, QuantMode.QI8, QuantMode.QU8): - if mode == QuantMode.Q8: - weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 + op_types = get_operators(loaded_model) + if mode == QuantMode.Q8: + weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 + elif mode == QuantMode.QI8: + weight_type = QuantType.QInt8 + else: # mode == QuantMode.QU8: + weight_type = QuantType.QUInt8 del loaded_model quantize_dynamic( model_input=model, model_output=save_path, weight_type=weight_type, - #optimize_model=False, # TODO allow user to specify these # op_types_to_quantize=['MatMul', 'Add', 'Conv'], @@ -334,30 +350,33 @@ def quantize(mode, model_names_or_paths, **quantize_kwargs): elif mode == QuantMode.Q4: quant = MatMul4BitsQuantizer( model=loaded_model, - block_size=quantize_kwargs.get('block_size', 32), + block_size=quantize_kwargs.get('block_size', 128), accuracy_level=quantize_kwargs.get('accuracy_level'), is_symmetric=quantize_kwargs.get('is_symmetric', True), ) quant.process() check_and_save_model(quant.model.model, save_path) del quant - + elif mode == QuantMode.BNB4: quant = MatMulBnb4Quantizer( model=loaded_model, block_size=quantize_kwargs.get('block_size', 64), - quant_type=quantize_kwargs.get('quant_type', MatMulBnb4Quantizer.NF4), + quant_type=quantize_kwargs.get( + 'quant_type', MatMulBnb4Quantizer.NF4), ) quant.process() check_and_save_model(quant.model.model, save_path) del quant - + + elif mode == QuantMode.FP16: + pass else: raise ValueError(f'Invalid quantization mode: {mode}') - # Save quantization config - with open(os.path.join(directory_path, 'quantize_config.json'), 'w') as fp: - json.dump(quantize_config, fp, indent=4) + outputs.append(save_path) + + return quantize_config, outputs def main(): @@ -382,10 +401,11 @@ def main(): # Saving the model config config = AutoConfig.from_pretrained(model_id, **from_pretrained_kwargs) - custom_kwargs={} + custom_kwargs = {} if conv_args.custom_onnx_configs is not None: if conv_args.task == 'auto': - raise Exception('`--task` must be set when exporting with `--custom_onnx_configs`') + raise Exception( + '`--task` must be set when exporting with `--custom_onnx_configs`') custom_onnx_configs = json.loads(conv_args.custom_onnx_configs) for key in custom_onnx_configs: @@ -398,14 +418,16 @@ def main(): tokenizer = None try: # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, **from_pretrained_kwargs) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_id, **from_pretrained_kwargs) # To avoid inserting all chat templates into tokenizers.js, we save the chat template # to the tokenizer_config.json file, and load it when the tokenizer is loaded. if getattr(tokenizer, 'chat_template', None) is None and \ - getattr(tokenizer, 'use_default_system_prompt', False): + getattr(tokenizer, 'use_default_system_prompt', False): # No chat template specified, and we use the default - setattr(tokenizer, 'chat_template', tokenizer.default_chat_template) + setattr(tokenizer, 'chat_template', + tokenizer.default_chat_template) except KeyError: pass # No Tokenizer @@ -442,7 +464,8 @@ def main(): elif config.model_type == 'esm': from .extra.esm import generate_fast_tokenizer fast_tokenizer = generate_fast_tokenizer(tokenizer) - fast_tokenizer.save(os.path.join(output_model_folder, 'tokenizer.json')) + fast_tokenizer.save(os.path.join( + output_model_folder, 'tokenizer.json')) elif config.model_type == 'whisper': if conv_args.output_attentions: @@ -452,14 +475,14 @@ def main(): **get_main_export_kwargs(config, "automatic-speech-recognition") ) - elif config.model_type in ('wav2vec2', 'wav2vec2-bert', 'hubert', 'unispeech' , 'unispeech-sat'): + elif config.model_type in ('wav2vec2', 'wav2vec2-bert', 'hubert', 'unispeech', 'unispeech-sat'): if tokenizer is not None: from .extra.wav2vec2 import generate_tokenizer_json tokenizer_json = generate_tokenizer_json(tokenizer) with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: json.dump(tokenizer_json, fp, indent=4) - + elif config.model_type == 'vits': if tokenizer is not None: from .extra.vits import generate_tokenizer_json @@ -467,10 +490,11 @@ def main(): with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: json.dump(tokenizer_json, fp, indent=4) - + elif config.model_type == 'speecht5': # TODO allow user to specify vocoder path - export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"} + export_kwargs["model_kwargs"] = { + "vocoder": "microsoft/speecht5_hifigan"} if tokenizer is not None: from .extra.speecht5 import generate_tokenizer_json @@ -501,8 +525,10 @@ def main(): from .extra.clip import CLIPTextModelWithProjectionOnnxConfig, CLIPVisionModelWithProjectionOnnxConfig from transformers.models.clip import CLIPTextModelWithProjection, CLIPVisionModelWithProjection - text_model = CLIPTextModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs) - vision_model = CLIPVisionModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs) + text_model = CLIPTextModelWithProjection.from_pretrained( + model_id, **from_pretrained_kwargs) + vision_model = CLIPVisionModelWithProjection.from_pretrained( + model_id, **from_pretrained_kwargs) export_models( models_and_onnx_configs={ @@ -517,8 +543,10 @@ def main(): from .extra.siglip import SiglipTextModelOnnxConfig, SiglipVisionModelOnnxConfig from transformers.models.siglip import SiglipTextModel, SiglipVisionModel - text_model = SiglipTextModel.from_pretrained(model_id, **from_pretrained_kwargs) - vision_model = SiglipVisionModel.from_pretrained(model_id, **from_pretrained_kwargs) + text_model = SiglipTextModel.from_pretrained( + model_id, **from_pretrained_kwargs) + vision_model = SiglipVisionModel.from_pretrained( + model_id, **from_pretrained_kwargs) export_models( models_and_onnx_configs={ @@ -546,8 +574,10 @@ def main(): # ) else: - raise Exception(f'Unable to export {config.model_type} model with `--split_modalities`.') + raise Exception( + f'Unable to export {config.model_type} model with `--split_modalities`.') + os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True) # Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size. if conv_args.quantize: @@ -562,14 +592,29 @@ def main(): if conv_args.reduce_range is not None: quantize_config['reduce_range'] = conv_args.reduce_range - quantize(QuantMode(conv_args.quantize_mode), [ - os.path.join(output_model_folder, x) - for x in os.listdir(output_model_folder) - if x.endswith('.onnx') and not x.endswith('_quantized.onnx') - ], **quantize_config) + final_quantize_configs = {} + quantize_modes = [x.value for x in QuantMode] \ + if conv_args.quantize_mode is None else [conv_args.quantize_mode] + for quantize_mode in quantize_modes: + + final_quantize_config, quantized_paths = quantize(QuantMode(quantize_mode), [ + os.path.join(output_model_folder, x) + for x in os.listdir(output_model_folder) + if x.endswith('.onnx') + ], **quantize_config) + + final_quantize_configs[quantize_mode] = final_quantize_config + + for path in quantized_paths: + file_name = os.path.basename(path) + shutil.move(path, + os.path.join(output_model_folder, 'onnx', file_name)) + + # Save quantization config + with open(os.path.join(output_model_folder, 'quantize_config.json'), 'w') as fp: + json.dump(final_quantize_configs, fp, indent=4) # Step 3. Move .onnx files to the 'onnx' subfolder - os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True) for file in os.listdir(output_model_folder): if file.endswith(('.onnx', '.onnx_data')): shutil.move(os.path.join(output_model_folder, file), @@ -580,7 +625,8 @@ def main(): from transformers import GenerationConfig from .extra.whisper import get_alignment_heads - generation_config = GenerationConfig.from_pretrained(model_id, **from_pretrained_kwargs) + generation_config = GenerationConfig.from_pretrained( + model_id, **from_pretrained_kwargs) generation_config.alignment_heads = get_alignment_heads(config) generation_config.save_pretrained(output_model_folder) diff --git a/scripts/requirements.txt b/scripts/requirements.txt index f0b3867ae..3adf88030 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -1,5 +1,6 @@ -transformers[torch]==4.33.2 -onnxruntime<1.16.0 -optimum==1.13.2 +transformers[torch]==4.38.2 +onnxruntime==1.17.1 +optimum==1.17.1 +onnx==1.15.0 +onnxconverter-common==1.14.0 tqdm -onnx==1.13.1 From d60515dd0de90b9b341dc522cc7951beb76263ad Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 19 Mar 2024 01:45:14 +0200 Subject: [PATCH 052/473] Update convert.py --- scripts/convert.py | 172 ++++++++++++++++++++++++++++++++------------- 1 file changed, 124 insertions(+), 48 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index 90ffc062b..9eff98928 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -23,10 +23,11 @@ ) from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer +from onnxconverter_common import float16 + + +PER_CHANNEL_REDUCE_RANGE_MODELS = { -DEFAULT_QUANTIZE_PARAMS = { - 'per_channel': True, - 'reduce_range': True, } MODEL_SPECIFIC_QUANTIZE_PARAMS = { @@ -229,6 +230,33 @@ class ConversionArguments: "help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode" } ) + block_size: int = field( + default=None, + metadata={ + "help": "Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64" + } + ) + quant_type: int = field( + default=MatMulBnb4Quantizer.NF4, + choices=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], + metadata={ + "help": "Quantization data type. 0: FP4, 1: NF4" + } + ) + symmetric: bool = field( + default=True, + metadata={ + "help": "Indicate whether to quantize the model symmetrically" + } + ) + accuracy_level: int = field( + default=None, + metadata={ + "help": "Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits)." + } + ) output_attentions: bool = field( default=False, @@ -276,41 +304,36 @@ def traverse_graph(graph): return operators -def quantize(mode, model_names_or_paths, **quantize_kwargs): +def quantize( + mode, model_names_or_paths, *, + + # 8-bit quantization + per_channel: bool = True, + reduce_range: bool = True, + + # 4-bit quantization + block_size: int | None = None, + + # MatMul4BitsQuantizer + is_symmetric: bool = True, + accuracy_level: int | None = None, + + # MatMulBnb4Quantizer + quant_type: int | None = None, +): """ - Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU - - Uses unsigned ints for activation values, signed ints for weights, per - https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection - it is faster on most CPU architectures - Args: - onnx_model_path: Path to location the exported ONNX model is stored - Returns: The Path generated for the quantized + Quantize the weights of the model (e.g., from float32 to int8) to allow for more efficient inference. """ - quantize_config = dict( - **quantize_kwargs, - per_model_config={} - ) + quantize_config = {} directory_path = os.path.dirname(model_names_or_paths[0]) outputs = [] for model in tqdm(model_names_or_paths, desc='Quantizing'): file_name_without_extension = os.path.splitext( - os.path.basename(model))[0] - - # NOTE: - # As of 2024/03/18, the current latest version of onnxruntime-web is 1.17.1, and does not support INT8 weights for Conv layers. - # If you attempt to run a model with INT8 weights for Conv layers, you will get an error like: - # `Can't create a session. ERROR_CODE: 9, ERROR_MESSAGE: Could not find an implementation for ConvInteger(10) node with name '/.../Conv_quant'` - # - # For this reason, we choose model weight types to ensure compatibility with onnxruntime-web. - # - # As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer. - # For more information, see: - # - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621 - # - https://github.com/microsoft/onnxruntime/issues/2339 + os.path.basename(model) + )[0] loaded_model = onnx.load_model(model) suffix = 'quantized' if mode == QuantMode.Q8 else mode.value @@ -319,17 +342,40 @@ def quantize(mode, model_names_or_paths, **quantize_kwargs): f'{file_name_without_extension}_{suffix}.onnx', ) + quantize_kwargs = {} + if mode in (QuantMode.Q8, QuantMode.QI8, QuantMode.QU8): + quantize_kwargs.update( + per_channel=per_channel, + reduce_range=reduce_range, + ) op_types = get_operators(loaded_model) if mode == QuantMode.Q8: + # NOTE: + # As of 2024/03/18, the current latest version of onnxruntime-web is 1.17.1, and does not support INT8 weights for Conv layers. + # If you attempt to run a model with INT8 weights for Conv layers, you will get an error like: + # `Can't create a session. ERROR_CODE: 9, ERROR_MESSAGE: Could not find an implementation for ConvInteger(10) node with name '/.../Conv_quant'` + # + # For this reason, we choose model weight types to ensure compatibility with onnxruntime-web. + # + # As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer. + # For more information, see: + # - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621 + # - https://github.com/microsoft/onnxruntime/issues/2339 weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 + elif mode == QuantMode.QI8: weight_type = QuantType.QInt8 + else: # mode == QuantMode.QU8: weight_type = QuantType.QUInt8 del loaded_model + + # Uses unsigned ints for activation values, signed ints for weights, per + # https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection + # it is faster on most CPU architectures quantize_dynamic( model_input=model, model_output=save_path, @@ -340,40 +386,60 @@ def quantize(mode, model_names_or_paths, **quantize_kwargs): extra_options=dict( EnableSubgraph=True ), + + **quantize_kwargs, ) + if 'per_model_config' not in quantize_config: + quantize_config['per_model_config'] = {} + quantize_config['per_model_config'][file_name_without_extension] = dict( - op_types=list(op_types), + op_types=sorted(list(op_types)), weight_type=str(weight_type), ) elif mode == QuantMode.Q4: - quant = MatMul4BitsQuantizer( + block_size = block_size if block_size is not None else 32 + quantize_kwargs.update( + block_size=block_size, + is_symmetric=is_symmetric, + accuracy_level=accuracy_level, + ) + quantizer = MatMul4BitsQuantizer( model=loaded_model, - block_size=quantize_kwargs.get('block_size', 128), - accuracy_level=quantize_kwargs.get('accuracy_level'), - is_symmetric=quantize_kwargs.get('is_symmetric', True), + **quantize_kwargs, ) - quant.process() - check_and_save_model(quant.model.model, save_path) - del quant + quantizer.process() + check_and_save_model(quantizer.model.model, save_path) + del quantizer elif mode == QuantMode.BNB4: - quant = MatMulBnb4Quantizer( + block_size = block_size if block_size is not None else 64 + quant_type = quant_type if quant_type is not None else MatMulBnb4Quantizer.NF4 + quantize_kwargs.update( + block_size=block_size, + quant_type=quant_type, + ) + + quantizer = MatMulBnb4Quantizer( model=loaded_model, - block_size=quantize_kwargs.get('block_size', 64), - quant_type=quantize_kwargs.get( - 'quant_type', MatMulBnb4Quantizer.NF4), + **quantize_kwargs, ) - quant.process() - check_and_save_model(quant.model.model, save_path) - del quant + quantizer.process() + check_and_save_model(quantizer.model.model, save_path) + del quantizer elif mode == QuantMode.FP16: - pass + model_fp16 = float16.convert_float_to_float16( + loaded_model, + keep_io_types=True, + ) + onnx.save(model_fp16, save_path) + else: raise ValueError(f'Invalid quantization mode: {mode}') + quantize_config.update(quantize_kwargs) outputs.append(save_path) return quantize_config, outputs @@ -582,15 +648,25 @@ def main(): # Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size. if conv_args.quantize: # Update quantize config with model specific defaults - quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get( - config.model_type, DEFAULT_QUANTIZE_PARAMS) + use_per_channel_reduce_range = config.model_type in PER_CHANNEL_REDUCE_RANGE_MODELS - # Update if user specified values + quantize_config = {} + + # Update with user-specified values if conv_args.per_channel is not None: quantize_config['per_channel'] = conv_args.per_channel + elif use_per_channel_reduce_range: + quantize_config['per_channel'] = True if conv_args.reduce_range is not None: quantize_config['reduce_range'] = conv_args.reduce_range + elif use_per_channel_reduce_range: + quantize_config['reduce_range'] = True + + quantize_config['block_size'] = conv_args.block_size + quantize_config['quant_type'] = conv_args.quant_type + quantize_config['is_symmetric'] = conv_args.symmetric + quantize_config['accuracy_level'] = conv_args.accuracy_level final_quantize_configs = {} quantize_modes = [x.value for x in QuantMode] \ From 8f33adebaec315459797690f1e6388e62c76d37d Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 19 Mar 2024 01:49:04 +0200 Subject: [PATCH 053/473] Update convert.py --- scripts/convert.py | 119 +++++++++------------------------------------ 1 file changed, 23 insertions(+), 96 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index 9eff98928..43e5514a8 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -27,107 +27,34 @@ PER_CHANNEL_REDUCE_RANGE_MODELS = { - -} - -MODEL_SPECIFIC_QUANTIZE_PARAMS = { # Decoder-only models - 'codegen': { - 'per_channel': False, - 'reduce_range': False, - }, - 'gpt2': { - 'per_channel': False, - 'reduce_range': False, - }, - 'gpt_bigcode': { - 'per_channel': False, - 'reduce_range': False, - }, - 'gptj': { - 'per_channel': False, - 'reduce_range': False, - }, - 'gpt-neo': { - 'per_channel': False, - 'reduce_range': False, - }, - 'gpt-neox': { - 'per_channel': False, - 'reduce_range': False, - }, - 'mpt': { - 'per_channel': False, - 'reduce_range': False, - }, - 'bloom': { - 'per_channel': False, - 'reduce_range': False, - }, - 'llama': { - 'per_channel': False, - 'reduce_range': False, - }, - 'opt': { - 'per_channel': False, - 'reduce_range': False, - }, - 'mistral': { - 'per_channel': False, - 'reduce_range': False, - }, - 'falcon': { - 'per_channel': False, - 'reduce_range': False, - }, - 'phi': { - 'per_channel': False, - 'reduce_range': False, - }, - 'qwen2': { - 'per_channel': False, - 'reduce_range': False, - }, - 'stablelm': { - 'per_channel': False, - 'reduce_range': False, - }, - 'starcoder2': { - 'per_channel': False, - 'reduce_range': False, - }, + 'codegen', + 'gpt2', + 'gpt_bigcode', + 'gptj', + 'gpt-neo', + 'gpt-neox', + 'mpt', + 'bloom', + 'llama', + 'opt', + 'mistral', + 'falcon', + 'phi', + 'qwen2', + 'stablelm', + 'starcoder2', # Encoder-decoder models - 'whisper': { - 'per_channel': False, - 'reduce_range': False, - }, - 'vision-encoder-decoder': { - 'per_channel': False, - 'reduce_range': False, - }, + 'whisper', + 'vision-encoder-decoder', # Encoder-only models - 'owlv2': { - 'per_channel': False, - 'reduce_range': False, - }, - 'wavlm': { - 'per_channel': False, - 'reduce_range': False, - }, - 'wav2vec2': { - 'per_channel': False, - 'reduce_range': False, - }, - 'unispeech': { - 'per_channel': False, - 'reduce_range': False, - }, - 'unispeech-sat': { - 'per_channel': False, - 'reduce_range': False, - }, + 'owlv2', + 'wavlm', + 'wav2vec2', + 'unispeech', + 'unispeech-sat', } MODELS_WITHOUT_TOKENIZERS = [ From a6db3a569bbdc2b8cee2e069c88e15fd243f60ad Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 20 Mar 2024 17:23:18 +0200 Subject: [PATCH 054/473] Move choices to metadata in `field()` --- scripts/convert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index 43e5514a8..76f921bb4 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -165,9 +165,9 @@ class ConversionArguments: ) quant_type: int = field( default=MatMulBnb4Quantizer.NF4, - choices=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], metadata={ - "help": "Quantization data type. 0: FP4, 1: NF4" + "help": "Quantization data type. 0: FP4, 1: NF4", + "choices": [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], } ) symmetric: bool = field( From 08a9b4077993e0ca4426d2e3b067198774003800 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 20 Mar 2024 23:09:46 +0200 Subject: [PATCH 055/473] Surround quantization step with try-except block --- scripts/convert.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index 76f921bb4..d8e9f1992 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -599,19 +599,22 @@ def main(): quantize_modes = [x.value for x in QuantMode] \ if conv_args.quantize_mode is None else [conv_args.quantize_mode] for quantize_mode in quantize_modes: - - final_quantize_config, quantized_paths = quantize(QuantMode(quantize_mode), [ - os.path.join(output_model_folder, x) - for x in os.listdir(output_model_folder) - if x.endswith('.onnx') - ], **quantize_config) - - final_quantize_configs[quantize_mode] = final_quantize_config - - for path in quantized_paths: - file_name = os.path.basename(path) - shutil.move(path, - os.path.join(output_model_folder, 'onnx', file_name)) + try: + final_quantize_config, quantized_paths = quantize(QuantMode(quantize_mode), [ + os.path.join(output_model_folder, x) + for x in os.listdir(output_model_folder) + if x.endswith('.onnx') + ], **quantize_config) + + final_quantize_configs[quantize_mode] = final_quantize_config + + for path in quantized_paths: + file_name = os.path.basename(path) + shutil.move(path, + os.path.join(output_model_folder, 'onnx', file_name)) + except Exception as e: + print( + f'Failed to quantize model with mode {quantize_mode}: {e}') # Save quantization config with open(os.path.join(output_model_folder, 'quantize_config.json'), 'w') as fp: From 11bed38814739a413bad689c5b0586b1674bf1bb Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 21 Mar 2024 00:46:40 +0200 Subject: [PATCH 056/473] Invert per channel settings --- scripts/convert.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index d8e9f1992..8cab6f586 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -26,7 +26,7 @@ from onnxconverter_common import float16 -PER_CHANNEL_REDUCE_RANGE_MODELS = { +NO_PER_CHANNEL_REDUCE_RANGE_MODELS = { # Decoder-only models 'codegen', 'gpt2', @@ -433,7 +433,7 @@ def main(): opset=conv_args.opset, device=conv_args.device, trust_remote_code=conv_args.trust_remote_code, - legacy=True, + legacy=True, # TODO: remove this when transformers.js config is updated **custom_kwargs, ) @@ -575,7 +575,7 @@ def main(): # Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size. if conv_args.quantize: # Update quantize config with model specific defaults - use_per_channel_reduce_range = config.model_type in PER_CHANNEL_REDUCE_RANGE_MODELS + use_per_channel_reduce_range = config.model_type in NO_PER_CHANNEL_REDUCE_RANGE_MODELS quantize_config = {} @@ -583,12 +583,12 @@ def main(): if conv_args.per_channel is not None: quantize_config['per_channel'] = conv_args.per_channel elif use_per_channel_reduce_range: - quantize_config['per_channel'] = True + quantize_config['per_channel'] = False if conv_args.reduce_range is not None: quantize_config['reduce_range'] = conv_args.reduce_range elif use_per_channel_reduce_range: - quantize_config['reduce_range'] = True + quantize_config['reduce_range'] = False quantize_config['block_size'] = conv_args.block_size quantize_config['quant_type'] = conv_args.quant_type From d5cfd1c482636dbf6cf36b6ae9a0203a57366da4 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Mar 2024 00:59:35 +0200 Subject: [PATCH 057/473] Update pipelines.js --- src/pipelines.js | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index e421a28a6..60f3c960e 100755 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -2600,7 +2600,7 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi // Load vocoder, if not provided if (!this.vocoder) { console.log('No vocoder specified, using default HifiGan vocoder.'); - this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { quantized: false }); + this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' }); } // Load speaker embeddings as Float32Array from path/URL @@ -3101,7 +3101,6 @@ export async function pipeline( task, model = null, { - quantized = true, progress_callback = null, config = null, cache_dir = null, @@ -3131,7 +3130,6 @@ export async function pipeline( } const pretrainedOptions = { - quantized, progress_callback, config, cache_dir, From e1809b9ffa7b7b7ed2f99d8fb5ad8eff96caa35a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Mar 2024 01:12:37 +0200 Subject: [PATCH 058/473] Update package-lock.json --- package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/package-lock.json b/package-lock.json index 7809f299b..2b50fdf0c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -7836,9 +7836,9 @@ } }, "node_modules/webpack-dev-middleware": { - "version": "5.3.3", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.3.tgz", - "integrity": "sha512-hj5CYrY0bZLB+eTO+x/j67Pkrquiy7kWepMHmUMoPsmcUaeEnQJqFzHJOyxgWlq746/wUuA64p9ta34Kyb01pA==", + "version": "5.3.4", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", + "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", "dev": true, "dependencies": { "colorette": "^2.0.10", From aadeed936af7cc9049e1d238fd4ec70545dfd508 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Mar 2024 01:16:02 +0200 Subject: [PATCH 059/473] Enable top-level await (Closes #657) --- webpack.config.js | 1 + 1 file changed, 1 insertion(+) diff --git a/webpack.config.js b/webpack.config.js index ab7549932..b0751cced 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -63,6 +63,7 @@ function buildConfig({ }, experiments: { outputModule, + topLevelAwait: true, }, resolve: { alias }, From 4b3abfed2e5e534387d30b5de598d1485618303b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 30 Mar 2024 02:35:03 +0200 Subject: [PATCH 060/473] Add `tensor.mean()` function --- src/utils/tensor.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 91b2bda9c..38c2baee6 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -625,6 +625,10 @@ export class Tensor { return this.clone().round_(); } + mean(dim = null, keepdim = false) { + return mean(this, dim, keepdim); + } + /** * Performs Tensor dtype conversion. * @param {DataType} type The desired data type. From 1fc2aa0917c1f8550545b1560defd74e681ad207 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 3 Apr 2024 02:33:51 +0200 Subject: [PATCH 061/473] Update musicgen conversion script --- scripts/convert.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/convert.py b/scripts/convert.py index 8cab6f586..df33eac32 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -48,6 +48,7 @@ # Encoder-decoder models 'whisper', 'vision-encoder-decoder', + 'musicgen', # Encoder-only models 'owlv2', @@ -429,11 +430,14 @@ def main(): if config.model_type not in MODELS_WITHOUT_TOKENIZERS: raise e + # TODO: remove this when transformers.js config is updated + use_legacy = config.model_type not in ('musicgen', ) + core_export_kwargs = dict( opset=conv_args.opset, device=conv_args.device, trust_remote_code=conv_args.trust_remote_code, - legacy=True, # TODO: remove this when transformers.js config is updated + legacy=use_legacy, **custom_kwargs, ) From 1515923958498ee469235fbc215ac991451c1ee3 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 3 Apr 2024 03:45:40 +0200 Subject: [PATCH 062/473] Huge `.generate()` refactor - Use object instead of positional args - Batched generation - Improved configuration handling - Add Musicgen --- src/generation/configuration_utils.js | 367 +++++ src/generation/logits_process.js | 624 ++++++++ src/generation/logits_sampler.js | 215 +++ src/generation/parameters.js | 26 + src/generation/stopping_criteria.js | 124 ++ src/models.js | 1873 ++++++++++++++++--------- src/pipelines.js | 60 +- src/processors.js | 5 + src/tokenizers.js | 17 +- src/utils/core.js | 51 +- src/utils/generation.js | 873 ------------ src/utils/generic.js | 35 + src/utils/hub.js | 2 + src/utils/tensor.js | 63 +- 14 files changed, 2734 insertions(+), 1601 deletions(-) create mode 100644 src/generation/configuration_utils.js create mode 100644 src/generation/logits_process.js create mode 100644 src/generation/logits_sampler.js create mode 100644 src/generation/parameters.js create mode 100644 src/generation/stopping_criteria.js delete mode 100644 src/utils/generation.js create mode 100644 src/utils/generic.js diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js new file mode 100644 index 000000000..428b079a3 --- /dev/null +++ b/src/generation/configuration_utils.js @@ -0,0 +1,367 @@ +import { pick } from "../utils/core.js"; + +/** + * Class that holds a configuration for a generation task. + */ +export class GenerationConfig { + // Parameters that control the length of the output + /** + * The maximum length the generated tokens can have. + * Corresponds to the length of the input prompt + `max_new_tokens`. + * Its effect is overridden by `max_new_tokens`, if also set. + * @type {number} + * @default 20 + */ + max_length = 20; + + /** + * The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. + * @type {number} + * @default null + */ + max_new_tokens = null; + + /** + * The minimum length of the sequence to be generated. + * Corresponds to the length of the input prompt + `min_new_tokens`. + * Its effect is overridden by `min_new_tokens`, if also set. + * @type {number} + * @default 0 + */ + min_length = 0; + + /** + * The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. + * @type {number} + * @default null + */ + min_new_tokens = null; + + /** + * Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: + * - `true`, where the generation stops as soon as there are `num_beams` complete candidates; + * - `false`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; + * - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). + * @type {boolean|"never"} + * @default false + */ + early_stopping = false; + + /** + * The maximum amount of time you allow the computation to run for in seconds. + * Generation will still finish the current pass after allocated time has been passed. + * @type {number} + * @default null + */ + max_time = null; + + // Parameters that control the generation strategy used + /** + * Whether or not to use sampling; use greedy decoding otherwise. + * @type {boolean} + * @default false + */ + do_sample = false; + + /** + * Number of beams for beam search. 1 means no beam search. + * @type {number} + * @default 1 + */ + num_beams = 1; + + /** + * Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + * See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + * @type {number} + * @default 1 + */ + num_beam_groups = 1; + + /** + * The values balance the model confidence and the degeneration penalty in contrastive search decoding. + * @type {number} + * @default null + */ + penalty_alpha = null; + + /** + * Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + * @type {boolean} + * @default true + */ + use_cache = true; + + // Parameters for manipulation of the model output logits + /** + * The value used to modulate the next token probabilities. + * @type {number} + * @default 1.0 + */ + temperature = 1.0; + + /** + * The number of highest probability vocabulary tokens to keep for top-k-filtering. + * @type {number} + * @default 50 + */ + top_k = 50; + + /** + * If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + * @type {number} + * @default 1.0 + */ + top_p = 1.0; + + /** + * Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. + * If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. + * See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. + * @type {number} + * @default 1.0 + */ + typical_p = 1.0; + + /** + * If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. + * In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. + * See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. + * @type {number} + * @default 0.0 + */ + epsilon_cutoff = 0.0; + + /** + * Eta sampling is a hybrid of locally typical sampling and epsilon sampling. + * If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`. + * The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + * See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. + * @type {number} + * @default 0.0 + */ + eta_cutoff = 0.0; + + /** + * This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. + * Note that `diversity_penalty` is only effective if `group beam search` is enabled. + * @type {number} + * @default 0.0 + */ + diversity_penalty = 0.0; + + /** + * The parameter for repetition penalty. 1.0 means no penalty. + * See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + * @type {number} + * @default 1.0 + */ + repetition_penalty = 1.0; + + /** + * The paramater for encoder_repetition_penalty. + * An exponential penalty on sequences that are not in the original input. + * 1.0 means no penalty. + * @type {number} + * @default 1.0 + */ + encoder_repetition_penalty = 1.0; + + /** + * Exponential penalty to the length that is used with beam-based generation. + * It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. + * Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. + * @type {number} + * @default 1.0 + */ + length_penalty = 1.0; + + /** + * If set to int > 0, all ngrams of that size can only occur once. + * @type {number} + * @default 0 + */ + no_repeat_ngram_size = 0; + + /** + * List of token ids that are not allowed to be generated. + * In order to get the token ids of the words that should not appear in the generated text, use + * `tokenizer(bad_words, { add_prefix_space: true, add_special_tokens: false }).input_ids`. + * @type {number[][]} + * @default null + */ + bad_words_ids = null; + + /** + * List of token ids that must be generated. + * If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. + * If given `number[][][]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. + * @type {number[][]|number[][][]} + * @default null + */ + force_words_ids = null; + + /** + * Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). + * It's highly recommended to set this flag to `true` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. + * @type {boolean} + * @default false + */ + renormalize_logits = false; + + /** + * Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible. + * @type {Object[]} + * @default null + */ + constraints = null; + + /** + * The id of the token to force as the first generated token after the `decoder_start_token_id`. + * Useful for multilingual models like mBART where the first generated token needs to be the target language token. + * @type {number} + * @default null + */ + forced_bos_token_id = null; + + /** + * The id of the token to force as the last generated token when `max_length` is reached. + * Optionally, use a list to set multiple *end-of-sequence* tokens. + * @type {number|number[]} + * @default null + */ + forced_eos_token_id = null; + + /** + * Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation. + * @type {boolean} + */ + remove_invalid_values = false; + + /** + * This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. + * The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay. + * @type {[number, number]} + * @default null + */ + exponential_decay_length_penalty = null; + + /** + * A list of tokens that will be suppressed at generation. + * The `SuppressTokens` logit processor will set their log probs to `-inf` so that they are not sampled. + * @type {number[]} + * @default null + */ + suppress_tokens = null; + + /** + * A list of tokens that will be suppressed at the beginning of the generation. + * The `SuppressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. + * @type {number[]} + * @default null + */ + begin_suppress_tokens = null; + + /** + * A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. + * For example, `[[1, 123]]` means the second generated token will always be a token of index 123. + * @type {[number, number][]} + * @default null + */ + forced_decoder_ids = null; + + // Parameters that define the output variables of `generate` + /** + * The number of independently computed returned sequences for each element in the batch. + * @type {number} + * @default 1 + */ + num_return_sequences = 1; + + /** + * Whether or not to return the attentions tensors of all attention layers. + * See `attentions` under returned tensors for more details. + * @type {boolean} + * @default false + */ + output_attentions = false; + + /** + * Whether or not to return the hidden states of all layers. + * See `hidden_states` under returned tensors for more details. + * @type {boolean} + * @default false + */ + output_hidden_states = false; + + /** + * Whether or not to return the prediction scores. + * See `scores` under returned tensors for more details. + * @type {boolean} + * @default false + */ + output_scores = false; + + /** + * Whether or not to return a `ModelOutput` instead of a plain tuple. + * @type {boolean} + * @default false + */ + return_dict_in_generate = false; + + // Special tokens that can be used at generation time + /** + * The id of the *padding* token. + * @type {number} + * @default null + */ + pad_token_id = null; + + /** + * The id of the *beginning-of-sequence* token. + * @type {number} + * @default null + */ + bos_token_id = null; + + /** + * The id of the *end-of-sequence* token. + * Optionally, use a list to set multiple *end-of-sequence* tokens. + * @type {number|number[]} + * @default null + */ + eos_token_id = null; + + // Generation parameters exclusive to encoder-decoder models + /** + * If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. + * @type {number} + * @default 0 + */ + encoder_no_repeat_ngram_size = 0; + + /** + * If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + * @type {number} + * @default null + */ + decoder_start_token_id = null; + + // Wild card + /** + * Additional generation kwargs will be forwarded to the `generate` function of the model. + * Kwargs that are not present in `generate`'s signature will be used in the model forward pass. + * @type {Object} + * @default {} + */ + generation_kwargs = {}; + + /** + * + * @param {GenerationConfig} config + */ + constructor(config) { + Object.assign(this, pick(config, Object.getOwnPropertyNames(this))); + } +} + diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js new file mode 100644 index 000000000..409fd3ba8 --- /dev/null +++ b/src/generation/logits_process.js @@ -0,0 +1,624 @@ +import { Callable } from "../utils/generic.js"; +import { Tensor } from "../utils/tensor.js"; + +import { max, log_softmax } from "../utils/maths.js"; + +/** + * Abstract base class for all logit processors that can be applied during generation. + */ +export class LogitsProcessor extends Callable { + /** + * Apply the processor to the input logits. + * + * @abstract + * @param {number[]} input_ids The input ids. + * @param {Tensor} logits The logits to process. + * @throws {Error} Throws an error if `_call` is not implemented in the subclass. + */ + _call(input_ids, logits) { + throw Error("`_call` should be implemented in a subclass") + } +} + + +/** + * Abstract base class for all logit warpers that can be applied during generation with multinomial sampling. + */ +export class LogitsWarper extends Callable { + /** + * Apply the processor to the input logits. + * + * @abstract + * @param {number[]} input_ids The input ids. + * @param {Tensor} logits The logits to process. + * @throws {Error} Throws an error if `_call` is not implemented in the subclass. + */ + _call(input_ids, logits) { + throw Error("`_call` should be implemented in a subclass") + } +} + + +/** + * A class representing a list of logits processors. A logits processor is a function that modifies the logits + * output of a language model. This class provides methods for adding new processors and applying all processors to a + * batch of logits. + */ +export class LogitsProcessorList extends Callable { + /** + * Constructs a new instance of `LogitsProcessorList`. + */ + constructor() { + super(); + this.processors = []; + } + + /** + * Adds a new logits processor to the list. + * + * @param {LogitsProcessor} item The logits processor function to add. + */ + push(item) { + this.processors.push(item); + } + + /** + * Adds multiple logits processors to the list. + * + * @param {LogitsProcessor[]} items The logits processor functions to add. + */ + extend(items) { + this.processors.push(...items); + } + + /** + * Applies all logits processors in the list to a batch of logits, modifying them in-place. + * + * @param {number[]} input_ids The input IDs for the language model. + * @param {Tensor} logits + */ + _call(input_ids, logits) { + // NOTE: This is different from the Python code, since vanilla JS does not support vectorized operations. + // As a result, we apply each processor to each + // Modifies logits inplace + this.processors.forEach( + func => func(input_ids, logits) + ) + } + + [Symbol.iterator]() { + return this.processors.values(); + } +} + +// DEPRECATED: https://github.com/huggingface/transformers/pull/29485 +// /** +// * A logits processor that forces a specific token to be generated by the decoder. +// */ +// export class ForceTokensLogitsProcessor extends LogitsProcessor { +// /** +// * Constructs a new instance of `ForceTokensLogitsProcessor`. +// * +// * @param {[number, number][]} forced_decoder_ids The ids of tokens that should be forced. +// */ +// constructor(forced_decoder_ids) { +// super(); +// // TODO: convert to `new Map(forced_decoder_ids)` +// this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []); +// } + +// /** +// * Apply the processor to the input logits. +// * +// * @param {number[]} input_ids The input ids. +// * @param {Tensor} logits The logits to process. +// * @returns {Tensor} The processed logits. +// */ +// _call(input_ids, logits) { +// console.log('this.force_token_map', this.force_token_map) +// console.log('call ForceTokensLogitsProcessor', input_ids, logits) +// console.log('input_ids.length', input_ids.length) +// let map = this.force_token_map[input_ids.length]; +// if (map) { // There exists a mapping +// logits.data.fill(-Infinity) +// logits.data[map] = 0; +// } +// console.log('map', map) +// // throw Error("Not implemented") +// return logits; +// } +// } + +/** + * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence. + */ +export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { + /** + * Create a ForcedBOSTokenLogitsProcessor. + * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced. + */ + constructor(bos_token_id) { + super(); + this.bos_token_id = bos_token_id; + } + + /** + * Apply the BOS token forcing to the logits. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The logits with BOS token forcing. + */ + _call(input_ids, logits) { + if (input_ids.length === 1) { + logits.data.fill(-Infinity) + logits.data[this.bos_token_id] = 0; + } + return logits; + } +} + +/** + * A logits processor that forces end-of-sequence token probability to 1. + */ +export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor { + /** + * Create a ForcedEOSTokenLogitsProcessor. + * @param {number} max_length Max length of the sequence. + * @param {number|number[]} forced_eos_token_id The ID of the end-of-sequence token to be forced. + */ + constructor(max_length, forced_eos_token_id) { + super(); + this.max_length = max_length; + this.forced_eos_token_id = forced_eos_token_id; + } + + /** + * Apply the processor to input_ids and logits. + * + * @param {number[]} input_ids The input ids. + * @param {Tensor} logits The logits tensor. + */ + _call(input_ids, logits) { + // console.log('call ForcedEOSTokenLogitsProcessor') + // TODO + } +} + +/** + * A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts + * generating using `begin_index` tokens. This should ensure that the tokens defined by + * `begin_suppress_tokens` at not sampled at the begining of the generation. + */ +export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { + /** + * Create a SuppressTokensAtBeginLogitsProcessor. + * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress. + * @param {number} begin_index The number of tokens to generate before suppressing tokens. + */ + constructor(begin_suppress_tokens, begin_index) { + super(); + this.begin_suppress_tokens = begin_suppress_tokens; + this.begin_index = begin_index; + } + + /** + * Apply the BOS token forcing to the logits. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The logits with BOS token forcing. + */ + _call(input_ids, logits) { + if (input_ids.length === this.begin_index) { + for (let token_id of this.begin_suppress_tokens) { + logits.data[token_id] = -Infinity; + } + } + return logits; + } +} + +/** + * A LogitsProcessor that handles adding timestamps to generated text. + */ +export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { + /** + * Constructs a new WhisperTimeStampLogitsProcessor. + * @param {Object} generate_config The config object passed to the `generate()` method of a transformer model. + * @param {number} generate_config.eos_token_id The ID of the end-of-sequence token. + * @param {number} generate_config.no_timestamps_token_id The ID of the token used to indicate that a token should not have a timestamp. + * @param {[number, number][]} [generate_config.forced_decoder_ids] An array of two-element arrays representing decoder IDs that are forced to appear in the output. The second element of each array indicates whether the token is a timestamp. + * @param {number} [generate_config.max_initial_timestamp_index] The maximum index at which an initial timestamp can appear. + */ + constructor(generate_config) { + super(); + this.eos_token_id = generate_config.eos_token_id; + this.no_timestamps_token_id = generate_config.no_timestamps_token_id; + this.timestamp_begin = this.no_timestamps_token_id + 1; + + this.begin_index = (generate_config.forced_decoder_ids || []).length + 2; + if (generate_config.forced_decoder_ids.slice(-1)[0][1] === this.no_timestamps_token_id) { + this.begin_index -= 1; + } + this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index; + + } + + /** + * Modify the logits to handle timestamp tokens. + * @param {number[]} input_ids The input sequence of tokens. + * @param {Tensor} logits The logits output by the model. + * @returns {Tensor} The modified logits. + */ + _call(input_ids, logits) { + const logitsData = /** @type {Float32Array} */(logits.data); + + // suppress <|notimestamps|> which is handled by without_timestamps + logitsData[this.no_timestamps_token_id] = -Infinity; + + if (input_ids.length === this.begin_index - 1) { + logitsData.fill(-Infinity); + logitsData[this.timestamp_begin] = 0; + return logits; + } + + // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly + const seq = input_ids.slice(this.begin_index); + const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin; + const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin; + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { // has to be non-timestamp + logitsData.subarray(this.timestamp_begin).fill(-Infinity); + } else { // cannot be normal text tokens + logitsData.subarray(0, this.eos_token_id).fill(-Infinity); + } + } + + // apply the `max_initial_timestamp` option + if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) { + const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index; + logitsData.subarray(last_allowed + 1).fill(-Infinity); + } + + // if sum of probability over timestamps is above any other token, sample timestamp + const logprobs = log_softmax(logitsData); + const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b)); + const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0]; + + if (timestamp_logprob > max_text_token_logprob) { + logitsData.subarray(0, this.timestamp_begin).fill(-Infinity); + } + + return logits; + } +} + +/** + * A logits processor that disallows ngrams of a certain size to be repeated. + */ +export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { + /** + * Create a NoRepeatNGramLogitsProcessor. + * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once. + */ + constructor(no_repeat_ngram_size) { + super(); + this.no_repeat_ngram_size = no_repeat_ngram_size; + } + + /** + * Generate n-grams from a sequence of token ids. + * @param {number[]} prevInputIds List of previous input ids + * @returns {Map} Map of generated n-grams + */ + getNgrams(prevInputIds) { + const curLen = prevInputIds.length; + + /**@type {number[][]} */ + const ngrams = []; + for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) { + const ngram = []; + for (let k = 0; k < this.no_repeat_ngram_size; ++k) { + ngram.push(prevInputIds[j + k]); + } + ngrams.push(ngram); + } + + /** @type {Map} */ + const generatedNgram = new Map(); + for (const ngram of ngrams) { + const prevNgram = ngram.slice(0, ngram.length - 1); + const prevNgramKey = JSON.stringify(prevNgram); + const prevNgramValue = generatedNgram.get(prevNgramKey) ?? []; + prevNgramValue.push(ngram[ngram.length - 1]); + generatedNgram.set(prevNgramKey, prevNgramValue); + } + return generatedNgram; + } + + /** + * Generate n-grams from a sequence of token ids. + * @param {Map} bannedNgrams Map of banned n-grams + * @param {number[]} prevInputIds List of previous input ids + * @returns {number[]} Map of generated n-grams + */ + getGeneratedNgrams(bannedNgrams, prevInputIds) { + const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length); + const banned = bannedNgrams.get(JSON.stringify(ngramIdx)) ?? []; + return banned; + } + + /** + * Calculate banned n-gram tokens + * @param {number[]} prevInputIds List of previous input ids + * @returns {number[]} Map of generated n-grams + */ + calcBannedNgramTokens(prevInputIds) { + const bannedTokens = []; + if (prevInputIds.length + 1 < this.no_repeat_ngram_size) { + // return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return bannedTokens; + + } else { + const generatedNgrams = this.getNgrams(prevInputIds); + const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds); + return bannedTokens; + } + } + + /** + * Apply the no-repeat-ngram processor to the logits. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The logits with no-repeat-ngram processing. + */ + _call(input_ids, logits) { + const bannedTokens = this.calcBannedNgramTokens(input_ids); + + for (const token of bannedTokens) { + logits.data[token] = -Infinity; + } + return logits; + } +} + +/** + * A logits processor that penalises repeated output tokens. + */ +export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { + /** + * Create a RepetitionPenaltyLogitsProcessor. + * @param {number} penalty The penalty to apply for repeated tokens. + */ + constructor(penalty) { + super(); + this.penalty = penalty; + } + + /** + * Apply the repetition penalty to the logits. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The logits with repetition penalty processing. + */ + _call(input_ids, logits) { + // Modify the logits corresponding to each element in `input_ids`. + // As a consequence, the logits corresponding to tokens that appear + // many times in the output will be penalised more. + for (const input_id of input_ids) { + if (logits.data[input_id] < 0) { + logits.data[input_id] *= this.penalty; + } else { + logits.data[input_id] /= this.penalty; + } + } + return logits + } +} + +/** + * A logits processor that enforces a minimum number of tokens. + */ +export class MinLengthLogitsProcessor extends LogitsProcessor { + /** + * Create a MinLengthLogitsProcessor. + * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity. + * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. + */ + constructor(min_length, eos_token_id) { + super(); + this.min_length = min_length; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + if (input_ids.length < this.min_length) { + for (const eos_token of this.eos_token_id) { + logits.data[eos_token] = -Infinity; + } + } + + return logits + } +} + +/** + * A logits processor that enforces a minimum number of new tokens. + */ +export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { + /** + * Create a MinNewTokensLengthLogitsProcessor. + * @param {number} prompt_length_to_skip The input tokens length. + * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity. + * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. + */ + constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) { + super(); + this.prompt_length_to_skip = prompt_length_to_skip; + this.min_new_tokens = min_new_tokens; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + const new_tokens_length = input_ids.length - this.prompt_length_to_skip; + if (new_tokens_length < this.min_new_tokens) { + for (const eos_token of this.eos_token_id) { + logits.data[eos_token] = -Infinity; + } + } + + return logits + } +} + +export class NoBadWordsLogitsProcessor extends LogitsProcessor { + /** + * Create a `NoBadWordsLogitsProcessor`. + * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated. + * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + */ + constructor(bad_words_ids, eos_token_id) { + super(); + this.bad_words_ids = bad_words_ids; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + + for (const bad_word_ids of this.bad_words_ids) { + // Whether to modify the logits of the last token in the bad word id sequence + let mark = true; + + // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last), + // then we set the logits of the last bad word id to -Infinity. + for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) { + + if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) { + // We have found a mismatch + mark = false; + break; + } + } + if (mark) { + logits.data[bad_word_ids.at(-1)] = -Infinity; + } + } + + return logits + } +} + +/** + * [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means + * that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`]. + */ +export class TemperatureLogitsWarper extends LogitsWarper { + /** + * Create a `TemperatureLogitsWarper`. + * @param {number} temperature Strictly positive float value used to modulate the logits distribution. + * A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting + * all probability mass to the most likely token. + */ + constructor(temperature) { + super(); + + if (typeof temperature !== 'number' || temperature <= 0) { + let errorMessage = + `\`temperature\` (=${temperature}) must be a strictly positive float, otherwise your next token scores will be invalid.`; + + if (temperature === 0) { + errorMessage += " If you're looking for greedy decoding strategies, set `do_sample=false`." + } + } + this.temperature = temperature; + } + + /** + * Apply logit warper. + * @param {number[]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + const logitsData = /** @type {Float32Array} */(logits.data); + for (let i = 0; i < logitsData.length; ++i) { + logitsData[i] /= this.temperature; + } + return logits; + } +} + +/** + * [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + * Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. + */ +export class TopPLogitsWarper extends LogitsWarper { + /** + * Create a `TopPLogitsWarper`. + * @param {number} top_p If set to < 1, only the smallest set of most probable tokens with + * probabilities that add up to `top_p` or higher are kept for generation. + * @param {Object} options Additional options for the top-p sampling. + * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value. + * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered. + */ + constructor(top_p, { + filter_value = -Infinity, + min_tokens_to_keep = 1, + } = {}) { + super(); + if (top_p < 0 || top_p > 1.0) { + throw new Error(`\`top_p\` must be a float > 0 and < 1, but is ${top_p}`) + } + if (!Number.isInteger(min_tokens_to_keep) || min_tokens_to_keep < 1) { + throw new Error(`\`min_tokens_to_keep\` must be a positive integer, but is ${min_tokens_to_keep}`) + } + + this.top_p = top_p + this.filter_value = filter_value + this.min_tokens_to_keep = min_tokens_to_keep + } +} + +/** + * [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + * Often used together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. + */ +export class TopKLogitsWarper extends LogitsWarper { + /** + * Create a `TopKLogitsWarper`. + * @param {number} top_k If set to > 0, only the top `top_k` tokens are kept for generation. + * @param {Object} options Additional options for the top-k sampling. + * @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value. + * @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered. + */ + constructor(top_k, { + filter_value = -Infinity, + min_tokens_to_keep = 1, + } = {}) { + super(); + if (!Number.isInteger(top_k) || top_k < 0) { + throw new Error(`\`top_k\` must be a positive integer, but is ${top_k}`) + } + + this.top_k = Math.max(top_k, min_tokens_to_keep) + this.filter_value = filter_value + } +} \ No newline at end of file diff --git a/src/generation/logits_sampler.js b/src/generation/logits_sampler.js new file mode 100644 index 000000000..00a405da8 --- /dev/null +++ b/src/generation/logits_sampler.js @@ -0,0 +1,215 @@ + +import { Callable } from "../utils/generic.js"; +import { Tensor } from "../utils/tensor.js"; + +import { + max, + softmax, + getTopItems, +} from '../utils/maths.js'; +import { GenerationConfig } from '../generation/configuration_utils.js'; + +/** + * Sampler is a base class for all sampling methods used for text generation. + */ +export class LogitsSampler extends Callable { + /** + * Creates a new Sampler object with the specified generation config. + * @param {GenerationConfig} generation_config The generation config. + */ + constructor(generation_config) { + super(); + this.generation_config = generation_config; + } + + /** + * Executes the sampler, using the specified logits. + * @param {Tensor} logits + * @param {number} index + * @returns {[number, number][]} + */ + _call(logits, index = -1) { + // Sample from logits, of dims [batch, sequence_length, vocab_size]. + // If index is specified, sample from [batch, index, vocab_size]. + return this.sample(logits, index); + } + + /** + * Abstract method for sampling the logits. + * @param {Tensor} logits + * @param {number} index + * @throws {Error} + * @returns {[number, number][]} + */ + sample(logits, index) { + throw Error("sample should be implemented in subclasses.") + } + + /** + * Returns the specified logits as an array, with temperature applied. + * @param {Tensor} logits + * @param {number} index + * @returns {Float32Array} + */ + getLogits(logits, index) { + let vocabSize = logits.dims.at(-1); + + let logs = /** @type {Float32Array} */(logits.data); + + if (index === -1) { + logs = logs.slice(-vocabSize); + } else { + let startIndex = index * vocabSize; + logs = logs.slice(startIndex, startIndex + vocabSize); + } + + // add temperature + // if (this.generation_config.temperature > 0) { + // logs = logs.map(x => x / this.generation_config.temperature) + // } + return logs; + } + + /** + * Selects an item randomly based on the specified probabilities. + * @param {Array} probabilities An array of probabilities to use for selection. + * @returns {number} The index of the selected item. + */ + randomSelect(probabilities) { + // Return index of chosen item + let sumProbabilities = probabilities.reduce((acc, curr) => acc + curr, 0); + + let r = Math.random() * sumProbabilities; + for (let i = 0; i < probabilities.length; ++i) { + r -= probabilities[i]; + if (r <= 0) { + return i; + } + } + return 0; // return first (most probable) as a fallback + } + + /** + * Returns a Sampler object based on the specified options. + * @param {GenerationConfig} generation_config An object containing options for the sampler. + * @returns {LogitsSampler} A Sampler object. + */ + static getSampler(generation_config) { + // - *greedy decoding*: `num_beams=1` and `do_sample=False` + // - *contrastive search*: `penalty_alpha>0` and `top_k>1` + // - *multinomial sampling*: `num_beams=1` and `do_sample=True` + // - *beam-search decoding*: `num_beams>1` and `do_sample=False` + // - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True` + // - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1` + // - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None` + + // NOTE: beam search is implemented directly into the generation function + if (generation_config.do_sample) { + return new MultinomialSampler(generation_config); + + } else if (generation_config.num_beams > 1) { + return new BeamSearchSampler(generation_config); + + } else { + if (generation_config.num_return_sequences > 1) { + throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`) + } + return new GreedySampler(generation_config); + } + } +} + +/** + * Class representing a Greedy Sampler. + */ +class GreedySampler extends LogitsSampler { + /** + * Sample the maximum probability of a given logits tensor. + * @param {Tensor} logits + * @param {number} [index=-1] + * @returns {[number, number][]} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search). + */ + sample(logits, index = -1) { + // NOTE: no need to do log_softmax here since we only take the maximum + let logs = this.getLogits(logits, index); + let argmax = max(logs)[1]; + + // Note: score is meaningless in this context, since we are performing + // greedy search (p = 1 => log(p) = 0) + return [ + [argmax, 0] + ]; + } +} + +/** + * Class representing a MultinomialSampler. + */ +class MultinomialSampler extends LogitsSampler { + + /** + * Sample from the logits. + * @param {Tensor} logits + * @param {number} index + * @returns {[number, number][]} + */ + sample(logits, index = -1) { + let k = logits.dims.at(-1); // defaults to vocab size + if (this.generation_config.top_k > 0) { + k = Math.min(this.generation_config.top_k, k); + } + + // Get logits of nth token + const logs = this.getLogits(logits, index); + + // Get top k tokens + const topLogits = getTopItems(logs, k); + + // Compute softmax over logits + const probabilities = softmax(topLogits.map(x => x[1])); + + return Array.from({ length: this.generation_config.num_beams }, () => { + const sampledIndex = this.randomSelect(probabilities); + return [ + topLogits[sampledIndex][0], // token id + Math.log(probabilities[sampledIndex]), // score + ]; + }); + } +} + + +/** + * Class representing a BeamSearchSampler. + */ +class BeamSearchSampler extends LogitsSampler { + + /** + * Sample from the logits. + * @param {Tensor} logits + * @param {number} index + * @returns {[number, number][]} + */ + sample(logits, index = -1) { + let k = logits.dims.at(-1); // defaults to vocab size + if (this.generation_config.top_k > 0) { + k = Math.min(this.generation_config.top_k, k); + } + + // Get logits of nth token + const logs = this.getLogits(logits, index); + + // Get top k tokens + const topLogits = getTopItems(logs, k); + + // Compute softmax over logits + const probabilities = softmax(topLogits.map(x => x[1])); + + return Array.from({ length: this.generation_config.num_beams }, (_, i) => { + return [ + topLogits[i][0], // token id + Math.log(probabilities[i]), // score + ]; + }); + } +} diff --git a/src/generation/parameters.js b/src/generation/parameters.js new file mode 100644 index 000000000..b1dc4b5a8 --- /dev/null +++ b/src/generation/parameters.js @@ -0,0 +1,26 @@ + +/** + * @typedef {Object} GenerationFunctionParameters + * @property {import('../utils/tensor.js').Tensor} [inputs=null] (`Tensor` of varying shape depending on the modality, *optional*): + * The sequence used as a prompt for the generation or as model inputs to the encoder. If `null` the + * method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + * should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + * `input_ids`, `input_values`, `input_features`, or `pixel_values`. + * @property {import('./configuration_utils.js').GenerationConfig} [generation_config=null] (`GenerationConfig`, *optional*): + * The generation configuration to be used as base parametrization for the generation call. + * `**kwargs` passed to generate matching the attributes of `generation_config` will override them. + * If `generation_config` is not provided, the default will be used, which has the following loading + * priority: + * - (1) from the `generation_config.json` model file, if it exists; + * - (2) from the model configuration. Please note that unspecified parameters will inherit [`GenerationConfig`]'s + * default values, whose documentation should be checked to parameterize generation. + * @property {import('./logits_process.js').LogitsProcessorList} [logits_processor=null] (`LogitsProcessorList`, *optional*): + * Custom logits processors that complement the default logits processors built from arguments and + * generation config. If a logit processor is passed that is already created with the arguments or a + * generation config an error is thrown. This feature is intended for advanced users. + * @property {import('./stopping_criteria.js').StoppingCriteriaList} [stopping_criteria=null] (`StoppingCriteriaList`, *optional*): + * Custom stopping criteria that complements the default stopping criteria built from arguments and a + * generation config. If a stopping criteria is passed that is already created with the arguments or a + * generation config an error is thrown. This feature is intended for advanced users. + * @param {any} [kwargs] (`Dict[str, any]`, *optional*): + */ diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.js new file mode 100644 index 000000000..c9863bca8 --- /dev/null +++ b/src/generation/stopping_criteria.js @@ -0,0 +1,124 @@ + +import { Callable } from "../utils/generic.js"; + +// NOTE: +// Stopping Criteria returns a list of `batch_size` booleans, indicating whether each sequence in the batch should be stopped. + +/** + * Abstract base class for all stopping criteria that can be applied during generation. + */ +export class StoppingCriteria extends Callable { + /** + * + * @param {number[][]} input_ids (`number[][]` of shape `(batch_size, sequence_length)`): + * Indices of input sequence tokens in the vocabulary. + * @param {number[][]} scores scores (`number[][]` of shape `(batch_size, config.vocab_size)`): + * Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax + * or scores for each vocabulary token after SoftMax. + * @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped. + */ + _call(input_ids, scores) { + throw Error("StoppingCriteria needs to be subclassed"); + } +} +/** + */ +export class StoppingCriteriaList extends Callable { + /** + * Constructs a new instance of `StoppingCriteriaList`. + */ + constructor() { + super(); + this.criteria = []; + } + + /** + * Adds a new stopping criterion to the list. + * + * @param {StoppingCriteria} item The stopping criterion to add. + */ + push(item) { + this.criteria.push(item); + } + + /** + * Adds multiple stopping criteria to the list. + * + * @param {StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add. + */ + extend(items) { + if (items instanceof StoppingCriteriaList) { + items = items.criteria; + } + this.criteria.push(...items); + } + + _call(input_ids, scores) { + const is_done = new Array(input_ids.length).fill(false); + for (const criterion of this.criteria) { + const criterion_done = criterion(input_ids, scores); + for (let i = 0; i < is_done.length; ++i) { + is_done[i] ||= criterion_done[i]; + } + } + return is_done; + } + + [Symbol.iterator]() { + return this.criteria.values(); + } +} + +/** + * This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. + * Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens. + */ +export class MaxLengthCriteria extends StoppingCriteria { + + /** + * + * @param {number} max_length The maximum length that the output sequence can have in number of tokens. + * @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute. + */ + constructor(max_length, max_position_embeddings = null) { + super(); + this.max_length = max_length; + this.max_position_embeddings = max_position_embeddings; + } + + _call(input_ids) { + return input_ids.map(ids => ids.length >= this.max_length); + } +} + +// TODO: add MaxTimeCriteria + +/** + * This class can be used to stop generation whenever the "end-of-sequence" token is generated. + * By default, it uses the `model.generation_config.eos_token_id`. + */ +export class EosTokenCriteria extends StoppingCriteria { + + /** + * + * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. + * Optionally, use a list to set multiple *end-of-sequence* tokens. + */ + constructor(eos_token_id) { + super(); + if (!Array.isArray(eos_token_id)) { + eos_token_id = [eos_token_id]; + } + this.eos_token_id = eos_token_id; + } + + /** + * + * @param {number[][]} input_ids + * @param {number[][]} scores + * @returns + */ + _call(input_ids, scores) { + return input_ids.map(ids => this.eos_token_id.includes(ids.at(-1))); + } +} \ No newline at end of file diff --git a/src/models.js b/src/models.js index 32f615ff4..2ef4559d8 100644 --- a/src/models.js +++ b/src/models.js @@ -56,9 +56,13 @@ import { import { Callable, +} from './utils/generic.js'; + +import { isIntegralNumber, isTypedArray, mergeArrays, + pick, } from './utils/core.js'; import { @@ -68,8 +72,6 @@ import { import { LogitsProcessorList, - GenerationConfig, - ForceTokensLogitsProcessor, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, @@ -80,13 +82,20 @@ import { MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, - Sampler, -} from './utils/generation.js'; + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +} from './generation/logits_process.js'; + +import { + GenerationConfig, +} from './generation/configuration_utils.js'; import { cat, dynamicTimeWarping, mean, + ones, ones_like, stack, std_mean, @@ -94,6 +103,8 @@ import { } from './utils/tensor.js'; import { medianFilter } from './utils/maths.js'; +import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; +import { LogitsSampler } from './generation/logits_sampler.js'; ////////////////////////////////////////////////// // Model types: used internally @@ -104,6 +115,8 @@ const MODEL_TYPES = { Vision2Seq: 3, DecoderOnly: 4, MaskGeneration: 5, + ImageTextToText: 6, + Musicgen: 7, } ////////////////////////////////////////////////// @@ -141,7 +154,7 @@ async function constructSession(pretrained_model_name_or_path, fileName, options // Construct the model file name const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[dtype]; - const modelFileName = `onnx/${fileName}${suffix}.onnx`; + const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`; const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); @@ -336,41 +349,6 @@ function prepareAttentionMask(self, tokens) { } } -/** - * Add position IDs to the feeds object. - * @param {Object} session The inference session. - * @param {Object} feeds The input to the model. - * @param {boolean} use_cache_branch Whether to use the cache branch of the model. - * @returns {void} - * @private - */ -function preparePositionIds(session, feeds, use_cache_branch) { - if (!session.inputNames.includes('position_ids')) return; - - const data = new BigInt64Array(feeds.attention_mask.data.length); - - // Compute cumulative sum of the attention mask along the sequence length dimension - for (let i = 0; i < feeds.attention_mask.dims[0]; ++i) { - let start = i * feeds.attention_mask.dims[1]; - let sum = BigInt(0); - for (let j = 0; j < feeds.attention_mask.dims[1]; ++j) { - const index = start + j; - if (feeds.attention_mask.data[index] === 0n) { - data[index] = BigInt(1); - } else { // === 1n - data[index] = sum; - sum += feeds.attention_mask.data[index]; - } - } - } - - feeds.position_ids = new Tensor('int64', data, feeds.attention_mask.dims); - - if (use_cache_branch) { - feeds.position_ids = feeds.position_ids.slice(null, -1).unsqueeze_(-1); - } -} - /** * Creates a boolean tensor with a single value. * @param {boolean} value The value of the tensor. @@ -393,144 +371,37 @@ async function seq2seqForward(self, model_inputs) { let { encoder_outputs, past_key_values } = model_inputs; + // Encode if needed if (!encoder_outputs) { + const encoder_inputs = pick(model_inputs, self.session.inputNames); // Encoder outputs are not given, so we must compute them. - encoder_outputs = (await encoderForward(self, model_inputs)).last_hidden_state; + encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state; } - let decoderFeeds = { - input_ids: model_inputs.decoder_input_ids, - encoder_hidden_states: encoder_outputs, - }; + + + const { input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs; + other_decoder_inputs.input_ids = decoder_input_ids; + other_decoder_inputs.encoder_hidden_states = encoder_outputs; const use_cache_branch = !!past_key_values; if (self.decoder_merged_session.inputNames.includes('use_cache_branch')) { - decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); + other_decoder_inputs.use_cache_branch = boolTensor(use_cache_branch); } - if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { - decoderFeeds.encoder_attention_mask = model_inputs.attention_mask + other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask } - preparePositionIds(self.decoder_merged_session, decoderFeeds, use_cache_branch); - self.addPastKeyValues(decoderFeeds, past_key_values); + this.addPastKeyValues(other_decoder_inputs, past_key_values); - const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds); - let logits = decoderResults.logits; - past_key_values = self.getPastKeyValues(decoderResults, past_key_values); + // Rename decoder inputs + const decoder_inputs = pick(other_decoder_inputs, self.decoder_merged_session.inputNames); - // Get cross attention and/or decoder attentions if they are present - const attns = self.getAttentions(decoderResults); - - return new Seq2SeqLMOutput({ logits, past_key_values, encoder_outputs, ...attns }); -} - -/** - * Start the beam search process for the seq2seq model. - * @param {PreTrainedModel} self The seq2seq model object. - * @param {Tensor} inputTokenIds Array of input token ids for each input sequence. - * @param {Object} generation_config The generation config. - * @param {number} numOutputTokens The maximum number of output tokens for the model. - * @returns {Object[]} Array of beam search objects. - * @private - */ -function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputTokens) { - let beams = []; - let beamId = 0; - - // @ts-ignore - const requires_attention_mask = self.requires_attention_mask ?? true; - - // decoder_input_ids == output_token_ids - let decoder_input_ids = - generation_config.decoder_input_ids - ?? generation_config.decoder_start_token_id - ?? generation_config.bos_token_id - ?? generation_config.eos_token_id; - - // Support input as tensor or list - // TODO support batched decoder_input_ids - if (decoder_input_ids instanceof Tensor) { - decoder_input_ids = decoder_input_ids.tolist().flat(); - } else if (!Array.isArray(decoder_input_ids)) { - decoder_input_ids = [decoder_input_ids]; - } - - for (let tokens of inputTokenIds) { - // TODO: Improve - // Currently, just add back batch dimension. - // In future, allow for true parallel execution - tokens.dims = [1, ...tokens.dims] - - // Create beam - let start = { - inputs: tokens, - encoder_outputs: null, - prev_model_outputs: null, - - output_token_ids: decoder_input_ids, - done: false, - score: 0, - id: beamId++ // assign unique id to beams - } - - if (requires_attention_mask) { - start.attention_mask = prepareAttentionMask(self, tokens); - } - - beams.push(start); - } - - return beams; -} - -/** - * Run beam search on the seq2seq model for a single beam. - * @param {PreTrainedModel} self The seq2seq model object. - * @param {Object} beam The beam search object for which to run the model. - * @param {Object} options options - * @param {string} [options.input_name='input_ids'] The name of the input tensor for the encoder. - * @returns {Promise} Promise that resolves with the output of the seq2seq model for the given beam. - * @private - */ -async function seq2seqRunBeam(self, beam) { - const input_name = self.main_input_name; - - let decoder_input_ids = beam.output_token_ids; - if (beam.prev_model_outputs) { - // After the first step, `prev_model_outputs` won't be null. - // So, we cut decoder_input_ids if past is used - decoder_input_ids = decoder_input_ids.slice(-1); - } - - // 1. Prepare - let model_inputs = { - [input_name]: beam.inputs, - decoder_input_ids: toI64Tensor(decoder_input_ids), - encoder_outputs: beam.encoder_outputs, - past_key_values: beam.prev_model_outputs?.past_key_values, - } - if (beam.attention_mask) { - model_inputs.attention_mask = beam.attention_mask - } + const decoderResults = await sessionRun(self.decoder_merged_session, decoder_inputs); - // 2. Run - let output = await self.forward(model_inputs); - - // 3. Update - beam.prev_model_outputs = output; - beam.encoder_outputs = output.encoder_outputs; + // Get cross attention and/or decoder attentions if they are present + // const attns = self.getAttentions(decoderResults); - return output; -} - -/** - * Update a beam with a new token ID. - * @param {Object} beam The beam to update. - * @param {number} newTokenId The new token ID to add to the beam's output. - * @private - */ -function seq2seqUpdatebeam(beam, newTokenId) { - beam.output_token_ids = [...beam.output_token_ids, newTokenId]; + return decoderResults; } /** @@ -557,7 +428,6 @@ async function encoderForward(self, model_inputs) { return await sessionRun(self.session, encoderFeeds); } - /** * Forward pass of a decoder model. * @param {Object} self The decoder model. @@ -566,126 +436,71 @@ async function encoderForward(self, model_inputs) { * @private */ async function decoderForward(self, model_inputs) { - let { input_ids, past_key_values, attention_mask } = model_inputs; - let decoderFeeds = { - input_ids: input_ids, - attention_mask: attention_mask ?? prepareAttentionMask(self, input_ids), - } - const use_cache_branch = !!past_key_values; - - if (self.session.inputNames.includes('use_cache_branch')) { - decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); - } - - preparePositionIds(self.session, decoderFeeds, use_cache_branch); - - self.addPastKeyValues(decoderFeeds, past_key_values); - - let decoderResults = await sessionRun(self.session, decoderFeeds); - - let logits = decoderResults.logits; - - past_key_values = self.getPastKeyValues(decoderResults, past_key_values); - return { logits, past_key_values }; + // TODO move addPastKeyValues from decoder_prepare_inputs_for_generation here + return await sessionRun(self.session, model_inputs); } -/** - * Starts the generation of text by initializing the beams for the given input token IDs. - * @param {Object} self The text generation model object. - * @param {Tensor} inputTokenIds An tensor of input token IDs to generate text from. - * @param {Object} generation_config The generation config. - * @param {number} numOutputTokens The maximum number of tokens to generate for each beam. - * @param {Tensor} [inputs_attention_mask] The attention mask tensor for the input token IDs. - * @returns {Object[]} An array of beams initialized with the given inputs and parameters. - * @private - */ -function decoderStartBeams(self, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) { - let beams = []; +function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { - let beamId = 0; - for (let tokens of inputTokenIds) { - let output_token_ids = tokens.tolist().map(Number); + const { past_key_values, ...new_model_inputs } = model_inputs; - // TODO: Improve - // Currently, just add back batch dimension. - // In future, allow for true parallel execution - tokens.dims = [1, ...tokens.dims] + self.addPastKeyValues(new_model_inputs, past_key_values); - let attn_mask; - if (inputs_attention_mask) { - attn_mask = inputs_attention_mask[beamId]; - attn_mask.dims = [1, ...attn_mask.dims] + const fixed = pick(new_model_inputs, self.session.inputNames); - } else { - attn_mask = prepareAttentionMask(self, tokens) + if (self.session.inputNames.includes('position_ids') && fixed.attention_mask && !fixed.position_ids) { + // If the model supports providing position_ids, we create position_ids on the fly for batch generation, + // by computing the cumulative sum of the attention mask along the sequence length dimension. + // + // Equivalent to: + // position_ids = attention_mask.long().cumsum(-1) - 1 + // position_ids.masked_fill_(attention_mask == 0, 1) + // if past_key_values: + // position_ids = position_ids[:, -input_ids.shape[1] :] + const [bz, seq_len] = fixed.attention_mask.dims; + + const data = new BigInt64Array(fixed.attention_mask.data.length); + for (let i = 0; i < bz; ++i) { + const start = i * seq_len; + let sum = BigInt(0); + for (let j = 0; j < seq_len; ++j) { + const index = start + j; + if (fixed.attention_mask.data[index] === 0n) { + data[index] = BigInt(1); + } else { // === 1n + data[index] = sum; + sum += fixed.attention_mask.data[index]; + } + } } - let start = { - input: tokens, - model_input_ids: tokens, - attention_mask: attn_mask, - prev_model_outputs: null, - - output_token_ids: output_token_ids, - num_output_tokens: numOutputTokens, - - done: false, - score: 0, - id: beamId++ // assign unique id to beams + fixed.position_ids = new Tensor('int64', data, fixed.attention_mask.dims); + if (past_key_values) { + fixed.position_ids = fixed.position_ids.slice(null, -1).unsqueeze_(-1); } - - beams.push(start); } - return beams; -} -/** - * Runs a single step of the text generation process for a given beam. - * - * @param {Object} self The decoder object. - * @param {Object} beam The beam to run. - * @param {Tensor} beam.input The input tensor. - * @param {Tensor} beam.model_input_ids The input ids to the model. - * @param {Tensor} beam.attention_mask The attention mask. - * @param {Object} beam.prev_model_outputs The past key values. - * @param {number[]} beam.output_token_ids The output token ids. - * @returns {Promise} The output of the generation step. - * @private - */ -async function decoderRunBeam(self, beam) { - let attnMaskData = new BigInt64Array(beam.output_token_ids.length).fill(1n) - - // 1. Prepare - let model_inputs = { - input_ids: beam.model_input_ids, - attention_mask: new Tensor( - 'int64', - attnMaskData, - [1, attnMaskData.length] - ), - past_key_values: beam.prev_model_outputs?.past_key_values, - } + return fixed; +} - // 2. Run - let output = await self.forward(model_inputs); +function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { - // 3. Update - beam.prev_model_outputs = output; + // console.log('model_inputs', model_inputs) + const { ...new_model_inputs } = model_inputs; - return output; -} + const past_key_values = model_inputs.past_key_values; + // self.addPastKeyValues(new_model_inputs, past_key_values); -/** - * Update a beam with a new token ID. - * @param {Object} beam The beam to update. - * @param {number} newTokenId The new token ID to add to the beam's output. - * @private - */ -function decoderUpdatebeam(beam, newTokenId) { - beam.output_token_ids = [...beam.output_token_ids, newTokenId]; - beam.model_input_ids = new Tensor('int64', [BigInt(newTokenId)], [1, 1]); + if (past_key_values) { + // keep only final IDs: + input_ids = input_ids.map(x => [x.at(-1)]); + } else { + // input_ids; + } + new_model_inputs['decoder_input_ids'] = toI64Tensor(input_ids); + // throw new Error('Not implemented'); + return new_model_inputs; } - ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -694,7 +509,7 @@ function decoderUpdatebeam(beam, newTokenId) { */ export class PreTrainedModel extends Callable { main_input_name = 'input_ids'; - + forward_params = ['input_ids', 'attention_mask']; /** * Creates a new instance of the `PreTrainedModel` class. * @param {Object} config The model configuration. @@ -710,29 +525,29 @@ export class PreTrainedModel extends Callable { const modelType = MODEL_TYPE_MAPPING.get(modelName); this.can_generate = false; - this._runBeam = null; - this._getStartBeams = null; - this._updateBeam = null; this._forward = null; + + this._prepare_inputs_for_generation = null; if (modelType === MODEL_TYPES.DecoderOnly) { this.can_generate = true; - this._runBeam = decoderRunBeam; - this._getStartBeams = decoderStartBeams; - this._updateBeam = decoderUpdatebeam; this._forward = decoderForward; + this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation; - } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { + } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq || modelType === MODEL_TYPES.Musicgen) { this.can_generate = true; - this._runBeam = seq2seqRunBeam; - this._getStartBeams = seq2seqStartBeams; - this._updateBeam = seq2seqUpdatebeam; this._forward = seq2seqForward; + this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation; } else if (modelType === MODEL_TYPES.EncoderDecoder) { - this._forward = encoderForward; + // console.warn('TODO: Implement EncoderDecoderForward') + this._forward = seq2seqForward; + } else if (modelType === MODEL_TYPES.ImageTextToText) { + this.can_generate = true; + console.warn('TODO: Implement visionDecoderForward'); + // this._forward = visionDecoderForward; } else { // should be MODEL_TYPES.EncoderOnly this._forward = encoderForward; } @@ -777,6 +592,7 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', model_file_name = null, + subfolder = 'onnx', device = null, dtype = null, session_options = {}, @@ -789,6 +605,7 @@ export class PreTrainedModel extends Callable { local_files_only, revision, model_file_name, + subfolder, device, dtype, session_options, @@ -801,7 +618,7 @@ export class PreTrainedModel extends Callable { if (modelType === MODEL_TYPES.DecoderOnly) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options), + constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'model', options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); @@ -827,6 +644,24 @@ export class PreTrainedModel extends Callable { constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), ]); + } else if (modelType === MODEL_TYPES.ImageTextToText) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + constructSession(pretrained_model_name_or_path, 'embed_tokens', options), + constructSession(pretrained_model_name_or_path, 'vision_encoder', options), + getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), + ]); + + } else if (modelType === MODEL_TYPES.Musicgen) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'text_encoder', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + constructSession(pretrained_model_name_or_path, 'encodec_decode', options), + getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), + ]); + } else { // should be MODEL_TYPES.EncoderOnly if (modelType !== MODEL_TYPES.EncoderOnly) { console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`) @@ -862,7 +697,33 @@ export class PreTrainedModel extends Callable { } /** - * @param {import('./utils/generation.js').GenerationConfigType} generation_config + * This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] + * instances used for multinomial sampling. + * @param {GenerationConfig} generation_config The generation config. + * @returns {LogitsProcessorList} generation_config + */ + _get_logits_warper(generation_config) { + + // instantiate warpers list + const warpers = new LogitsProcessorList(); + + if (generation_config.temperature !== null && generation_config.temperature !== 1.0) { + warpers.push(new TemperatureLogitsWarper(generation_config.temperature)); + } + if (generation_config.top_k !== null && generation_config.top_k !== 0) { + // TODO: add min_tokens_to_keep + warpers.push(new TopKLogitsWarper(generation_config.top_k)); + } + if (generation_config.top_p !== null && generation_config.top_p < 1.0) { + // TODO: add min_tokens_to_keep + warpers.push(new TopPLogitsWarper(generation_config.top_p)); + } + + return warpers; + } + + /** + * @param {GenerationConfig} generation_config * @param {number} input_ids_seq_length The starting sequence length for the input ids. * @returns {LogitsProcessorList} * @private @@ -973,9 +834,10 @@ export class PreTrainedModel extends Callable { processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)); } - if (generation_config.forced_decoder_ids !== null) { - processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)); - } + // DEPRECATED: https://github.com/huggingface/transformers/pull/29485 + // if (generation_config.forced_decoder_ids !== null) { + // processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)); + // } if (logits_processor !== null) { processors.extend(logits_processor) @@ -992,235 +854,490 @@ export class PreTrainedModel extends Callable { /** * This function merges multiple generation configs together to form a final generation config to be used by the model for text generation. * It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object. - * @param {import('./utils/generation.js').GenerationConfigType} generation_config A `GenerationConfig` object containing generation parameters. - * @returns {import('./utils/generation.js').GenerationConfigType} The final generation config object to be used by the model for text generation. + * @param {GenerationConfig} generation_config A `GenerationConfig` object containing generation parameters. + * @param {Object} kwargs Additional generation parameters to be used in place of those in the `generation_config` object. + * @returns {GenerationConfig} The final generation config object to be used by the model for text generation. */ - _get_generation_config(generation_config) { + _prepare_generation_config(generation_config, kwargs) { // Create empty generation config (contains defaults) // We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them - let gen_config = new GenerationConfig(this.config); + const gen_config = new GenerationConfig(this.config); // Apply model's generation config, if it exists if ('generation_config' in this) { Object.assign(gen_config, this.generation_config); } - // Finally, use any generation config specified by the user + // Next, use any generation config specified by the user // when calling `generate` - if (generation_config !== null) { + if (generation_config) { Object.assign(gen_config, generation_config); } + + // Finally, if any kwargs were passed, use them to overwrite + if (kwargs) { + Object.assign(gen_config, pick(kwargs, Object.getOwnPropertyNames(gen_config))); + } + return gen_config; } /** - * @typedef {import('./utils/maths.js').TypedArray} TypedArray + * + * @param {GenerationConfig} generation_config + * @param {StoppingCriteriaList} [stopping_criteria=null] */ + _get_stopping_criteria(generation_config, stopping_criteria = null) { + const criteria = new StoppingCriteriaList(); + + if (generation_config.max_length !== null) { + criteria.push(new MaxLengthCriteria( + generation_config.max_length, + this.config.max_position_embeddings ?? null, + )); + } + // if (generation_config.max_time !== null) { + // criteria.push(new MaxTimeCriteria(generation_config.max_time)); + // } + if (generation_config.eos_token_id !== null) { + criteria.push(new EosTokenCriteria(generation_config.eos_token_id)); + } + + if (stopping_criteria) { + criteria.extend(stopping_criteria); + } + return criteria; + + } /** - * @typedef {{ sequences: Tensor, decoder_attentions: Tensor, cross_attentions: Tensor }} EncoderDecoderOutput - * @typedef {Object} DecoderOutput - * - * Generates text based on the given inputs and generation configuration using the model. - * @param {Tensor|Array|TypedArray} inputs An array of input token IDs. - * @param {Object|GenerationConfig|null} generation_config The generation configuration to use. If null, default configuration will be used. - * @param {Object|null} logits_processor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created. - * @param {Object} options options - * @param {Object} [options.inputs_attention_mask=null] An optional attention mask for the inputs. - * @returns {Promise} An array of generated output sequences, where each sequence is an array of token IDs. - * @throws {Error} Throws an error if the inputs array is empty. - */ - async generate( - inputs, - generation_config = null, - logits_processor = null, - { - inputs_attention_mask = null - } = {}, - ) { + * Confirms that the model class is compatible with generation. + * If not, raises an exception that points to the right class to use. + */ + _validate_model_class() { if (!this.can_generate) { + const generate_compatible_mappings = [ + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + // MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, // TODO + MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, + ]; + const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); - let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.` + const generate_compatible_classes = new Set(); const modelType = this.config.model_type; - const possibleInfo = - MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType) - ?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType) - ?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType) - // ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType) // TODO - ?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType); - - if (possibleInfo) { - // TODO: support multiple possible classes - errorMessage += ` Please use the following class instead: '${possibleInfo[0]}'`; + for (const model_mapping of generate_compatible_mappings) { + const supported_models = model_mapping.get(modelType); + if (supported_models) { + generate_compatible_classes.add(supported_models[0]); + } + } + + let errorMessage = `The current model class (${modelName}) is not compatible with \`.generate()\`, as it doesn't have a language model head.` + if (generate_compatible_classes.size > 0) { + errorMessage += ` Please use the following class instead: ${[...generate_compatible_classes].join(', ')}`; } throw Error(errorMessage); } + } + + prepare_inputs_for_generation(...args) { + return this._prepare_inputs_for_generation(this, ...args); + } + + _update_model_kwargs_for_generation({ generated_input_ids, outputs, model_inputs, is_encoder_decoder }) { + // update past_key_values + model_inputs['past_key_values'] = this.getPastKeyValues(outputs, model_inputs.past_key_values); + + // update inputs for next run + model_inputs['input_ids'] = new Tensor('int64', generated_input_ids, [generated_input_ids.length, 1]); + + if (!is_encoder_decoder) { + // update attention mask + model_inputs.attention_mask = cat( + [ + model_inputs.attention_mask, + ones([model_inputs.attention_mask.dims[0], 1]), + ], 1 + ); + } else if ('decoder_attention_mask' in model_inputs) { + // update decoder attention mask + console.warn('TODO: update decoder attention mask') + } - if (!(inputs instanceof Tensor) && !isTypedArray(inputs) && !Array.isArray(inputs)) { - throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs.constructor.name}".`); + // force recreate position_ids in next iteration + model_inputs['position_ids'] = null; + + return model_inputs; + } + + /** + * This function extracts the model-specific `inputs` for generation. + * @param {Object} params + * @param {Tensor} [params.inputs=null] + * @param {number} [params.bos_token_id=null] + * @param {Record} [params.model_kwargs] + * @returns {{inputs_tensor: Tensor, model_inputs: Record, model_input_name: string}} The model-specific inputs for generation. + */ + _prepare_model_inputs({ inputs, bos_token_id, model_kwargs }) { + const model_inputs = pick(model_kwargs, this.forward_params); + // console.log('model_inputs', model_inputs) + const input_name = this.main_input_name; + if (input_name in model_inputs) { + if (inputs) { + throw new Error( + "`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " + + "Make sure to either pass {inputs} or {input_name}=..." + ); + } + } else { + model_inputs[input_name] = inputs; } - let input_ids_seq_length; + const inputs_tensor = model_inputs[input_name]; + + return { inputs_tensor, model_inputs, model_input_name: input_name }; + } + + async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name }) { + const encoder_kwargs = pick(model_inputs, this.session.inputNames); + + const encoder_outputs = await encoderForward(this, encoder_kwargs); + + model_inputs['encoder_outputs'] = encoder_outputs.last_hidden_state; + + return model_inputs; + } + + /** + * Prepares `decoder_input_ids` for generation with encoder-decoder models + * @param {*} param0 + */ + _prepare_decoder_input_ids_for_generation({ batch_size, model_input_name, model_kwargs, decoder_start_token_id, bos_token_id }) { + + decoder_start_token_id = decoder_start_token_id ?? bos_token_id; - // Prepare `input_ids` which will be used for auto-regressive generation + let decoder_input_ids_start_data; + if (this.config.model_type === 'musicgen') { + // Custom logic + // TODO: move to Musicgen class + decoder_input_ids_start_data = + new Array(batch_size * this.config.decoder.num_codebooks) + .fill(decoder_start_token_id); + + } else if (Array.isArray(decoder_start_token_id)) { + if (decoder_start_token_id.length !== batch_size) { + throw new Error( + `\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}` + ) + } + // TODO: support list of start tokens? + decoder_input_ids_start_data = decoder_start_token_id; + } else { + decoder_input_ids_start_data = new Array(batch_size).fill(decoder_start_token_id); + } + const decoder_input_ids_start = new Tensor( + 'int64', + decoder_input_ids_start_data, + [decoder_input_ids_start_data.length, 1], + ); + + // TODO add other functionality + const decoder_input_ids = decoder_input_ids_start; + model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids); + + return { input_ids: decoder_input_ids, model_inputs: model_kwargs }; + } + + /** + * Generates sequences of token ids for models with a language modeling head. + * @param {import('./generation/parameters.js').GenerationFunctionParameters} options + * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores. + */ + async generate({ + inputs = null, + generation_config = null, + logits_processor = null, + stopping_criteria = null, + + // inputs_attention_mask = null, + ...kwargs + }) { + this._validate_model_class(); + + // Update generation config with defaults and kwargs + generation_config = this._prepare_generation_config(generation_config, kwargs); + + // 3. Define model inputs + let { inputs_tensor, model_inputs, model_input_name } = this._prepare_model_inputs({ + inputs, + model_kwargs: kwargs, + }); + + const is_encoder_decoder = this.config.is_encoder_decoder; + + // 4. Define other model kwargs + if (!is_encoder_decoder) { + // decoder-only models should use left-padding for generation + } else if (!('encoder_outputs' in model_inputs)) { + // if model is encoder decoder encoder_outputs are created + // and added to `model_kwargs` + model_inputs = await this._prepare_encoder_decoder_kwargs_for_generation( + { inputs_tensor, model_inputs, model_input_name } + ) + } + + // 5. Prepare `input_ids` which will be used for auto-regressive generation // TODO: Update to align with HF transformers' implementation - if (this.config.is_encoder_decoder) { + let input_ids; + if (is_encoder_decoder) { // Generating from the encoder outputs - input_ids_seq_length = 0; + ({ input_ids, model_inputs } = this._prepare_decoder_input_ids_for_generation({ + batch_size: model_inputs[model_input_name].dims.at(0), + model_input_name, + model_kwargs: model_inputs, + decoder_start_token_id: generation_config.decoder_start_token_id, + bos_token_id: generation_config.bos_token_id, + })) } else { - input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length; + input_ids = model_inputs[model_input_name] + } + // 6. Prepare `max_length` depending on other stopping criteria. + let input_ids_length = input_ids.dims.at(-1); - // decoder-only - if (input_ids_seq_length === 0) { - throw Error("Must supply a non-empty array of input token ids.") - } + if (generation_config.max_new_tokens !== null) { + generation_config.max_length = input_ids_length + generation_config.max_new_tokens; } - // Update generation config with defaults - generation_config = this._get_generation_config(generation_config); + // input_ids_length = model_inputs[model_input_name].dims.at(1); + // // inputs instanceof Tensor ? : inputs.length; - logits_processor = logits_processor ?? new LogitsProcessorList() + // // decoder-only + // if (input_ids_length === 0) { + // throw Error("Must supply a non-empty array of input token ids.") + // } + + // let decoder_input_ids = + // generation_config.decoder_input_ids + // ?? generation_config.decoder_start_token_id + // ?? generation_config.bos_token_id + // ?? generation_config.eos_token_id; // Update logits processor - logits_processor = this._get_logits_processor( + // 8. prepare distribution pre_processing samplers + const prepared_logits_processor = this._get_logits_processor( generation_config, - input_ids_seq_length, - logits_processor + input_ids_length, + logits_processor, ) - /** @type {number[]} */ - let eos_token_ids = generation_config.eos_token_id; - if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) { - eos_token_ids = [eos_token_ids]; - } + // 9. prepare stopping criteria + const prepared_stopping_criteria = this._get_stopping_criteria( + generation_config, stopping_criteria + ) - // TODO implement early_stopping - // https://huggingface.co/blog/how-to-generate + // /** @type {number[]} */ + // let eos_token_ids = generation_config.eos_token_id; + // if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) { + // eos_token_ids = [eos_token_ids]; + // } - let numOutputTokens = 1; - const maxOutputTokens = numOutputTokens + (generation_config.max_new_tokens ?? Infinity); + const numInputs = model_inputs[model_input_name].dims.at(0); - // Only use max length if max_new_tokens is not provided - const useMaxLength = Number.isInteger(generation_config.max_length) && (generation_config.max_new_tokens ?? null) === null; - let sampler = Sampler.getSampler(generation_config); + // TODO: + // done is a list of booleans to keep track of which inputs are done + // const done = new Array(numInputs).fill(false); + // For efficiency purposes, we remove completed rows from model_inputs + // when the beam is complete, and we keep track of the row index + // const rowIndexToBatchIndex = new Map(); - // @ts-ignore - let beams = this.getStartBeams(inputs, generation_config, numOutputTokens, inputs_attention_mask); - - while (beams.some(x => !x.done) && numOutputTokens < maxOutputTokens) { - let newest_beams = []; - for (let beam of beams) { - if (beam.done) { - // Add this beam back into the pool - newest_beams.push(beam); - continue - } - if (useMaxLength && beam.output_token_ids.length >= generation_config.max_length) { - // Set this beam to done and add it back into the pool - beam.done = true; - newest_beams.push(beam); - continue - } + const sampler = LogitsSampler.getSampler(generation_config); - // @ts-ignore - let output = await this.runBeam(beam); + // TODO make > numInputs + const scores = new Array(numInputs).fill(0); + const all_input_ids = input_ids.tolist(); + // const all_generated_input_ids = Array.from({ length: numInputs }, () => []); - // add attentions/scores to beam only if user requested - if (generation_config.output_attentions) { - this.addAttentionsToBeam(beam, output); - } - if (generation_config.output_scores) { - // TODO add + // NOTE: For now, we don't support spawning new beams + // TODO: when we do, we simply copy past key values and accumulate into single large tensor + + //////////////////////////////////////////////////// + // Generic search which handles 4 generation modes: + // - GenerationMode.GREEDY_SEARCH + // - GenerationMode.SAMPLE + // - GenerationMode.BEAM_SEARCH + // - GenerationMode.BEAM_SAMPLE + //////////////////////////////////////////////////// + while (true) { + // prepare model inputs + model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs) + + const outputs = await this.forward(model_inputs); + + // Logits are of the form [batch_size, out_seq_length, vocab_size] + // In most cases, this will be [batch_size, 1, vocab_size] + // So, we select the last token's logits: + // (equivalent to `logits = outputs.logits[:, -1, :]`) + const logits = outputs.logits.slice(null, -1, null); + + // only for this batch + const generated_input_ids = []; + // const new_kv_cache = [];// NOTE: Only used for beam search when concatenating new kv + // Loop over each batch + for (let batch_idx = 0; batch_idx < logits.dims.at(0); ++batch_idx) { + const logs = logits[batch_idx]; + + prepared_logits_processor(all_input_ids[batch_idx], logs); + let sampledTokens = sampler(logs); + for (let [newTokenId, logProb] of sampledTokens) { + const bigint = BigInt(newTokenId); + // TODO: If branching, use previous beam as a starting point + // update generated ids, model inputs, and length for next step + scores[batch_idx] += logProb; + all_input_ids[batch_idx].push(bigint); + generated_input_ids.push(bigint); } + } + // if(streamer) { + // streamer.put(next_tokens.cpu()) + // } - // Logits are of the form [batch_size, out_seq_length, vocab_size] - // In most cases, this will be [batch_size, 1, vocab_size] - // So, we select the last token's logits: - // (equivalent to `logits = outputs.logits[:, -1, :]`) - let logits = output.logits.slice(null, -1, null); + const stop = prepared_stopping_criteria(all_input_ids); + if (stop.every(x => x)) { + break; + } - // Apply logits processor - logits_processor(beam.output_token_ids, logits); + model_inputs = this._update_model_kwargs_for_generation({ + generated_input_ids, outputs, model_inputs, is_encoder_decoder, + }) + } - let sampledTokens = sampler(logits); - for (let [newTokenId, logProb] of sampledTokens) { - // use previous beam as a starting point - let newBeam = { ...beam }; + // TODO: ensure all_input_ids is padded correctly... + return new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]); - // update new beam - // @ts-ignore - this.updateBeam(newBeam, newTokenId); + // TODO: + // let numOutputTokens = 1; + // const maxOutputTokens = numOutputTokens + (generation_config.max_new_tokens ?? Infinity); - newBeam.score += logProb; + // // Only use max length if max_new_tokens is not provided + // const useMaxLength = Number.isInteger(generation_config.max_length) && (generation_config.max_new_tokens ?? null) === null; - if (eos_token_ids && eos_token_ids.includes(newTokenId)) { - newBeam.done = true; - } + // // console.log('inputs', inputs) + // let beams = this.getStartBeams(inputs, generation_config, numOutputTokens, inputs_attention_mask); - newest_beams.push(newBeam); - } - } - ++numOutputTokens; + // while (beams.some(x => !x.done) && numOutputTokens < maxOutputTokens) { + // let newest_beams = []; + // for (let beam of beams) { + // if (beam.done) { + // // Add this beam back into the pool + // newest_beams.push(beam); + // continue + // } + // if (useMaxLength && beam.output_token_ids.length >= generation_config.max_length) { + // // Set this beam to done and add it back into the pool + // beam.done = true; + // newest_beams.push(beam); + // continue + // } - // Next, we get the best beams, per ID - newest_beams = this.groupBeams(newest_beams).map( - group => group - .sort((a, b) => b.score - a.score) // sort by score - .slice(0, generation_config.num_beams) // remove outside beam width - ); + // // TODO generalize + // let output = await this.runBeam(beam); - // Flatten beams - beams = newest_beams.flat(); - // Run callback - if (generation_config.callback_function) { - generation_config.callback_function(beams); - } - } + // // add attentions/scores to beam only if user requested + // if (generation_config.output_attentions) { + // this.addAttentionsToBeam(beam, output); + // } + // if (generation_config.output_scores) { + // // TODO add + // } - // TODO: Ensure that we can return non-batched outputs + // let logits = output.logits.slice(null, -1, null); - const groupedBeams = this.groupBeams(beams); + // // Apply logits processor + // logits_processor(beam.output_token_ids, logits); - const getFlattened = (key) => groupedBeams.map( - batch => { - if (generation_config.num_return_sequences > 1) { - return batch.slice(0, generation_config.num_return_sequences).map(x => x[key]); - } else { - return [batch[0][key]]; - } - } - ).flat(); // Flatten across batches (depth=1) - - const sequences = getFlattened('output_token_ids'); // [1, seqLength] - - if (generation_config.return_dict_in_generate) { - // NOTE: `decoder_attentions` and `cross_attentions` should be: - // list (one element for each generated token) - // of list (one element for each layer of the decoder) - // of torch.FloatTensor of shape (batch_size, num_heads, generated_length, sequence_length) - // However, since we are only generating one batch at a time, they are of the form: - // list (batches) - // of list (one element for each generated token) - // of list (one element for each layer of the decoder) - // of torch.FloatTensor of shape (1, num_heads, generated_length, sequence_length) - // - // TODO: In future (when true parallelism, we should be able to return the correct shape) - - const decoder_attentions = getFlattened('decoder_attentions'); - const cross_attentions = getFlattened('cross_attentions'); - - return { - sequences, - - decoder_attentions, - cross_attentions, - } - } else { - return sequences; - } + // let sampledTokens = sampler(logits); + // for (let [newTokenId, logProb] of sampledTokens) { + // // use previous beam as a starting point + // let newBeam = { ...beam }; + + // // update new beam + // // @ts-ignore + // this.updateBeam(newBeam, newTokenId); + + // newBeam.score += logProb; + + // if (eos_token_ids && eos_token_ids.includes(newTokenId)) { + // newBeam.done = true; + // } + + // newest_beams.push(newBeam); + // } + // } + // ++numOutputTokens; + + // // Next, we get the best beams, per ID + // newest_beams = this.groupBeams(newest_beams).map( + // group => group + // .sort((a, b) => b.score - a.score) // sort by score + // .slice(0, generation_config.num_beams) // remove outside beam width + // ); + + // // Flatten beams + // beams = newest_beams.flat(); + + // // Run callback + // if (generation_config.callback_function) { + // throw new Error("Callback function not yet implemented") + // generation_config.callback_function(beams); + // } + // } + + // // TODO: Ensure that we can return non-batched outputs + + // const groupedBeams = this.groupBeams(beams); + + // const getFlattened = (key) => groupedBeams.map( + // batch => { + // if (generation_config.num_return_sequences > 1) { + // return batch.slice(0, generation_config.num_return_sequences).map(x => x[key]); + // } else { + // return [batch[0][key]]; + // } + // } + // ).flat(); // Flatten across batches (depth=1) + + // const sequences = getFlattened('output_token_ids'); // [1, seqLength] + + // if (generation_config.return_dict_in_generate) { + // // NOTE: `decoder_attentions` and `cross_attentions` should be: + // // list (one element for each generated token) + // // of list (one element for each layer of the decoder) + // // of torch.FloatTensor of shape (batch_size, num_heads, generated_length, sequence_length) + // // However, since we are only generating one batch at a time, they are of the form: + // // list (batches) + // // of list (one element for each generated token) + // // of list (one element for each layer of the decoder) + // // of torch.FloatTensor of shape (1, num_heads, generated_length, sequence_length) + // // + // // TODO: In future (when true parallelism, we should be able to return the correct shape) + + // const decoder_attentions = getFlattened('decoder_attentions'); + // const cross_attentions = getFlattened('cross_attentions'); + + // return { + // sequences, + + // decoder_attentions, + // cross_attentions, + // } + // } else { + // return sequences; + // } } /** @@ -1283,7 +1400,6 @@ export class PreTrainedModel extends Callable { * @returns {Object} An object containing past key values. */ getPastKeyValues(decoderResults, pastKeyValues) { - const pkvs = Object.create(null); for (const name in decoderResults) { @@ -1337,8 +1453,9 @@ export class PreTrainedModel extends Callable { } else { // TODO support batches (i.e., batch_size > 1) const batch_size = 1; - const dtype = this.config.precision || 'float32'; - const empty = (dtype === 'float16') ? new Uint16Array() : []; + const dtype = 'float32'; // this.config.precision || + const empty = []; + // (dtype === 'float16') ? new Uint16Array() : []; // @ts-ignore if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) { @@ -1392,39 +1509,6 @@ export class PreTrainedModel extends Callable { } } } - - /** - * Initializes and returns the beam for text generation task - * @param {Tensor} inputTokenIds The input token ids. - * @param {Object} generation_config The generation config. - * @param {number} numOutputTokens The number of tokens to be generated. - * @param {Tensor} inputs_attention_mask Optional input attention mask. - * @returns {any} A Beam object representing the initialized beam. - * @private - */ - getStartBeams(inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) { - return this._getStartBeams(this, inputTokenIds, generation_config, numOutputTokens, inputs_attention_mask) - } - - /** - * Runs a single step of the beam search generation algorithm. - * @param {any} beam The current beam being generated. - * @returns {Promise} The updated beam after a single generation step. - * @private - */ - async runBeam(beam) { - return await this._runBeam(this, beam); - } - - /** - * Update a beam with a new token ID. - * @param {Object} beam The beam to update. - * @param {number} newTokenId The new token ID to add to the beam's output. - * @private - */ - updateBeam(beam, newTokenId) { - return this._updateBeam(beam, newTokenId); - } } ////////////////////////////////////////////////// @@ -2281,17 +2365,11 @@ export class AlbertForMaskedLM extends AlbertPreTrainedModel { ////////////////////////////////////////////////// // T5 models -export class T5PreTrainedModel extends PreTrainedModel { }; - -export class T5Model extends T5PreTrainedModel { } - -/** - * T5Model is a class representing a T5 model for conditional generation. - */ -export class T5ForConditionalGeneration extends T5PreTrainedModel { +export class T5PreTrainedModel extends PreTrainedModel { + forward_params = ['input_ids', 'attention_mask', 'encoder_outputs', 'decoder_input_ids', 'decoder_attention_mask', 'past_key_values']; /** - * Creates a new instance of the `T5ForConditionalGeneration` class. + * Creates a new instance of the `T5PreTrainedModel` class. * @param {Object} config The model configuration. * @param {any} session session for the model. * @param {any} decoder_merged_session session for the decoder. @@ -2310,26 +2388,23 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { this.num_encoder_heads = this.config.num_heads; this.encoder_dim_kv = this.config.d_kv; } -} -////////////////////////////////////////////////// +}; +export class T5Model extends T5PreTrainedModel { } -////////////////////////////////////////////////// -// LONGT5 models /** - * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + * T5Model is a class representing a T5 model for conditional generation. */ -export class LongT5PreTrainedModel extends PreTrainedModel { }; +export class T5ForConditionalGeneration extends T5PreTrainedModel { } +////////////////////////////////////////////////// -/** - * The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top. - */ -export class LongT5Model extends LongT5PreTrainedModel { } +////////////////////////////////////////////////// +// LONGT5 models /** - * LONGT5 Model with a `language modeling` head on top. + * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. */ -export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { +export class LongT5PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `LongT5ForConditionalGeneration` class. * @param {Object} config The model configuration. @@ -2350,20 +2425,23 @@ export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { this.num_encoder_heads = this.config.num_heads; this.encoder_dim_kv = this.config.d_kv; } -} -////////////////////////////////////////////////// +}; +/** + * The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top. + */ +export class LongT5Model extends LongT5PreTrainedModel { } +/** + * LONGT5 Model with a `language modeling` head on top. + */ +export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { } ////////////////////////////////////////////////// -// MT5 models -export class MT5PreTrainedModel extends PreTrainedModel { }; -export class MT5Model extends MT5PreTrainedModel { } -/** - * A class representing a conditional sequence-to-sequence model based on the MT5 architecture. - */ -export class MT5ForConditionalGeneration extends MT5PreTrainedModel { +////////////////////////////////////////////////// +// MT5 models +export class MT5PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `MT5ForConditionalGeneration` class. @@ -2385,22 +2463,19 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { this.num_encoder_heads = this.config.num_heads; this.encoder_dim_kv = this.config.d_kv; } -} -////////////////////////////////////////////////// +}; -////////////////////////////////////////////////// -// Bart models -export class BartPretrainedModel extends PreTrainedModel { }; +export class MT5Model extends MT5PreTrainedModel { } /** - * The bare BART Model outputting raw hidden-states without any specific head on top. + * A class representing a conditional sequence-to-sequence model based on the MT5 architecture. */ -export class BartModel extends BartPretrainedModel { } +export class MT5ForConditionalGeneration extends MT5PreTrainedModel { } +////////////////////////////////////////////////// -/** - * The BART Model with a language modeling head. Can be used for summarization. - */ -export class BartForConditionalGeneration extends BartPretrainedModel { +////////////////////////////////////////////////// +// Bart models +export class BartPretrainedModel extends PreTrainedModel { /** * Creates a new instance of the `BartForConditionalGeneration` class. @@ -2422,8 +2497,17 @@ export class BartForConditionalGeneration extends BartPretrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } +}; -} +/** + * The bare BART Model outputting raw hidden-states without any specific head on top. + */ +export class BartModel extends BartPretrainedModel { } + +/** + * The BART Model with a language modeling head. Can be used for summarization. + */ +export class BartForConditionalGeneration extends BartPretrainedModel { } /** * Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) @@ -2444,17 +2528,7 @@ export class BartForSequenceClassification extends BartPretrainedModel { ////////////////////////////////////////////////// // MBart models -export class MBartPreTrainedModel extends PreTrainedModel { }; - -/** - * The bare MBART Model outputting raw hidden-states without any specific head on top. - */ -export class MBartModel extends MBartPreTrainedModel { } - -/** - * The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models. - */ -export class MBartForConditionalGeneration extends MBartPreTrainedModel { +export class MBartPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `MBartForConditionalGeneration` class. @@ -2476,8 +2550,17 @@ export class MBartForConditionalGeneration extends MBartPreTrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } +}; -} +/** + * The bare MBART Model outputting raw hidden-states without any specific head on top. + */ +export class MBartModel extends MBartPreTrainedModel { } + +/** + * The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models. + */ +export class MBartForConditionalGeneration extends MBartPreTrainedModel { } /** * MBart model with a sequence classification/head on top (a linear layer on top of the pooled output). @@ -2520,17 +2603,7 @@ export class MBartForCausalLM extends MBartPreTrainedModel { ////////////////////////////////////////////////// // Blenderbot models -export class BlenderbotPreTrainedModel extends PreTrainedModel { }; - -/** - * The bare Blenderbot Model outputting raw hidden-states without any specific head on top. - */ -export class BlenderbotModel extends BlenderbotPreTrainedModel { } - -/** - * The Blenderbot Model with a language modeling head. Can be used for summarization. - */ -export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel { +export class BlenderbotPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `BlenderbotForConditionalGeneration` class. @@ -2552,23 +2625,23 @@ export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedMode this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } -} +}; + +/** + * The bare Blenderbot Model outputting raw hidden-states without any specific head on top. + */ +export class BlenderbotModel extends BlenderbotPreTrainedModel { } + +/** + * The Blenderbot Model with a language modeling head. Can be used for summarization. + */ +export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Blenderbot models -export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { }; - -/** - * The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top. - */ -export class BlenderbotSmallModel extends BlenderbotSmallPreTrainedModel { } - -/** - * The BlenderbotSmall Model with a language modeling head. Can be used for summarization. - */ -export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel { +export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `BlenderbotForConditionalGeneration` class. @@ -2590,7 +2663,17 @@ export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreT this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } -} +}; + +/** + * The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top. + */ +export class BlenderbotSmallModel extends BlenderbotSmallPreTrainedModel { } + +/** + * The BlenderbotSmall Model with a language modeling head. Can be used for summarization. + */ +export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel { } ////////////////////////////////////////////////// @@ -2818,20 +2901,11 @@ export class ASTForAudioClassification extends ASTPreTrainedModel { } ////////////////////////////////////////////////// // Whisper models -export class WhisperPreTrainedModel extends PreTrainedModel { }; - -/** - * WhisperModel class for training Whisper models without a language model head. - */ -export class WhisperModel extends WhisperPreTrainedModel { } - -/** - * WhisperForConditionalGeneration class for generating conditional outputs from Whisper models. - */ -export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { +export class WhisperPreTrainedModel extends PreTrainedModel { requires_attention_mask = false; main_input_name = 'input_features'; + forward_params = ['input_features', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask', 'past_key_values']; /** * Creates a new instance of the `WhisperForConditionalGeneration` class. @@ -2853,37 +2927,113 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } +}; + +/** + * WhisperModel class for training Whisper models without a language model head. + */ +export class WhisperModel extends WhisperPreTrainedModel { } + + +class WhisperGenerationConfig extends GenerationConfig { /** - * @typedef {Object} WhisperGenerationConfig - * @extends GenerationConfig - * @property {boolean} [return_timestamps=null] Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. - * @property {boolean} [return_token_timestamps=null] Whether to return token-level timestamps + * Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. + * @type {boolean} + */ + return_timestamps = null; + + /** + * Whether to return token-level timestamps * with the text. This can be used with or without the `return_timestamps` option. To get word-level * timestamps, use the tokenizer to group the tokens into words. - * @property {number} [num_frames=null] The number of audio frames available in this chunk. This is only used generating word-level timestamps. + * @type {boolean} + */ + return_token_timestamps = null; + + /** + * The number of audio frames available in this chunk. This is only used generating word-level timestamps. + * @type {number} + */ + num_frames = null; + + /** + * Alignment heads to predict word-level timestamps. This is a list of [layer, head] pairs that + * select the cross-attention heads that are highly correlated to word-level timing. + * @type {[number, number][]} + */ + alignment_heads = null; + + /** + * Task to use for generation, either "translate" or "transcribe". + * @type {string} + */ + task = null; + + /** + * Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. + * You can find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. + * @type {string} + */ + language = null; +} + +/** + * WhisperForConditionalGeneration class for generating conditional outputs from Whisper models. + */ +export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { + + /** + * + * @param {WhisperGenerationConfig} generation_config + */ + _retrieve_init_tokens(generation_config) { + const init_tokens = [generation_config.decoder_start_token_id] + + throw new Error("Not implemented yet") + } + + /** + * @typedef {Object} WhisperGenerationSpecificParams + * @property {WhisperGenerationConfig} generation_config */ /** - * Generates outputs based on input and generation configuration. - * @param {Object} inputs Input data for the model. - * @param {WhisperGenerationConfig} generation_config Configuration object for the generation process. - * @param {Object} logits_processor Optional logits processor object. - * @returns {Promise} Promise object represents the generated outputs. + * Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids. + * @param {import('./generation/parameters.js').GenerationFunctionParameters & {generation_config: WhisperGenerationConfig} & WhisperGenerationConfig} options + * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores. */ - async generate( - inputs, + async generate({ + inputs = null, generation_config = null, logits_processor = null, - // { - // return_timestamps = null, - // return_token_timestamps = null, - // language = null, - // task = null, - // } = {}, - ) { + stopping_criteria = null, + + // Whisper-specific options + language = null, + task = null, + + ...kwargs + }) { + throw new Error("WhisperForConditionalGeneration.generate is not yet in Transformers.js v3.") + + // console.log('inputs', inputs); + // console.log('kwargs', kwargs); + // async generate({ + // inputs, + // }, + // generation_config = null, + // logits_processor = null, + // // { + // // return_timestamps = null, + // // return_token_timestamps = null, + // // language = null, + // // task = null, + // // } = {}, + // ) { // Create generation config object - generation_config = this._get_generation_config(generation_config); + // TODO: this doesn't create a WhisperGenerationConfig, it makes a GenerationConfig + generation_config = this._prepare_generation_config(generation_config); // Whisper has additional options for returning timestamps @@ -2892,7 +3042,8 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { // TODO add language and task if (generation_config.return_timestamps) { - logits_processor = [new WhisperTimeStampLogitsProcessor(generation_config)] + throw new Error("Not implemented yet") + // logits_processor = [new WhisperTimeStampLogitsProcessor(generation_config)] } if (generation_config.return_token_timestamps) { @@ -2911,7 +3062,13 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { } } - const outputs = await super.generate(inputs, generation_config, logits_processor); + const init_tokens = this._retrieve_init_tokens(generation_config) + + // https://github.com/huggingface/transformers/pull/28687/files + + const outputs = await super.generate({ + inputs, generation_config, logits_processor, ...kwargs + }); if (generation_config.return_token_timestamps && generation_config.alignment_heads) { outputs["token_timestamps"] = this._extract_token_timestamps( @@ -3062,7 +3219,7 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { } // Validate decoder - const decoderModel = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(decoderConfig.model_type); + const decoderModel = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(decoderConfig.model_type); if (!decoderModel) { throw new Error(`Unable to construct \`VisionEncoderDecoder\` due to unsupported decoder: "${this.config.decoder.model_type}"`); } @@ -3093,6 +3250,296 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { } ////////////////////////////////////////////////// + +export class LlavaPreTrainedModel extends PreTrainedModel { + forward_params = [ + 'input_ids', + 'past_key_values', + 'pixel_values', + 'attention_mask', + ]; + + constructor(config, session, embed_tokens_session, vision_encoder_session, generation_config) { + super(config, session); + this.input_embeds_session = embed_tokens_session; + this.vision_encoder_session = vision_encoder_session; + this.generation_config = generation_config; + + const decoderConfig = this.config.text_config; + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = decoderConfig.eos_token_id; + + this.num_heads = decoderConfig.num_attention_heads; + this.num_layers = decoderConfig.num_hidden_layers; + this.dim_kv = decoderConfig.hidden_size / this.num_heads; + } +} + +/** + * The LLAVA model which consists of a vision backbone and a language model. + */ +export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { + + /** + * + */ + async encode_image({ pixel_values }) { + // image_inputs === { pixel_values } + return (await sessionRun(this.vision_encoder_session, { pixel_values })).image_features; + } + + async encode_text({ input_ids }) { + // text_inputs === { input_ids, attention_mask } + return (await sessionRun(this.input_embeds_session, { input_ids })).inputs_embeds; + } + + // async answer_question({ + // image_embeds, + // question, + // tokenizer, + // chat_history = "", + // }) { + // const prompt = `\n\n${chat_history}Question: ${question}\n\nAnswer:`; + // const answer = await this.generate({ + // image_embeds, + // prompt, + // tokenizer, + // eos_text: "", + // max_new_tokens: 512, + // })[0]; + // console.log('answer', answer) + // return 'todo' + // const cleaned_answer = answer.replace(/<$| { + const input_ids = tokenizer(text); + return this.encode_text(input_ids); + } + + if (prompt.includes('')) { + const splits = prompt.split(''); + if (splits.length !== 2) { + throw new Error('Prompt should contain only one tag'); + } + const [before, after] = splits; + const embeds = []; + if (before.length > 0) { + embeds.push(await e(before)); + } + embeds.push(image_embeds); + if (after.length > 0) { + embeds.push(await e(after)); + } + console.log('embeds', embeds) + return cat(embeds, 1); + } else { + return await e(prompt); + } + } + + _merge_input_ids_with_image_features({ + inputs_embeds, + image_features, + input_ids, + attention_mask, + }) { + console.log('_merge_input_ids_with_image_features'); + + const image_token_index = this.config.image_token_index; + + // NOTE: we use .findIndex instead of .indexOf to perform weak comparison (==) between BigInt and Number + const idsList = input_ids.tolist(); + // TODO: validate at most 1 image token + const indexOfImage = idsList.map(x => x.findIndex(x => x == image_token_index)); + + if (!(indexOfImage.every(x => x === -1) || indexOfImage.every(x => x !== -1))) { + // Check for padding reasons + throw new Error('Every input should contain either 0 or 1 image token.'); + } + + let stacked = []; + for (let i = 0; i < indexOfImage.length; ++i) { + const index = indexOfImage[i]; + + const e = inputs_embeds[i]; + const im = image_features[i]; + if (index === -1) { + stacked.push(e); + } else { + const sliced = [ + e.slice([0, index]), + e.slice([index + 1, e.dims[0]]) + ]; + stacked.push( + cat([sliced[0], im, sliced[1]], 0) + ); + } + } + + return { + inputs_embeds: stack(stacked, 0), + attention_mask, + position_ids: null, + }; + } + + prepare_inputs_for_generation({ + input_ids, + past_key_values = null, + inputs_embeds = null, + pixel_values = null, + attention_mask = null, + ...kwargs + }) { + // if (input_ids.dims[0] !== 1) { + // throw new Error('Only single input is supported for now'); + // } + if (past_key_values) { + + } + + } + + + /** + * + * @param {Object} params + * @param {Tensor} [params.input_ids=null] + * @param {Tensor} [params.attention_mask=null] + * @param {Tensor} [params.pixel_values=null] + * @param {Tensor} [params.position_ids=null] + * @param {Tensor} [params.inputs_embeds=null] + * @param {Tensor} [params.past_key_values=null] + * @param {Object} [params.generation_config=null] + * @param {Object} [params.logits_processor=null] + * @returns + */ + async forward({ + // These are produced by the processors: + input_ids = null, + attention_mask = null, + pixel_values = null, + + // Used during generation: + position_ids = null, + inputs_embeds = null, + past_key_values = null, + + // Generic generation parameters + generation_config = null, + logits_processor = null, + + // TODO: needed? + ...kwargs + }) { + + if (!inputs_embeds) { + // 1. Extract the input embeddings + inputs_embeds = await this.encode_text({ input_ids }); + + // 2. Possibly, merge text and images + if (pixel_values && input_ids.dims[1] !== 1) { + const image_features = await this.encode_image({ pixel_values }); + + ({ inputs_embeds, inputs_embeds, position_ids } = this._merge_input_ids_with_image_features({ + image_features, + inputs_embeds, + input_ids, + attention_mask, + })); + } else if (past_key_values && pixel_values && input_ids.dims[1] === 1) { + // In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + // generation with cache + } + } + + // TEMP: for now, just recreate attention mask + attention_mask = ones(inputs_embeds.dims.slice(0, 2)); + + const outputs = await decoderForward(this, { + inputs_embeds, + past_key_values, + attention_mask, + generation_config, + logits_processor, + }) + // return super.generate({ + // inputs_embeds, + + // }) + return outputs; + + // TODO generalize with decoderForward + // let { input_ids, image_embeds, pixel_values, past_key_values, attention_mask } = model_inputs; + + // const inputs_embeds = await this.input_embeds({ prompt, image_embeds, tokenizer }); + // console.log('inputs_embeds', inputs_embeds) + + // const decoderFeeds = { + // inputs_embeds: inputs_embeds, + // attention_mask: ones_like(inputs_embeds), + // } + // let { input_ids, past_key_values, attention_mask } = model_inputs; + // let decoderFeeds = { + // inputs_embeds, + // attention_mask: ones(inputs_embeds.dims.slice(0, 2)), + // } + + // const use_cache_branch = !!past_key_values; + + // if (this.session.inputNames.includes('use_cache_branch')) { + // decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); + // } + + + // this.addPastKeyValues(decoderFeeds, past_key_values); + + // let decoderResults = await sessionRun(this.session, decoderFeeds); + + // let logits = decoderResults.logits; + + // past_key_values = this.getPastKeyValues(decoderResults, past_key_values); + // const result = { logits, past_key_values }; + // return result; + + // Retrieve the input embeddings + + // If an image is provided, compute the image embeddings + // if (!image_embeds && pixel_values) { + // image_embeds = await this.encode_image({ pixel_values }); + // console.log('image_embeds', image_embeds) + // } + + // let decoderFeeds = { + // inputs_embeds: inputs_embeds, + // attention_mask: attention_mask ?? ones_like(inputs_embeds), + // // prepareAttentionMask(this, input_ids), + // } + // const use_cache_branch = !!past_key_values; + + // if (this.session.inputNames.includes('use_cache_branch')) { + // decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); + // } + + + // this.addPastKeyValues(decoderFeeds, past_key_values); + // console.log('run', decoderFeeds) + + // let decoderResults = await sessionRun(this.session, decoderFeeds); + + // let logits = decoderResults.logits; + + // past_key_values = this.getPastKeyValues(decoderResults, past_key_values); + // return { logits, past_key_values }; + } +} + ////////////////////////////////////////////////// // CLIP models export class CLIPPreTrainedModel extends PreTrainedModel { } @@ -3412,7 +3859,7 @@ export class GPT2PreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.n_head this.num_layers = this.config.n_layer @@ -3445,7 +3892,7 @@ export class GPTNeoPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_heads; this.num_layers = this.config.num_layers; @@ -3471,7 +3918,7 @@ export class GPTNeoXPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_attention_heads; this.num_layers = this.config.num_hidden_layers; @@ -3498,7 +3945,7 @@ export class GPTJPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.n_head this.num_layers = this.config.n_layer @@ -3526,7 +3973,7 @@ export class GPTBigCodePreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.n_head this.num_layers = this.config.n_layer @@ -3553,7 +4000,7 @@ export class CodeGenPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.n_head this.num_layers = this.config.n_layer @@ -3590,7 +4037,7 @@ export class LlamaPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads this.num_layers = this.config.num_hidden_layers @@ -3623,7 +4070,7 @@ export class Qwen2PreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads this.num_layers = this.config.num_hidden_layers @@ -3654,7 +4101,7 @@ export class PhiPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id; + // this.config.pad_token_id = this.config.eos_token_id; this.num_heads = this.config.num_attention_heads; this.num_layers = this.config.num_hidden_layers; @@ -3687,7 +4134,7 @@ export class BloomPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.n_head this.num_layers = this.config.n_layer @@ -3720,7 +4167,7 @@ export class MptPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.n_heads this.num_layers = this.config.n_layers @@ -3754,7 +4201,7 @@ export class OPTPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_attention_heads; this.num_layers = this.config.num_hidden_layers; @@ -4498,11 +4945,7 @@ export class SamImageSegmentationOutput extends ModelOutput { ////////////////////////////////////////////////// // MarianMT models -export class MarianPreTrainedModel extends PreTrainedModel { }; - -export class MarianModel extends MarianPreTrainedModel { } - -export class MarianMTModel extends MarianPreTrainedModel { +export class MarianPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `MarianMTModel` class. @@ -4524,16 +4967,16 @@ export class MarianMTModel extends MarianPreTrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } -} +}; + +export class MarianModel extends MarianPreTrainedModel { } + +export class MarianMTModel extends MarianPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // M2M100 models -export class M2M100PreTrainedModel extends PreTrainedModel { }; - -export class M2M100Model extends M2M100PreTrainedModel { } - -export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { +export class M2M100PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `M2M100ForConditionalGeneration` class. @@ -4555,8 +4998,11 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } +}; -} +export class M2M100Model extends M2M100PreTrainedModel { } + +export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -4976,7 +5422,29 @@ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel { /** * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. */ -export class SpeechT5PreTrainedModel extends PreTrainedModel { }; +export class SpeechT5PreTrainedModel extends PreTrainedModel { + + /** + * Creates a new instance of the `SpeechT5ForTextToSpeech` class. + * @param {Object} config The model configuration. + * @param {any} session session for the model. + * @param {any} decoder_merged_session session for the decoder. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, decoder_merged_session, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; + + this.num_decoder_layers = this.config.decoder_layers; + this.num_decoder_heads = this.config.decoder_attention_heads; + this.decoder_dim_kv = this.config.hidden_size / this.num_decoder_heads; + + this.num_encoder_layers = this.config.encoder_layers; + this.num_encoder_heads = this.config.encoder_attention_heads; + this.encoder_dim_kv = this.config.hidden_size / this.num_encoder_heads; + } +}; /** * The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets. @@ -5030,27 +5498,6 @@ export class SpeechT5ForSpeechToText extends SpeechT5PreTrainedModel { } */ export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel { - /** - * Creates a new instance of the `SpeechT5ForTextToSpeech` class. - * @param {Object} config The model configuration. - * @param {any} session session for the model. - * @param {any} decoder_merged_session session for the decoder. - * @param {GenerationConfig} generation_config The generation configuration. - */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; - this.generation_config = generation_config; - - this.num_decoder_layers = this.config.decoder_layers; - this.num_decoder_heads = this.config.decoder_attention_heads; - this.decoder_dim_kv = this.config.hidden_size / this.num_decoder_heads; - - this.num_encoder_layers = this.config.encoder_layers; - this.num_encoder_heads = this.config.encoder_attention_heads; - this.encoder_dim_kv = this.config.hidden_size / this.num_encoder_heads; - } - /** * @typedef {Object} SpeechOutput * @property {Tensor} [spectrogram] The predicted log-mel spectrogram of shape @@ -5170,7 +5617,7 @@ export class TrOCRPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id; + // this.config.pad_token_id = this.config.eos_token_id; this.num_encoder_layers = this.num_decoder_layers = this.config.decoder_layers; this.num_encoder_heads = this.num_decoder_heads = this.config.decoder_attention_heads; @@ -5203,7 +5650,7 @@ export class MistralPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_key_value_heads; this.num_layers = this.config.num_hidden_layers; @@ -5234,7 +5681,7 @@ export class Starcoder2PreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_key_value_heads; this.num_layers = this.config.num_hidden_layers; @@ -5265,7 +5712,7 @@ export class FalconPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_attention_heads; this.num_layers = this.config.num_hidden_layers; @@ -5433,7 +5880,7 @@ export class StableLmPreTrainedModel extends PreTrainedModel { this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id + // this.config.pad_token_id = this.config.eos_token_id this.num_heads = this.config.num_attention_heads; this.num_layers = this.config.num_hidden_layers; @@ -5474,6 +5921,109 @@ export class EfficientNetForImageClassification extends EfficientNetPreTrainedMo } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// Musicgen models +export class MusicgenPreTrainedModel extends PreTrainedModel { } + +/** + * The bare Musicgen decoder model outputting raw hidden-states without any specific head on top. + */ +export class MusicgenModel extends MusicgenPreTrainedModel { } + +/** + * The MusicGen decoder model with a language modelling head on top. + */ +export class MusicgenForCausalLM extends MusicgenPreTrainedModel { } + +/** + * The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder, + * for music generation tasks with one or both of text and audio prompts. + */ +export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: not MusicgenPreTrainedModel + forward_params = ['input_ids', 'attention_mask', 'encoder_outputs', 'decoder_input_ids', 'decoder_attention_mask', 'past_key_values']; + + /** + * Creates a new instance of the `MusicgenForConditionalGeneration` class. + * @param {Object} config The model configuration. + * @param {any} session session for the model. + * @param {any} decoder_merged_session session for the decoder. + * @param {any} encodec_decode session for the encodec.decode function. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, decoder_merged_session, encodec_decode, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.encodec_decode = encodec_decode; + this.generation_config = generation_config; + + // decoder + const decoderConfig = config.decoder; + this.num_decoder_layers = decoderConfig.num_hidden_layers; + this.num_decoder_heads = decoderConfig.num_attention_heads; + this.decoder_dim_kv = decoderConfig.hidden_size / this.num_decoder_heads; + + // text encoder + const textConfig = config.text_encoder; + this.num_encoder_layers = textConfig.num_layers; + this.num_encoder_heads = textConfig.num_heads; + this.encoder_dim_kv = textConfig.d_model / textConfig.num_heads; // Should be textConfig.d_kv; + } + + /** + * Apply the pattern mask to the final ids, + * then revert the pattern delay mask by filtering the pad token id in a single step. + * @param {Tensor} outputs The output tensor from the model. + * @returns {Tensor} The filtered output tensor. + */ + _apply_and_filter_by_delay_pattern_mask(outputs) { + const [bs_x_codebooks, seqLength] = outputs.dims; + const num_codebooks = this.config.decoder.num_codebooks; + const upperBound = (seqLength - num_codebooks); + + let newDataSize = 0; + for (let i = 0; i < outputs.size; ++i) { + if (outputs.data[i] === this.config.pad_token_id) { + continue; + } + + const row = (i % seqLength); + const col = Math.floor(i / seqLength) % num_codebooks; + + const diff = row - col; + if (diff > 0 && diff <= upperBound) { + outputs.data[newDataSize++] = outputs.data[i]; + } + } + + const batch_size = Math.floor(bs_x_codebooks / num_codebooks); + const inferred = newDataSize / (batch_size * num_codebooks); + // TODO: assert `inferred` is an integer + return new Tensor( + outputs.type, + outputs.data.slice(0, newDataSize), [batch_size, num_codebooks, inferred]); + } + + /** + * Generates sequences of token ids for models with a language modeling head. + * @param {import('./generation/parameters.js').GenerationFunctionParameters} options + * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores. + */ + async generate(options) { + let output_ids = await super.generate(options); + + // apply the pattern mask to the final ids + output_ids = this._apply_and_filter_by_delay_pattern_mask( + /** @type {Tensor} */(output_ids) + ).unsqueeze_(0); // append the frame dimension back to the audio codes + + const output_values = await sessionRun(this.encodec_decode, { + // tensor: int64[1,batch_size,4,chunk_length] + audio_codes: output_ids, + }) + + return output_values.audio_values; + } +} ////////////////////////////////////////////////// // AutoModels, used to simplify construction of PreTrainedModels @@ -5497,7 +6047,7 @@ export class PretrainedMixin { static BASE_IF_FAIL = false; - /** @type {PreTrainedModel.from_pretrained} */ + /** @type {typeof PreTrainedModel.from_pretrained} */ static async from_pretrained(pretrained_model_name_or_path, { progress_callback = null, config = null, @@ -5505,6 +6055,7 @@ export class PretrainedMixin { local_files_only = false, revision = 'main', model_file_name = null, + subfolder = 'onnx', device = null, dtype = null, session_options = {}, @@ -5517,6 +6068,7 @@ export class PretrainedMixin { local_files_only, revision, model_file_name, + subfolder, device, dtype, session_options, @@ -5646,6 +6198,7 @@ const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([ const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([ ['vits', ['VitsModel', VitsModel]], + ['musicgen', ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration]], ]); const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ @@ -5697,7 +6250,7 @@ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([ ['blenderbot-small', ['BlenderbotSmallForConditionalGeneration', BlenderbotSmallForConditionalGeneration]], ]); -const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ +const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([ ['bloom', ['BloomForCausalLM', BloomForCausalLM]], ['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]], ['gptj', ['GPTJForCausalLM', GPTJForCausalLM]], @@ -5759,6 +6312,11 @@ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], ]); +const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ + ['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]], + // ['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]], +]); + const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], ]); @@ -5859,10 +6417,11 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], - [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES.DecoderOnly], + [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly], [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq], + [MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText], [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -5893,6 +6452,10 @@ for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { } const CUSTOM_MAPPING = [ + // OVERRIDE: + // TODO: Refactor to allow class to specify model + ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration, MODEL_TYPES.Musicgen], + ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], ['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly], ['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly], @@ -5993,7 +6556,7 @@ export class AutoModelForTextToWaveform extends PretrainedMixin { * let model = await AutoModelForCausalLM.from_pretrained('Xenova/gpt2'); */ export class AutoModelForCausalLM extends PretrainedMixin { - static MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]; + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES]; } /** diff --git a/src/pipelines.js b/src/pipelines.js index 60f3c960e..50ca9b230 100755 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -47,9 +47,11 @@ import { Processor } from './processors.js'; - import { Callable, +} from './utils/generic.js'; + +import { dispatchCallback, pop, product, @@ -644,7 +646,7 @@ export class FillMaskPipeline extends (/** @type {new (options: TextPipelineCons * * @callback Text2TextGenerationPipelineCallback Generate the output text(s) using text(s) given as inputs. * @param {string|string[]} texts Input text for the encoder. - * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model. + * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise} * * @typedef {TextPipelineConstructorArgs & Text2TextGenerationPipelineCallback & Disposable} Text2TextGenerationPipelineType @@ -676,6 +678,7 @@ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextP /** @type {Text2TextGenerationPipelineCallback} */ async _call(texts, generate_kwargs = {}) { + throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented if (!Array.isArray(texts)) { texts = [texts]; } @@ -713,9 +716,9 @@ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextP input_ids = tokenizer(texts, tokenizer_options).input_ids; } - const outputTokenIds = await this.model.generate(input_ids, generate_kwargs); + const outputTokenIds = await this.model.generate({ inputs: input_ids, ...generate_kwargs }); - return tokenizer.batch_decode(outputTokenIds, { + return tokenizer.batch_decode(/** @type {Tensor} */(outputTokenIds), { skip_special_tokens: true, }).map(text => ({ [this._key]: text })); } @@ -729,7 +732,7 @@ export class Text2TextGenerationPipeline extends (/** @type {new (options: TextP * * @callback SummarizationPipelineCallback Summarize the text(s) given as inputs. * @param {string|string[]} texts One or several articles (or one list of articles) to summarize. - * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model. + * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise} * * @typedef {TextPipelineConstructorArgs & SummarizationPipelineCallback & Disposable} SummarizationPipelineType @@ -776,7 +779,7 @@ export class SummarizationPipeline extends (/** @type {new (options: TextPipelin * * @callback TranslationPipelineCallback Translate the text(s) given as inputs. * @param {string|string[]} texts Texts to be translated. - * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model. + * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise} * * @typedef {TextPipelineConstructorArgs & TranslationPipelineCallback & Disposable} TranslationPipelineType @@ -848,7 +851,7 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC * * @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines. * @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences. - * @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig + * @typedef {import('./generation/configuration_utils.js').GenerationConfig & TextGenerationSpecificParams} TextGenerationConfig * * @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs. * @param {string|string[]} texts One or several prompts (or one list of prompts) to complete. @@ -920,6 +923,7 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli /** @type {TextGenerationPipelineCallback} */ async _call(texts, generate_kwargs = {}) { + throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented const isBatched = Array.isArray(texts); if (!isBatched) { @@ -936,9 +940,13 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli truncation: true, }); - const outputTokenIds = await this.model.generate(input_ids, generate_kwargs, null, { - inputs_attention_mask: attention_mask - }); + const outputTokenIds = /** @type {Tensor} */(await this.model.generate({ + inputs: input_ids, + + // TODO: add back? + // inputs_attention_mask: attention_mask, + ...generate_kwargs + })); const decoded = this.tokenizer.batch_decode(outputTokenIds, { skip_special_tokens: true, @@ -1504,7 +1512,7 @@ export class ZeroShotAudioClassificationPipeline extends (/** @type {new (option * @property {number[][]} [kwargs.forced_decoder_ids] A list of pairs of integers which indicates a mapping from generation indices to token indices * that will be forced before sampling. For example, [[1, 123]] means the second generated token will always be a token of index 123. * @property {number} [num_frames] The number of frames in the input audio. - * @typedef {import('./utils/generation.js').GenerationConfigType & AutomaticSpeechRecognitionSpecificParams} AutomaticSpeechRecognitionConfig + * @typedef {import('./generation/configuration_utils.js').GenerationConfig & AutomaticSpeechRecognitionSpecificParams} AutomaticSpeechRecognitionConfig * * @callback AutomaticSpeechRecognitionPipelineCallback Transcribe the audio sequence(s) given as inputs to text. * @param {AudioPipelineInputs} audio The input audio file(s) to be transcribed. The input is either: @@ -1598,6 +1606,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options /** @type {AutomaticSpeechRecognitionPipelineCallback} */ async _call(audio, kwargs = {}) { + throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented switch (this.model.config.model_type) { case 'whisper': return this._call_whisper(audio, kwargs) @@ -1742,7 +1751,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options kwargs.num_frames = Math.floor(chunk.stride[0] / hop_length); // NOTE: doing sequentially for now - const data = await this.model.generate(chunk.input_features, kwargs); + const data = await this.model.generate({ inputs: chunk.input_features, ...kwargs }); // TODO: Right now we only get top beam if (return_timestamps === 'word') { @@ -1782,7 +1791,7 @@ export class AutomaticSpeechRecognitionPipeline extends (/** @type {new (options * * @callback ImageToTextPipelineCallback Assign labels to the image(s) passed as inputs. * @param {ImagePipelineInputs} texts The images to be captioned. - * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model. + * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise} An object (or array of objects) containing the generated text(s). * * @typedef {TextImagePipelineConstructorArgs & ImageToTextPipelineCallback & Disposable} ImageToTextPipelineType @@ -1819,6 +1828,7 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe /** @type {ImageToTextPipelineCallback} */ async _call(images, generate_kwargs = {}) { + throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented const isBatched = Array.isArray(images); const preparedImages = await prepareImages(images); @@ -1828,8 +1838,8 @@ export class ImageToTextPipeline extends (/** @type {new (options: TextImagePipe const toReturn = []; for (const batch of pixel_values) { batch.dims = [1, ...batch.dims] - const output = await this.model.generate(batch, generate_kwargs); - const decoded = this.tokenizer.batch_decode(output, { + const output = await this.model.generate({ inputs: batch, ...generate_kwargs }); + const decoded = this.tokenizer.batch_decode(/** @type {Tensor} */(output), { skip_special_tokens: true, }).map(x => ({ generated_text: x.trim() })) toReturn.push(decoded); @@ -2421,7 +2431,7 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T * @callback DocumentQuestionAnsweringPipelineCallback Answer the question given as input by using the document. * @param {ImageInput} image The image of the document to use. * @param {string} question A question to ask of the document. - * @param {import('./utils/generation.js').GenerationConfigType} [options] Additional keyword arguments to pass along to the generate method of the model. + * @param {import('./generation/configuration_utils.js').GenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise} An object (or array of objects) containing the answer(s). * * @typedef {TextImagePipelineConstructorArgs & DocumentQuestionAnsweringPipelineCallback & Disposable} DocumentQuestionAnsweringPipelineType @@ -2453,6 +2463,7 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: /** @type {DocumentQuestionAnsweringPipelineCallback} */ async _call(image, question, generate_kwargs = {}) { + throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented // NOTE: For now, we only support a batch size of 1 @@ -2469,17 +2480,15 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: }).input_ids; // Run model - const output = await this.model.generate( - pixel_values, - { - ...generate_kwargs, - decoder_input_ids, - max_length: this.model.config.decoder.max_position_embeddings, - } - ); + const output = await this.model.generate({ + inputs: pixel_values, + max_length: this.model.config.decoder.max_position_embeddings, + decoder_input_ids, + ...generate_kwargs, + }); // Decode output - const decoded = this.tokenizer.batch_decode(output)[0]; + const decoded = this.tokenizer.batch_decode(/** @type {Tensor} */(output))[0]; // Parse answer const match = decoded.match(/(.*?)<\/s_answer>/); @@ -2568,6 +2577,7 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi async _call(text_inputs, { speaker_embeddings = null, } = {}) { + throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented // If this.processor is not set, we are using a `AutoModelForTextToWaveform` model if (this.processor) { diff --git a/src/processors.js b/src/processors.js index 38ff2797e..cd1e09a55 100644 --- a/src/processors.js +++ b/src/processors.js @@ -21,6 +21,9 @@ */ import { Callable, +} from './utils/generic.js'; + +import { calculateDimensions, calculateReflectOffset, } from './utils/core.js'; @@ -775,6 +778,7 @@ export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends export class BitImageProcessor extends ImageFeatureExtractor { } export class GLPNFeatureExtractor extends ImageFeatureExtractor { } export class CLIPFeatureExtractor extends ImageFeatureExtractor { } +export class CLIPImageProcessor extends CLIPFeatureExtractor { } // NOTE: extends CLIPFeatureExtractor export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { } export class SiglipImageProcessor extends ImageFeatureExtractor { } export class ConvNextFeatureExtractor extends ImageFeatureExtractor { @@ -2162,6 +2166,7 @@ export class AutoProcessor { OwlViTFeatureExtractor, Owlv2ImageProcessor, CLIPFeatureExtractor, + CLIPImageProcessor, ChineseCLIPFeatureExtractor, SiglipImageProcessor, ConvNextFeatureExtractor, diff --git a/src/tokenizers.js b/src/tokenizers.js index 5b58e37c0..00a52c484 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -19,9 +19,11 @@ * * @module tokenizers */ - import { Callable, +} from './utils/generic.js'; + +import { reverseDictionary, escapeRegExp, isIntegralNumber, @@ -2398,7 +2400,7 @@ const SPECIAL_TOKEN_ATTRIBUTES = [ * @param {Record} item The input object. * @param {number} length The length to pad to. * @param {(key: string) => any} value_fn Determine the value to fill the array, based on its key. - * @param {'right'|'left'} side Which side to pad the array. + * @param {string} side Which side to pad the array. * @private */ function padHelper(item, length, value_fn, side) { @@ -2434,6 +2436,7 @@ export class PreTrainedTokenizer extends Callable { _default_chat_template = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}`; + padding_side = 'right'; /** * Create a new PreTrainedTokenizer instance. * @param {Object} tokenizerJSON The JSON of the tokenizer. @@ -2512,9 +2515,9 @@ export class PreTrainedTokenizer extends Callable { this.clean_up_tokenization_spaces = tokenizerConfig.clean_up_tokenization_spaces ?? true; this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false; - // TODO allow user to change this - /** @type {'right'|'left'} */ - this.padding_side = 'right'; + if (tokenizerConfig.padding_side) { + this.padding_side = tokenizerConfig.padding_side; + } this.legacy = false; @@ -3192,6 +3195,8 @@ export class LlamaTokenizer extends PreTrainedTokenizer { "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not " + "correct. If you don't know the answer to a question, please don't share false information." + padding_side = 'left'; + constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); this.use_default_system_prompt = tokenizerConfig.use_default_system_prompt ?? false; @@ -4386,7 +4391,6 @@ export class AutoTokenizer { * @returns {Promise} A new instance of the PreTrainedTokenizer class. */ static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, progress_callback = null, config = null, cache_dir = null, @@ -4396,7 +4400,6 @@ export class AutoTokenizer { } = {}) { const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { - quantized, progress_callback, config, cache_dir, diff --git a/src/utils/core.js b/src/utils/core.js index 4ed0f15ef..ba43aaf40 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -42,40 +42,6 @@ export function escapeRegExp(string) { return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string } -/** - * A base class for creating callable objects. - * - * @type {new () => {(...args: any[]): any, _call(...args: any[]): any}} - */ -export const Callable = /** @type {any} */ (class { - /** - * Creates a new instance of the Callable class. - */ - constructor() { - /** - * Creates a closure that delegates to a private method '_call' with the given arguments. - * @type {any} - * @param {...any} args Zero or more arguments to pass to the '_call' method. - * @returns {*} The result of calling the '_call' method. - */ - let closure = function (...args) { - return closure._call(...args) - } - return Object.setPrototypeOf(closure, new.target.prototype) - } - - /** - * This method should be implemented in subclasses to provide the - * functionality of the callable object. - * - * @param {any[]} args - * @throws {Error} If the subclass does not implement the `_call` method. - */ - _call(...args) { - throw Error('Must implement _call method in subclass') - } -}); - /** * Check if a value is a typed array. * @param {*} val The value to check. @@ -173,3 +139,20 @@ export function product(...a) { export function calculateReflectOffset(i, w) { return Math.abs((i + w) % (2 * w) - w); } + +/** + * + * @param {Object} o + * @param {string[]} props + * @returns {Object} + */ +export function pick(o, props) { + return Object.assign( + {}, + ...props.map((prop) => { + if (o[prop] !== undefined) { + return { [prop]: o[prop] }; + } + }) + ); +} diff --git a/src/utils/generation.js b/src/utils/generation.js deleted file mode 100644 index 1f9dc898b..000000000 --- a/src/utils/generation.js +++ /dev/null @@ -1,873 +0,0 @@ - -/** - * @file Classes, functions, and utilities for generation. - * - * @todo Describe how to create a custom `GenerationConfig`. - * - * @module utils/generation - */ -import { Tensor } from './tensor.js'; -import { - Callable, - exists, -} from './core.js'; -import { - max, - softmax, - log_softmax, - getTopItems, -} from './maths.js'; - -/** - * A class representing a list of logits processors. A logits processor is a function that modifies the logits - * output of a language model. This class provides methods for adding new processors and applying all processors to a - * batch of logits. - * - * @extends Callable - */ -export class LogitsProcessorList extends Callable { - /** - * Constructs a new instance of `LogitsProcessorList`. - */ - constructor() { - super(); - this.processors = []; - } - - /** - * Adds a new logits processor to the list. - * - * @param {LogitsProcessor} item The logits processor function to add. - */ - push(item) { - this.processors.push(item); - } - - /** - * Adds multiple logits processors to the list. - * - * @param {LogitsProcessor[]} items The logits processor functions to add. - */ - extend(items) { - this.processors.push(...items); - } - - /** - * Applies all logits processors in the list to a batch of logits, modifying them in-place. - * - * @param {number[]} input_ids The input IDs for the language model. - * @param {number[][]} batchedLogits A 2D array of logits, where each row corresponds to a single - * input sequence in the batch. - */ - _call(input_ids, batchedLogits) { - // NOTE: This is different from the Python code, since vanilla JS does not support vectorized operations. - // As a result, we apply each processor to each item in the batch. - for (let logits of batchedLogits) { - // Modifies logits inplace - this.processors.forEach( - func => func(input_ids, logits) - ) - } - } - - [Symbol.iterator]() { - return this.processors.values(); - } -} - -/** - * Base class for processing logits. - * @extends Callable - */ -export class LogitsProcessor extends Callable { - /** - * Apply the processor to the input logits. - * - * @abstract - * @param {Array} input_ids The input ids. - * @param {Tensor} logits The logits to process. - * @throws {Error} Throws an error if `_call` is not implemented in the subclass. - */ - _call(input_ids, logits) { - throw Error("`_call` should be implemented in a subclass") - } -} - -/** - * A logits processor that forces a specific token to be generated by the decoder. - * - * @extends LogitsProcessor - */ -export class ForceTokensLogitsProcessor extends LogitsProcessor { - /** - * Constructs a new instance of `ForceTokensLogitsProcessor`. - * - * @param {Array} forced_decoder_ids The ids of tokens that should be forced. - */ - constructor(forced_decoder_ids) { - super(); - this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []); - } - - /** - * Apply the processor to the input logits. - * - * @param {Array} input_ids The input ids. - * @param {Tensor} logits The logits to process. - * @returns {Tensor} The processed logits. - */ - _call(input_ids, logits) { - let map = this.force_token_map[input_ids.length]; - if (exists(map)) { // There exists a mapping - logits.data.fill(-Infinity) - logits.data[map] = 0; - } - return logits; - } -} - -/** - * A LogitsProcessor that forces a BOS token at the beginning of the generated sequence. - * @extends LogitsProcessor - */ -export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { - /** - * Create a ForcedBOSTokenLogitsProcessor. - * @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced. - */ - constructor(bos_token_id) { - super(); - this.bos_token_id = bos_token_id; - } - - /** - * Apply the BOS token forcing to the logits. - * @param {Array} input_ids The input IDs. - * @param {Object} logits The logits. - * @returns {Object} The logits with BOS token forcing. - */ - _call(input_ids, logits) { - if (input_ids.length === 1) { - logits.data.fill(-Infinity) - logits.data[this.bos_token_id] = 0; - } - return logits; - } -} - -/** - * A logits processor that forces end-of-sequence token probability to 1. - * - * @extends LogitsProcessor - */ -export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor { - /** - * Create a ForcedEOSTokenLogitsProcessor. - * @param {number} max_length Max length of the sequence. - * @param {number|number[]} forced_eos_token_id The ID of the end-of-sequence token to be forced. - */ - constructor(max_length, forced_eos_token_id) { - super(); - this.max_length = max_length; - this.forced_eos_token_id = forced_eos_token_id; - } - - /** - * Apply the processor to input_ids and logits. - * - * @param {number[]} input_ids The input ids. - * @param {Tensor} logits The logits tensor. - */ - _call(input_ids, logits) { - // console.log('call ForcedEOSTokenLogitsProcessor') - // TODO - } -} - -/** - * A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts - * generating using `begin_index` tokens. This should ensure that the tokens defined by - * `begin_suppress_tokens` at not sampled at the begining of the generation. - * @extends LogitsProcessor - */ -export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { - /** - * Create a SuppressTokensAtBeginLogitsProcessor. - * @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress. - * @param {number} begin_index The number of tokens to generate before suppressing tokens. - */ - constructor(begin_suppress_tokens, begin_index) { - super(); - this.begin_suppress_tokens = begin_suppress_tokens; - this.begin_index = begin_index; - } - - /** - * Apply the BOS token forcing to the logits. - * @param {Array} input_ids The input IDs. - * @param {Object} logits The logits. - * @returns {Object} The logits with BOS token forcing. - */ - _call(input_ids, logits) { - if (input_ids.length === this.begin_index) { - for (let token_id of this.begin_suppress_tokens) { - logits.data[token_id] = -Infinity; - } - } - return logits; - } -} - -/** - * A LogitsProcessor that handles adding timestamps to generated text. - * @extends LogitsProcessor - */ -export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { - /** - * Constructs a new WhisperTimeStampLogitsProcessor. - * @param {Object} generate_config The config object passed to the `generate()` method of a transformer model. - * @param {number} generate_config.eos_token_id The ID of the end-of-sequence token. - * @param {number} generate_config.no_timestamps_token_id The ID of the token used to indicate that a token should not have a timestamp. - * @param {number[][]} [generate_config.forced_decoder_ids] An array of two-element arrays representing decoder IDs that are forced to appear in the output. The second element of each array indicates whether the token is a timestamp. - * @param {number} [generate_config.max_initial_timestamp_index] The maximum index at which an initial timestamp can appear. - */ - constructor(generate_config) { - super(); - this.eos_token_id = generate_config.eos_token_id; - this.no_timestamps_token_id = generate_config.no_timestamps_token_id; - this.timestamp_begin = this.no_timestamps_token_id + 1; - - this.begin_index = (generate_config.forced_decoder_ids || []).length + 2; - if (generate_config.forced_decoder_ids.slice(-1)[0][1] === this.no_timestamps_token_id) { - this.begin_index -= 1; - } - this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index; - - } - - /** - * Modify the logits to handle timestamp tokens. - * @param {Array} input_ids The input sequence of tokens. - * @param {Tensor} logits The logits output by the model. - * @returns {Tensor} The modified logits. - */ - _call(input_ids, logits) { - const logitsData = /** @type {Float32Array} */(logits.data); - - // suppress <|notimestamps|> which is handled by without_timestamps - logitsData[this.no_timestamps_token_id] = -Infinity; - - if (input_ids.length === this.begin_index - 1) { - logitsData.fill(-Infinity); - logitsData[this.timestamp_begin] = 0; - return logits; - } - - // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly - const seq = input_ids.slice(this.begin_index); - const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin; - const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin; - - if (last_was_timestamp) { - if (penultimate_was_timestamp) { // has to be non-timestamp - logitsData.subarray(this.timestamp_begin).fill(-Infinity); - } else { // cannot be normal text tokens - logitsData.subarray(0, this.eos_token_id).fill(-Infinity); - } - } - - // apply the `max_initial_timestamp` option - if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) { - const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index; - logitsData.subarray(last_allowed + 1).fill(-Infinity); - } - - // if sum of probability over timestamps is above any other token, sample timestamp - const logprobs = log_softmax(logitsData); - const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b)); - const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0]; - - if (timestamp_logprob > max_text_token_logprob) { - logitsData.subarray(0, this.timestamp_begin).fill(-Infinity); - } - - return logits; - } -} - -/** - * A logits processor that disallows ngrams of a certain size to be repeated. - * - * @extends LogitsProcessor - */ -export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { - /** - * Create a NoRepeatNGramLogitsProcessor. - * @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once. - */ - constructor(no_repeat_ngram_size) { - super(); - this.no_repeat_ngram_size = no_repeat_ngram_size; - } - - /** - * Generate n-grams from a sequence of token ids. - * @param {number[]} prevInputIds List of previous input ids - * @returns {Map} Map of generated n-grams - */ - getNgrams(prevInputIds) { - const curLen = prevInputIds.length; - - /**@type {number[][]} */ - const ngrams = []; - for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) { - const ngram = []; - for (let k = 0; k < this.no_repeat_ngram_size; ++k) { - ngram.push(prevInputIds[j + k]); - } - ngrams.push(ngram); - } - - /** @type {Map} */ - const generatedNgram = new Map(); - for (const ngram of ngrams) { - const prevNgram = ngram.slice(0, ngram.length - 1); - const prevNgramKey = JSON.stringify(prevNgram); - const prevNgramValue = generatedNgram.get(prevNgramKey) ?? []; - prevNgramValue.push(ngram[ngram.length - 1]); - generatedNgram.set(prevNgramKey, prevNgramValue); - } - return generatedNgram; - } - - /** - * Generate n-grams from a sequence of token ids. - * @param {Map} bannedNgrams Map of banned n-grams - * @param {number[]} prevInputIds List of previous input ids - * @returns {number[]} Map of generated n-grams - */ - getGeneratedNgrams(bannedNgrams, prevInputIds) { - const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length); - const banned = bannedNgrams.get(JSON.stringify(ngramIdx)) ?? []; - return banned; - } - - /** - * Calculate banned n-gram tokens - * @param {number[]} prevInputIds List of previous input ids - * @returns {number[]} Map of generated n-grams - */ - calcBannedNgramTokens(prevInputIds) { - const bannedTokens = []; - if (prevInputIds.length + 1 < this.no_repeat_ngram_size) { - // return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - return bannedTokens; - - } else { - const generatedNgrams = this.getNgrams(prevInputIds); - const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds); - return bannedTokens; - } - } - - /** - * Apply the no-repeat-ngram processor to the logits. - * @param {Array} input_ids The input IDs. - * @param {Object} logits The logits. - * @returns {Object} The logits with no-repeat-ngram processing. - */ - _call(input_ids, logits) { - const bannedTokens = this.calcBannedNgramTokens(input_ids); - - for (const token of bannedTokens) { - logits.data[token] = -Infinity; - } - return logits; - } -} - -/** - * A logits processor that penalises repeated output tokens. - * - * @extends LogitsProcessor - */ -export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { - /** - * Create a RepetitionPenaltyLogitsProcessor. - * @param {number} penalty The penalty to apply for repeated tokens. - */ - constructor(penalty) { - super(); - this.penalty = penalty; - } - - /** - * Apply the repetition penalty to the logits. - * @param {Array} input_ids The input IDs. - * @param {Object} logits The logits. - * @returns {Object} The logits with repetition penalty processing. - */ - _call(input_ids, logits) { - // Modify the logits corresponding to each element in `input_ids`. - // As a consequence, the logits corresponding to tokens that appear - // many times in the output will be penalised more. - for (const input_id of input_ids) { - if (logits.data[input_id] < 0) { - logits.data[input_id] *= this.penalty; - } else { - logits.data[input_id] /= this.penalty; - } - } - return logits - } -} - -/** - * A logits processor that enforces a minimum number of tokens. - * - * @extends LogitsProcessor - */ -export class MinLengthLogitsProcessor extends LogitsProcessor { - /** - * Create a MinLengthLogitsProcessor. - * @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity. - * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. - */ - constructor(min_length, eos_token_id) { - super(); - this.min_length = min_length; - this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; - } - - /** - * Apply logit processor. - * @param {Array} input_ids The input IDs. - * @param {Object} logits The logits. - * @returns {Object} The processed logits. - */ - _call(input_ids, logits) { - if (input_ids.length < this.min_length) { - for (const eos_token of this.eos_token_id) { - logits.data[eos_token] = -Infinity; - } - } - - return logits - } -} - -/** - * A logits processor that enforces a minimum number of new tokens. - * - * @extends LogitsProcessor - */ -export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { - /** - * Create a MinNewTokensLengthLogitsProcessor. - * @param {number} prompt_length_to_skip The input tokens length. - * @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity. - * @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token. - */ - constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) { - super(); - this.prompt_length_to_skip = prompt_length_to_skip; - this.min_new_tokens = min_new_tokens; - this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; - } - - /** - * Apply logit processor. - * @param {Array} input_ids The input IDs. - * @param {Object} logits The logits. - * @returns {Object} The processed logits. - */ - _call(input_ids, logits) { - const new_tokens_length = input_ids.length - this.prompt_length_to_skip; - if (new_tokens_length < this.min_new_tokens) { - for (const eos_token of this.eos_token_id) { - logits.data[eos_token] = -Infinity; - } - } - - return logits - } -} - -export class NoBadWordsLogitsProcessor extends LogitsProcessor { - /** - * Create a `NoBadWordsLogitsProcessor`. - * @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated. - * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - */ - constructor(bad_words_ids, eos_token_id) { - super(); - this.bad_words_ids = bad_words_ids; - this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; - } - - /** - * Apply logit processor. - * @param {Array} input_ids The input IDs. - * @param {Object} logits The logits. - * @returns {Object} The processed logits. - */ - _call(input_ids, logits) { - - for (const bad_word_ids of this.bad_words_ids) { - // Whether to modify the logits of the last token in the bad word id sequence - let mark = true; - - // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last), - // then we set the logits of the last bad word id to -Infinity. - for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) { - - if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) { - // We have found a mismatch - mark = false; - break; - } - } - if (mark) { - logits.data[bad_word_ids.at(-1)] = -Infinity; - } - } - - return logits - } -} - -/** - * @typedef {Object} GenerationConfigType The default configuration parameters. - * @property {number} [max_length=20] The maximum length the generated tokens can have. Corresponds to the length of the input prompt + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. - * @property {number} [max_new_tokens=null] The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. - * @property {number} [min_length=0] The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set. - * @property {number} [min_new_tokens=null] The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. - * @property {boolean|"never"} [early_stopping=false] Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: - * - `true`, where the generation stops as soon as there are `num_beams` complete candidates; - * - `false`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; - * - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm). - * @property {number} [max_time=null] The maximum amount of time you allow the computation to run for in seconds. Generation will still finish the current pass after allocated time has been passed. - * - * @property {boolean} [do_sample=false] Whether or not to use sampling; use greedy decoding otherwise. - * @property {number} [num_beams=1] Number of beams for beam search. 1 means no beam search. - * @property {number} [num_beam_groups=1] Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. - * @property {number} [penalty_alpha=null] The values balance the model confidence and the degeneration penalty in contrastive search decoding. - * @property {boolean} [use_cache=true] Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. - * - * @property {number} [temperature=1.0] The value used to modulate the next token probabilities. - * @property {number} [top_k=50] The number of highest probability vocabulary tokens to keep for top-k-filtering. - * @property {number} [top_p=1.0] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - * @property {number} [typical_p=1.0] Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. - * @property {number} [epsilon_cutoff=0.0] If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. - * @property {number} [eta_cutoff=0.0] Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. - * @property {number} [diversity_penalty=0.0] This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. - * @property {number} [repetition_penalty=1.0] The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. - * @property {number} [encoder_repetition_penalty=1.0] The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the original input. 1.0 means no penalty. - * @property {number} [length_penalty=1.0] Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. - * @property {number} [no_repeat_ngram_size=0] If set to int > 0, all ngrams of that size can only occur once. - * @property {number[][]} [bad_words_ids=null] List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `(await tokenizer(bad_words, {add_prefix_space: true, add_special_tokens: false})).input_ids`. - * @property {number[][]|number[][][]} [force_words_ids=null] List of token ids that must be generated. If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `number[][][]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. - * @property {boolean} [renormalize_logits=false] Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `true` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. - * @property {Object[]} [constraints=null] Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible. - * - * @property {number} [forced_bos_token_id=null] The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like mBART where the first generated token needs to be the target language token. - * @property {number|number[]} [forced_eos_token_id=null] The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a list to set multiple *end-of-sequence* tokens. - * @property {boolean} [remove_invalid_values=false] Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation. - * @property {number[]} [exponential_decay_length_penalty=null] This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay. - * @property {number[]} [suppress_tokens=null] A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - * @property {number[]} [begin_suppress_tokens=null] A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - * @property {number[][]} [forced_decoder_ids=null] A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123. - * - * @property {number} [num_return_sequences=1] The number of independently computed returned sequences for each element in the batch. - * @property {boolean} [output_attentions=false] Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. - * @property {boolean} [output_hidden_states=false] Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. - * @property {boolean} [output_scores=false] Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - * @property {boolean} [return_dict_in_generate=false] Whether or not to return a `ModelOutput` instead of a plain tuple. - * - * @property {number} [pad_token_id=null] The id of the *padding* token. - * @property {number} [bos_token_id=null] The id of the *beginning-of-sequence* token. - * @property {number|number[]} [eos_token_id=null] The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - * - * @property {number} [encoder_no_repeat_ngram_size=0] If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. - * @property {number} [decoder_start_token_id=null] If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. - * - * @property {Object} [generation_kwargs={}] Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not present in `generate`'s signature will be used in the model forward pass. - */ - -/** - * Class that holds a configuration for a generation task. - * @type {new (kwargs?: GenerationConfigType) => GenerationConfigType} - */ -export const GenerationConfig = /** @type {any} */ (class { - - /** - * Create a new GenerationConfig object. - * @param {GenerationConfigType} kwargs - */ - constructor(kwargs = {}) { - // Parameters that control the length of the output - this.max_length = kwargs.max_length ?? 20; - this.max_new_tokens = kwargs.max_new_tokens ?? null; - this.min_length = kwargs.min_length ?? 0; - this.min_new_tokens = kwargs.min_new_tokens ?? null; - this.early_stopping = kwargs.early_stopping ?? false; - this.max_time = kwargs.max_time ?? null; - - // Parameters that control the generation strategy used - this.do_sample = kwargs.do_sample ?? false; - this.num_beams = kwargs.num_beams ?? 1; - this.num_beam_groups = kwargs.num_beam_groups ?? 1; - this.penalty_alpha = kwargs.penalty_alpha ?? null; - this.use_cache = kwargs.use_cache ?? true; - - // Parameters for manipulation of the model output logits - this.temperature = kwargs.temperature ?? 1.0; - this.top_k = kwargs.top_k ?? 50; - this.top_p = kwargs.top_p ?? 1.0; - this.typical_p = kwargs.typical_p ?? 1.0; - this.epsilon_cutoff = kwargs.epsilon_cutoff ?? 0.0; - this.eta_cutoff = kwargs.eta_cutoff ?? 0.0; - this.diversity_penalty = kwargs.diversity_penalty ?? 0.0; - this.repetition_penalty = kwargs.repetition_penalty ?? 1.0; - this.encoder_repetition_penalty = kwargs.encoder_repetition_penalty ?? 1.0; - this.length_penalty = kwargs.length_penalty ?? 1.0; - this.no_repeat_ngram_size = kwargs.no_repeat_ngram_size ?? 0; - this.bad_words_ids = kwargs.bad_words_ids ?? null; - this.force_words_ids = kwargs.force_words_ids ?? null; - this.renormalize_logits = kwargs.renormalize_logits ?? false; - this.constraints = kwargs.constraints ?? null; - this.forced_bos_token_id = kwargs.forced_bos_token_id ?? null; - this.forced_eos_token_id = kwargs.forced_eos_token_id ?? null; - this.remove_invalid_values = kwargs.remove_invalid_values ?? false; - this.exponential_decay_length_penalty = kwargs.exponential_decay_length_penalty ?? null; - this.suppress_tokens = kwargs.suppress_tokens ?? null; - this.begin_suppress_tokens = kwargs.begin_suppress_tokens ?? null; - this.forced_decoder_ids = kwargs.forced_decoder_ids ?? null; - - // Parameters that define the output variables of `generate` - this.num_return_sequences = kwargs.num_return_sequences ?? 1; - this.output_attentions = kwargs.output_attentions ?? false; - this.output_hidden_states = kwargs.output_hidden_states ?? false; - this.output_scores = kwargs.output_scores ?? false; - this.return_dict_in_generate = kwargs.return_dict_in_generate ?? false; - - // Special tokens that can be used at generation time - this.pad_token_id = kwargs.pad_token_id ?? null; - this.bos_token_id = kwargs.bos_token_id ?? null; - this.eos_token_id = kwargs.eos_token_id ?? null; - - // Generation parameters exclusive to encoder-decoder models - this.encoder_no_repeat_ngram_size = kwargs.encoder_no_repeat_ngram_size ?? 0; - this.decoder_start_token_id = kwargs.decoder_start_token_id ?? null; - - // Wild card - this.generation_kwargs = kwargs.generation_kwargs ?? {}; - } -}); - -/** - * Sampler is a base class for all sampling methods used for text generation. - */ -export class Sampler extends Callable { - /** - * Creates a new Sampler object with the specified generation config. - * @param {GenerationConfigType} generation_config The generation config. - */ - constructor(generation_config) { - super(); - this.generation_config = generation_config; - } - - /** - * Executes the sampler, using the specified logits. - * @param {Tensor} logits - * @param {number} index - * @returns {void} - */ - _call(logits, index = -1) { - // Sample from logits, of dims [batch, sequence_length, vocab_size]. - // If index is specified, sample from [batch, index, vocab_size]. - return this.sample(logits, index); - } - - /** - * Abstract method for sampling the logits. - * @param {Tensor} logits - * @param {number} index - * @throws {Error} - */ - sample(logits, index) { - throw Error("sample should be implemented in subclasses.") - } - - /** - * Returns the specified logits as an array, with temperature applied. - * @param {Tensor} logits - * @param {number} index - * @returns {Float32Array} - */ - getLogits(logits, index) { - let vocabSize = logits.dims.at(-1); - - let logs = /** @type {Float32Array} */(logits.data); - - if (index === -1) { - logs = logs.slice(-vocabSize); - } else { - let startIndex = index * vocabSize; - logs = logs.slice(startIndex, startIndex + vocabSize); - } - - // add temperature - if (this.generation_config.temperature > 0) { - logs = logs.map(x => x / this.generation_config.temperature) - } - return logs; - } - - /** - * Selects an item randomly based on the specified probabilities. - * @param {Array} probabilities An array of probabilities to use for selection. - * @returns {number} The index of the selected item. - */ - randomSelect(probabilities) { - // Return index of chosen item - let sumProbabilities = probabilities.reduce((acc, curr) => acc + curr, 0); - - let r = Math.random() * sumProbabilities; - for (let i = 0; i < probabilities.length; ++i) { - r -= probabilities[i]; - if (r <= 0) { - return i; - } - } - return 0; // return first (most probable) as a fallback - } - - /** - * Returns a Sampler object based on the specified options. - * @param {GenerationConfigType} generation_config An object containing options for the sampler. - * @returns {Sampler} A Sampler object. - */ - static getSampler(generation_config) { - // - *greedy decoding*: `num_beams=1` and `do_sample=False` - // - *contrastive search*: `penalty_alpha>0` and `top_k>1` - // - *multinomial sampling*: `num_beams=1` and `do_sample=True` - // - *beam-search decoding*: `num_beams>1` and `do_sample=False` - // - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True` - // - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1` - // - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None` - - // NOTE: beam search is implemented directly into the generation function - if (generation_config.do_sample) { - return new MultinomialSampler(generation_config); - - } else if (generation_config.num_beams > 1) { - return new BeamSearchSampler(generation_config); - - } else { - if (generation_config.num_return_sequences > 1) { - throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`) - } - return new GreedySampler(generation_config); - } - } -} - -/** - * Class representing a Greedy Sampler. - * @extends Sampler - */ -class GreedySampler extends Sampler { - /** - * Sample the maximum probability of a given logits tensor. - * @param {Tensor} logits - * @param {number} [index=-1] - * @returns {Array} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search). - */ - sample(logits, index = -1) { - // NOTE: no need to do log_softmax here since we only take the maximum - let logs = this.getLogits(logits, index); - let argmax = max(logs)[1]; - - // Note: score is meaningless in this context, since we are performing - // greedy search (p = 1 => log(p) = 0) - return [ - [argmax, 0] - ]; - } -} - -/** - * Class representing a MultinomialSampler. - * @extends Sampler - */ -class MultinomialSampler extends Sampler { - - /** - * Sample from the logits. - * @param {Tensor} logits - * @param {number} index - * @returns {Array} - */ - sample(logits, index = -1) { - let k = logits.dims.at(-1); // defaults to vocab size - if (this.generation_config.top_k > 0) { - k = Math.min(this.generation_config.top_k, k); - } - - // Get logits of nth token - const logs = this.getLogits(logits, index); - - // Get top k tokens - const topLogits = getTopItems(logs, k); - - // Compute softmax over logits - const probabilities = softmax(topLogits.map(x => x[1])); - - return Array.from({ length: this.generation_config.num_beams }, () => { - const sampledIndex = this.randomSelect(probabilities); - return [ - topLogits[sampledIndex][0], // token id - Math.log(probabilities[sampledIndex]), // score - ]; - }); - } -} - - -/** - * Class representing a BeamSearchSampler. - * @extends Sampler - */ -class BeamSearchSampler extends Sampler { - - /** - * Sample from the logits. - * @param {Tensor} logits - * @param {number} index - * @returns {Array} - */ - sample(logits, index = -1) { - let k = logits.dims.at(-1); // defaults to vocab size - if (this.generation_config.top_k > 0) { - k = Math.min(this.generation_config.top_k, k); - } - - // Get logits of nth token - const logs = this.getLogits(logits, index); - - // Get top k tokens - const topLogits = getTopItems(logs, k); - - // Compute softmax over logits - const probabilities = softmax(topLogits.map(x => x[1])); - - return Array.from({ length: this.generation_config.num_beams }, (_, i) => { - return [ - topLogits[i][0], // token id - Math.log(probabilities[i]), // score - ]; - }); - } -} diff --git a/src/utils/generic.js b/src/utils/generic.js new file mode 100644 index 000000000..5ccd467ad --- /dev/null +++ b/src/utils/generic.js @@ -0,0 +1,35 @@ + +/** + * A base class for creating callable objects. + * See [here](https://stackoverflow.com/q/76073890) for more information. + * + * @type {new () => {(...args: any[]): any, _call(...args: any[]): any}} + */ +export const Callable = /** @type {any} */ (class { + /** + * Creates a new instance of the Callable class. + */ + constructor() { + /** + * Creates a closure that delegates to a private method '_call' with the given arguments. + * @type {any} + * @param {...any} args Zero or more arguments to pass to the '_call' method. + * @returns {*} The result of calling the '_call' method. + */ + let closure = function (...args) { + return closure._call(...args) + } + return Object.setPrototypeOf(closure, new.target.prototype) + } + + /** + * This method should be implemented in subclasses to provide the + * functionality of the callable object. + * + * @param {any[]} args + * @throws {Error} If the subclass does not implement the `_call` method. + */ + _call(...args) { + throw Error('Must implement _call method in subclass') + } +}); diff --git a/src/utils/hub.js b/src/utils/hub.js index 32407bed6..7f4b0547e 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -26,6 +26,8 @@ import { dispatchCallback } from './core.js'; /** * @typedef {Object} ModelSpecificPretrainedOptions Options for loading a pretrained model. + * @property {string} [subfolder='onnx'] In case the relevant files are located inside a subfolder of the model repo on huggingface.co, + * you can specify the folder name here. * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. * @property {import("./devices.js").DeviceType} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. * @property {import("./dtypes.js").DataType} [dtype=null] The data type to use for the model. If not specified, the data type will be chosen from the environment settings. diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 38c2baee6..bf64456e0 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -1196,24 +1196,73 @@ function dimsToStride(dims) { return stride; } +function fullHelper(size, fill_value, dtype, cls) { + const numElements = size.reduce((a, b) => a * b, 1); + return new Tensor( + dtype, + new cls(numElements).fill(fill_value), + size + ) +} + +/** + * Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. + * @param {number[]} size A sequence of integers defining the shape of the output tensor. + * @param {number|bigint} fill_value The value to fill the output tensor with. + * @returns {Tensor} The filled tensor. + */ +export function full(size, fill_value) { + let dtype; + let typedArrayCls; + if (typeof fill_value === 'number') { + dtype = 'float32'; + typedArrayCls = Float32Array; + } else if (typeof fill_value === 'bigint') { + dtype = 'int64'; + typedArrayCls = BigInt64Array; + } else { + // TODO: support other dtypes + throw new Error(`Unsupported data type: ${typeof fill_value}`); + } + return fullHelper(size, fill_value, dtype, typedArrayCls); +} + +export function full_like(tensor, fill_value) { + return full(tensor.dims, fill_value); +} + /** * Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. * @param {number[]} size A sequence of integers defining the shape of the output tensor. + * @returns {Tensor} The ones tensor. */ export function ones(size) { - const numElements = size.reduce((a, b) => a * b, 1); - return new Tensor( - 'int64', - new BigInt64Array(numElements).fill(1n), - size - ) + return fullHelper(size, 1n, 'int64', BigInt64Array); } /** * Returns a tensor filled with the scalar value 1, with the same size as input. * @param {Tensor} tensor The size of input will determine size of the output tensor. - * @returns The ones tensor. + * @returns {Tensor} The ones tensor. */ export function ones_like(tensor) { return ones(tensor.dims); } + +/** + * Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size. + * @param {number[]} size A sequence of integers defining the shape of the output tensor. + * @returns {Tensor} The zeros tensor. + */ +export function zeros(size) { + return fullHelper(size, 0n, 'int64', BigInt64Array); +} + +/** + * Returns a tensor filled with the scalar value 0, with the same size as input. + * @param {Tensor} tensor The size of input will determine size of the output tensor. + * @returns {Tensor} The zeros tensor. + */ +export function zeros_like(tensor) { + return zeros(tensor.dims); +} \ No newline at end of file From a4cce2ccc9128ed8d04c699cc7eb9c099d7d62b9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 3 Apr 2024 03:45:55 +0200 Subject: [PATCH 063/473] Create tiny model unit test suite --- tests/tiny_random.test.js | 869 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 869 insertions(+) create mode 100644 tests/tiny_random.test.js diff --git a/tests/tiny_random.test.js b/tests/tiny_random.test.js new file mode 100644 index 000000000..4fe1dfd47 --- /dev/null +++ b/tests/tiny_random.test.js @@ -0,0 +1,869 @@ + + +import { + CodeGenTokenizer, + LlamaForCausalLM, + LlamaTokenizer, + OPTForCausalLM, + GPT2Tokenizer, + GPTNeoXForCausalLM, + GPTNeoXTokenizer, + GPTJForCausalLM, + BloomForCausalLM, + BloomTokenizer, + GPTBigCodeForCausalLM, + GPT2LMHeadModel, + MptForCausalLM, + CodeGenForCausalLM, + MistralForCausalLM, + GPTNeoForCausalLM, + BertTokenizer, + BertForMaskedLM, + BertForSequenceClassification, + T5ForConditionalGeneration, + T5Tokenizer, + T5Model, + BertModel, + BertForTokenClassification, + BertForQuestionAnswering, + MusicgenForConditionalGeneration, + full, +} from '../src/transformers.js'; + +import { init } from './init.js'; +init(); + +const MAX_MODEL_LOAD_TIME = 10_000; // 10 seconds +const MAX_TEST_EXECUTION_TIME = 10_000; // 10 seconds +const MAX_MODEL_DISPOSE_TIME = 1_000; // 1 second + +const DEFAULT_MODEL_OPTIONS = { + dtype: 'fp32', +} +describe('Tiny random models', () => { + + describe('bert', () => { + describe('BertModel', () => { + const model_id = 'hf-internal-testing/tiny-random-BertModel'; + + /** @type {BertModel} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const { last_hidden_state } = await model(inputs); + expect(last_hidden_state.dims).toEqual([1, 7, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(0.0, 5); + + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const { last_hidden_state } = await model(inputs); + expect(last_hidden_state.dims).toEqual([2, 12, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(1.4901161193847656e-08, 5); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe('BertForMaskedLM', () => { + const model_id = 'hf-internal-testing/tiny-random-BertForMaskedLM'; + + const texts = [ + 'The goal of life is [MASK].', + 'Paris is the [MASK] of France.', + ]; + + /** @type {BertForMaskedLM} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForMaskedLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer(texts[0]); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 19, 1124]); + expect(logits.mean().item()).toBeCloseTo(0.0016587056452408433, 5); + + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(texts, { padding: true }); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([2, 22, 1124]); + expect(logits.mean().item()).toBeCloseTo(0.0017160633578896523, 5); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe('BertForSequenceClassification', () => { + const model_id = 'hf-internal-testing/tiny-random-BertForSequenceClassification'; + + /** @type {BertForSequenceClassification} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForSequenceClassification.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const { logits } = await model(inputs); + const target = [ + [0.00043986947275698185, -0.030218850821256638], + ].flat(); + expect(logits.dims).toEqual([1, 2]); + logits.tolist().flat().forEach((item, i) => { + expect(item).toBeCloseTo(target[i], 5); + }); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const { logits } = await model(inputs); + const target = [ + [0.00043986947275698185, -0.030218850821256638], + [0.0003853091038763523, -0.03022204339504242] + ].flat(); + expect(logits.dims).toEqual([2, 2]); + logits.tolist().flat().forEach((item, i) => { + expect(item).toBeCloseTo(target[i], 5); + }); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe('BertForTokenClassification', () => { + const model_id = 'hf-internal-testing/tiny-random-BertForTokenClassification'; + + /** @type {BertForTokenClassification} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForTokenClassification.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 7, 2]); + expect(logits.mean().item()).toBeCloseTo(0.07089076191186905, 5); + + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const { logits } = await model(inputs); + expect(logits.dims).toEqual([2, 12, 2]); + expect(logits.mean().item()).toBeCloseTo(0.04702216014266014, 5); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + + describe('BertForQuestionAnswering', () => { + const model_id = 'hf-internal-testing/tiny-random-BertForQuestionAnswering'; + + /** @type {BertForQuestionAnswering} */ + let model; + /** @type {BertTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BertForQuestionAnswering.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await BertTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const { start_logits, end_logits } = await model(inputs); + expect(start_logits.dims).toEqual([1, 7]); + expect(start_logits.mean().item()).toBeCloseTo(0.12772157788276672, 5); + expect(end_logits.dims).toEqual([1, 7]); + expect(end_logits.mean().item()).toBeCloseTo(0.11811424791812897, 5); + + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const { start_logits, end_logits } = await model(inputs); + expect(start_logits.dims).toEqual([2, 12]); + expect(start_logits.mean().item()).toBeCloseTo(0.12843115627765656, 5); + expect(end_logits.dims).toEqual([2, 12]); + expect(end_logits.mean().item()).toBeCloseTo(0.11745202541351318, 5); + + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('t5', () => { + + describe('T5Model', () => { + const model_id = 'hf-internal-testing/tiny-random-T5Model'; + + /** @type {T5Model} */ + let model; + /** @type {T5Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await T5Model.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await T5Tokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('forward', async () => { + // Example adapted from https://huggingface.co/google-t5/t5-small#how-to-get-started-with-the-model + const inputs = tokenizer( + "Studies have been shown that owning a dog is good for you", + ); + const { input_ids: decoder_input_ids } = tokenizer( + "Studies show that", + ); + + const { last_hidden_state } = await model({ ...inputs, decoder_input_ids }); + expect(last_hidden_state.dims).toEqual([1, 4, 32]); + expect(last_hidden_state.mean().item()).toBeCloseTo(7.492632721550763e-05, 8); + }); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + describe('T5ForConditionalGeneration', () => { + const model_id = 'hf-internal-testing/tiny-random-T5ForConditionalGeneration'; + + /** @type {T5ForConditionalGeneration} */ + let model; + /** @type {T5Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await T5ForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await T5Tokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('forward', async () => { + // Example adapted from https://huggingface.co/google-t5/t5-small#how-to-get-started-with-the-model + const inputs = tokenizer( + "Studies have been shown that owning a dog is good for you", + ); + const { input_ids: decoder_input_ids } = tokenizer( + "Studies show that", + ); + + const model = await T5ForConditionalGeneration.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS); + const outputs = await model({ ...inputs, decoder_input_ids }); + expect(outputs.logits.dims).toEqual([1, 4, 32100]); + expect(outputs.logits.mean().item()).toBeCloseTo(8.867568901393952e-09, 12); + }); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n], + [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('musicgen', () => { + describe('MusicgenForConditionalGeneration', () => { + const model_id = 'hf-internal-testing/tiny-random-MusicgenForConditionalGeneration'; + + // Example adapted from https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation + const texts = [ + "80s pop track with bassy drums and synth", + "90s rock song with loud guitars and heavy drums", + ]; + + /** @type {MusicgenForConditionalGeneration} */ + let model; + /** @type {T5Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await MusicgenForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await T5Tokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('forward', async () => { + // Example from https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenForConditionalGeneration.forward.example + const inputs = tokenizer(texts, { padding: true }); + const pad_token_id = BigInt(model.generation_config.pad_token_id); + const decoder_input_ids = full( + [inputs.input_ids.dims[0] * model.config.decoder.num_codebooks, 1], + pad_token_id, + ); + const { logits } = await model({ ...inputs, decoder_input_ids }); + expect(logits.dims).toEqual([8, 1, 99]); + expect(logits.mean().item()).toBeCloseTo(-0.0018370470497757196, 5); + }); + + it('batch_size=1', async () => { + const inputs = tokenizer(texts[0]); + const audio_values = await model.generate({ ...inputs, max_length: 10 }); + expect(audio_values.dims).toEqual([1, 1, 1920]); + expect(audio_values.mean().item()).toBeCloseTo(0.16644205152988434, 5); + }, MAX_TEST_EXECUTION_TIME); + + it('batch_size>1', async () => { + const inputs = tokenizer(texts, { padding: true }); + const audio_values = await model.generate({ ...inputs, max_length: 10 }); + expect(audio_values.dims).toEqual([2, 1, 1920]); + expect(audio_values.mean().item()).toBeCloseTo(0.16644206643104553, 5); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('opt', () => { + describe('OPTForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-OPTForCausalLM'; + /** @type {OPTForCausalLM} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await OPTForCausalLM.from_pretrained(model_id, { + // TODO move to config + revision: 'refs/pr/2', + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id, { + // TODO update this + revision: 'refs/pr/3', + }); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [2n, 42891n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n], + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [1n, 2n, 42891n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n, 39144n], + [2n, 42891n, 232n, 24680n, 24680n, 24680n, 24680n, 24680n, 24680n, 24680n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('llama', () => { + describe('LlamaForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-LlamaForCausalLM'; + /** @type {LlamaForCausalLM} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await LlamaForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n, 15330n, 26591n, 15721n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 1n, 22172n, 18547n, 8143n, 22202n, 9456n, 17213n, 15330n, 26591n], + [1n, 22172n, 3186n, 24786n, 19169n, 20222n, 29993n, 27146n, 27426n, 24562n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('gpt_neo', () => { + describe('GPTNeoForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-GPTNeoForCausalLM'; + /** @type {GPTNeoForCausalLM} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTNeoForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [258n, 863n, 79n, 79n, 79n, 949n, 949n, 949n, 949n, 949n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 79n, 79n, 949n, 949n, 949n], + [258n, 863n, 79n, 269n, 813n, 849n, 849n, 849n, 849n, 849n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('gpt_neox', () => { + describe('GPTNeoXForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-GPTNeoXForCausalLM'; + /** @type {GPTNeoXForCausalLM} */ + let model; + /** @type {GPTNeoXTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTNeoXForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [259n, 864n, 80n, 881n, 502n, 895n, 938n, 668n, 502n, 895n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 259n, 864n, 80n, 881n, 502n, 895n, 938n, 668n], + [259n, 864n, 80n, 270n, 814n, 522n, 112n, 268n, 503n, 468n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('gptj', () => { + describe('GPTJForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-GPTJForCausalLM'; + /** @type {GPTJForCausalLM} */ + let model; + /** @type {GPTNeoXTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTJForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [258n, 863n, 79n, 102n, 401n, 773n, 889n, 159n, 957n, 869n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 102n, 401n, 773n, 889n, 159n], + [258n, 863n, 79n, 269n, 813n, 879n, 175n, 39n, 141n, 1000n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('bloom', () => { + describe('BloomForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-BloomForCausalLM'; + /** @type {BloomForCausalLM} */ + let model; + /** @type {BloomTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await BloomForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await BloomTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [198n, 803n, 82n, 82n, 82n, 82n, 82n, 82n, 82n, 82n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [3n, 3n, 198n, 803n, 82n, 82n, 82n, 82n, 82n, 82n], + [198n, 803n, 82n, 209n, 753n, 753n, 753n, 753n, 753n, 753n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('gpt_bigcode', () => { + describe('GPTBigCodeForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-GPTBigCodeForCausalLM'; + /** @type {GPTBigCodeForCausalLM} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPTBigCodeForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n, 79n, 79n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n], + [258n, 863n, 79n, 269n, 813n, 832n, 93n, 93n, 93n, 93n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('gpt2', () => { + describe('GPT2LMHeadModel', () => { + const model_id = 'hf-internal-testing/tiny-random-GPT2LMHeadModel'; + /** @type {GPT2LMHeadModel} */ + let model; + /** @type {GPT2Tokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await GPT2LMHeadModel.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await GPT2Tokenizer.from_pretrained(model_id); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n, 79n, 243n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 79n, 79n, 79n, 79n, 79n], + [258n, 863n, 79n, 269n, 813n, 813n, 813n, 813n, 813n, 813n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('mpt', () => { + describe('MptForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-MptForCausalLM'; + /** @type {MptForCausalLM} */ + let model; + /** @type {GPTNeoXTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await MptForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await GPTNeoXTokenizer.from_pretrained(model_id); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [259n, 864n, 80n, 80n, 80n, 80n, 80n, 80n, 80n, 80n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 259n, 864n, 80n, 80n, 80n, 80n, 80n, 80n], + [259n, 864n, 80n, 270n, 814n, 293n, 293n, 293n, 293n, 293n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('codegen', () => { + describe('CodeGenForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-CodeGenForCausalLM'; + /** @type {CodeGenForCausalLM} */ + let model; + /** @type {CodeGenTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await CodeGenForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await CodeGenTokenizer.from_pretrained(model_id); + tokenizer.padding_side = 'left'; + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [258n, 863n, 79n, 437n, 334n, 450n, 294n, 621n, 375n, 385n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [0n, 0n, 258n, 863n, 79n, 437n, 334n, 450n, 294n, 621n], + [258n, 863n, 79n, 269n, 813n, 759n, 113n, 295n, 574n, 987n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('mistral', () => { + describe('MistralForCausalLM', () => { + const model_id = 'hf-internal-testing/tiny-random-MistralForCausalLM'; + /** @type {MistralForCausalLM} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + beforeAll(async () => { + model = await MistralForCausalLM.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('batch_size=1', async () => { + const inputs = tokenizer('hello'); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [1n, 6312n, 28709n, 24704n, 8732n, 1310n, 9808n, 13771n, 27309n, 4779n] + ]); + }, MAX_TEST_EXECUTION_TIME); + it('batch_size>1', async () => { + const inputs = tokenizer(['hello', 'hello world'], { padding: true }); + const outputs = await model.generate({ + ...inputs, + max_length: 10, + }); + expect(outputs.tolist()).toEqual([ + [2n, 1n, 6312n, 28709n, 24704n, 8732n, 1310n, 9808n, 13771n, 27309n], + [1n, 6312n, 28709n, 1526n, 8687n, 5690n, 1770n, 30811n, 12501n, 3325n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); +}); From 94c587d9938bb73da769e04a2007d4ff6faa1c65 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 3 Apr 2024 18:07:05 +0200 Subject: [PATCH 064/473] Fix musicgen config --- src/models.js | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/models.js b/src/models.js index 2ef4559d8..701241465 100644 --- a/src/models.js +++ b/src/models.js @@ -5958,15 +5958,9 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: // decoder const decoderConfig = config.decoder; - this.num_decoder_layers = decoderConfig.num_hidden_layers; - this.num_decoder_heads = decoderConfig.num_attention_heads; - this.decoder_dim_kv = decoderConfig.hidden_size / this.num_decoder_heads; - - // text encoder - const textConfig = config.text_encoder; - this.num_encoder_layers = textConfig.num_layers; - this.num_encoder_heads = textConfig.num_heads; - this.encoder_dim_kv = textConfig.d_model / textConfig.num_heads; // Should be textConfig.d_kv; + this.num_encoder_layers = this.num_decoder_layers = decoderConfig.num_hidden_layers; + this.num_encoder_heads = this.num_decoder_heads = decoderConfig.num_attention_heads; + this.encoder_dim_kv = this.decoder_dim_kv = decoderConfig.hidden_size / this.num_decoder_heads; } /** From d9412405692af83210b8ce7235cf6b08969ed015 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Apr 2024 01:37:27 +0200 Subject: [PATCH 065/473] Add support for llava image-text-to-text model --- src/models.js | 199 ++++++++++++-------------------------- tests/tiny_random.test.js | 76 +++++++++++++++ 2 files changed, 140 insertions(+), 135 deletions(-) diff --git a/src/models.js b/src/models.js index 701241465..80b29e15f 100644 --- a/src/models.js +++ b/src/models.js @@ -378,25 +378,15 @@ async function seq2seqForward(self, model_inputs) { encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state; } - const { input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs; other_decoder_inputs.input_ids = decoder_input_ids; other_decoder_inputs.encoder_hidden_states = encoder_outputs; - const use_cache_branch = !!past_key_values; - if (self.decoder_merged_session.inputNames.includes('use_cache_branch')) { - other_decoder_inputs.use_cache_branch = boolTensor(use_cache_branch); - } if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask } - this.addPastKeyValues(other_decoder_inputs, past_key_values); - - // Rename decoder inputs - const decoder_inputs = pick(other_decoder_inputs, self.decoder_merged_session.inputNames); - - const decoderResults = await sessionRun(self.decoder_merged_session, decoder_inputs); + const decoderResults = await decoderForward(self, other_decoder_inputs, true); // Get cross attention and/or decoder attentions if they are present // const attns = self.getAttentions(decoderResults); @@ -435,20 +425,24 @@ async function encoderForward(self, model_inputs) { * @returns {Promise} Promise that resolves with an object containing the logits and past key values. * @private */ -async function decoderForward(self, model_inputs) { - // TODO move addPastKeyValues from decoder_prepare_inputs_for_generation here - return await sessionRun(self.session, model_inputs); -} +async function decoderForward(self, model_inputs, is_encoder_decoder = false) { + const { past_key_values, ...new_model_inputs } = model_inputs; -function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { + const session = is_encoder_decoder ? self.decoder_merged_session : self.session; - const { past_key_values, ...new_model_inputs } = model_inputs; + if (session.inputNames.includes('use_cache_branch')) { + new_model_inputs.use_cache_branch = boolTensor(!!past_key_values); + } + // Unpack the `past_key_values` object into model inputs self.addPastKeyValues(new_model_inputs, past_key_values); + const fixed = pick(new_model_inputs, session.inputNames); + return await sessionRun(session, fixed); +} - const fixed = pick(new_model_inputs, self.session.inputNames); +function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { - if (self.session.inputNames.includes('position_ids') && fixed.attention_mask && !fixed.position_ids) { + if (self.session.inputNames.includes('position_ids') && model_inputs.attention_mask && !model_inputs.position_ids) { // If the model supports providing position_ids, we create position_ids on the fly for batch generation, // by computing the cumulative sum of the attention mask along the sequence length dimension. // @@ -457,30 +451,30 @@ function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { // position_ids.masked_fill_(attention_mask == 0, 1) // if past_key_values: // position_ids = position_ids[:, -input_ids.shape[1] :] - const [bz, seq_len] = fixed.attention_mask.dims; + const [bz, seq_len] = model_inputs.attention_mask.dims; - const data = new BigInt64Array(fixed.attention_mask.data.length); + const data = new BigInt64Array(model_inputs.attention_mask.data.length); for (let i = 0; i < bz; ++i) { const start = i * seq_len; let sum = BigInt(0); for (let j = 0; j < seq_len; ++j) { const index = start + j; - if (fixed.attention_mask.data[index] === 0n) { + if (model_inputs.attention_mask.data[index] === 0n) { data[index] = BigInt(1); } else { // === 1n data[index] = sum; - sum += fixed.attention_mask.data[index]; + sum += model_inputs.attention_mask.data[index]; } } } - fixed.position_ids = new Tensor('int64', data, fixed.attention_mask.dims); - if (past_key_values) { - fixed.position_ids = fixed.position_ids.slice(null, -1).unsqueeze_(-1); + model_inputs.position_ids = new Tensor('int64', data, model_inputs.attention_mask.dims); + if (model_inputs.past_key_values) { + model_inputs.position_ids = model_inputs.position_ids.slice(null, -1).unsqueeze_(-1); } } - return fixed; + return model_inputs; } function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { @@ -647,9 +641,9 @@ export class PreTrainedModel extends Callable { } else if (modelType === MODEL_TYPES.ImageTextToText) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), constructSession(pretrained_model_name_or_path, 'embed_tokens', options), constructSession(pretrained_model_name_or_path, 'vision_encoder', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); @@ -1105,6 +1099,7 @@ export class PreTrainedModel extends Callable { } else { input_ids = model_inputs[model_input_name] } + // 6. Prepare `max_length` depending on other stopping criteria. let input_ids_length = input_ids.dims.at(-1); @@ -1191,8 +1186,8 @@ export class PreTrainedModel extends Callable { const logs = logits[batch_idx]; prepared_logits_processor(all_input_ids[batch_idx], logs); - let sampledTokens = sampler(logs); - for (let [newTokenId, logProb] of sampledTokens) { + const sampledTokens = sampler(logs); + for (const [newTokenId, logProb] of sampledTokens) { const bigint = BigInt(newTokenId); // TODO: If branching, use previous beam as a starting point // update generated ids, model inputs, and length for next step @@ -3259,10 +3254,11 @@ export class LlavaPreTrainedModel extends PreTrainedModel { 'attention_mask', ]; - constructor(config, session, embed_tokens_session, vision_encoder_session, generation_config) { - super(config, session); - this.input_embeds_session = embed_tokens_session; + constructor(config, input_embeds_session, vision_encoder_session, decoder_merged_session, generation_config) { + super(config, input_embeds_session); this.vision_encoder_session = vision_encoder_session; + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; const decoderConfig = this.config.text_config; @@ -3291,7 +3287,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { async encode_text({ input_ids }) { // text_inputs === { input_ids, attention_mask } - return (await sessionRun(this.input_embeds_session, { input_ids })).inputs_embeds; + return (await sessionRun(this.session, { input_ids })).inputs_embeds; } // async answer_question({ @@ -3353,57 +3349,60 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { const image_token_index = this.config.image_token_index; - // NOTE: we use .findIndex instead of .indexOf to perform weak comparison (==) between BigInt and Number const idsList = input_ids.tolist(); - // TODO: validate at most 1 image token + + // NOTE: we use .findIndex instead of .indexOf to perform weak comparison (==) between BigInt and Number const indexOfImage = idsList.map(x => x.findIndex(x => x == image_token_index)); - if (!(indexOfImage.every(x => x === -1) || indexOfImage.every(x => x !== -1))) { + const noImages = indexOfImage.every(x => x === -1); + const allImages = indexOfImage.every(x => x !== -1); + if (!noImages && !allImages) { // Check for padding reasons throw new Error('Every input should contain either 0 or 1 image token.'); } + if (noImages) { + return { + inputs_embeds, + attention_mask, + position_ids: null, + }; + } + let stacked = []; + let stacked_attention_mask = []; for (let i = 0; i < indexOfImage.length; ++i) { const index = indexOfImage[i]; const e = inputs_embeds[i]; const im = image_features[i]; - if (index === -1) { - stacked.push(e); - } else { - const sliced = [ + const am = attention_mask[i]; + stacked.push( + cat([ e.slice([0, index]), - e.slice([index + 1, e.dims[0]]) - ]; - stacked.push( - cat([sliced[0], im, sliced[1]], 0) - ); - } + im, + e.slice([index + 1, e.dims[0]]), + ], 0) + ); + + stacked_attention_mask.push( + cat([ + am.slice([0, index]), + ones([im.dims[0]]), + am.slice([index + 1, am.dims[0]]) + ], 0) + ) } return { inputs_embeds: stack(stacked, 0), - attention_mask, + attention_mask: stack(stacked_attention_mask, 0), position_ids: null, }; } - prepare_inputs_for_generation({ - input_ids, - past_key_values = null, - inputs_embeds = null, - pixel_values = null, - attention_mask = null, - ...kwargs - }) { - // if (input_ids.dims[0] !== 1) { - // throw new Error('Only single input is supported for now'); - // } - if (past_key_values) { - - } - + prepare_inputs_for_generation(input_ids, model_inputs) { + return model_inputs; } @@ -3447,7 +3446,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { if (pixel_values && input_ids.dims[1] !== 1) { const image_features = await this.encode_image({ pixel_values }); - ({ inputs_embeds, inputs_embeds, position_ids } = this._merge_input_ids_with_image_features({ + ({ inputs_embeds, inputs_embeds, attention_mask, position_ids } = this._merge_input_ids_with_image_features({ image_features, inputs_embeds, input_ids, @@ -3459,84 +3458,14 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { } } - // TEMP: for now, just recreate attention mask - attention_mask = ones(inputs_embeds.dims.slice(0, 2)); - const outputs = await decoderForward(this, { inputs_embeds, past_key_values, attention_mask, generation_config, logits_processor, - }) - // return super.generate({ - // inputs_embeds, - - // }) + }, true); return outputs; - - // TODO generalize with decoderForward - // let { input_ids, image_embeds, pixel_values, past_key_values, attention_mask } = model_inputs; - - // const inputs_embeds = await this.input_embeds({ prompt, image_embeds, tokenizer }); - // console.log('inputs_embeds', inputs_embeds) - - // const decoderFeeds = { - // inputs_embeds: inputs_embeds, - // attention_mask: ones_like(inputs_embeds), - // } - // let { input_ids, past_key_values, attention_mask } = model_inputs; - // let decoderFeeds = { - // inputs_embeds, - // attention_mask: ones(inputs_embeds.dims.slice(0, 2)), - // } - - // const use_cache_branch = !!past_key_values; - - // if (this.session.inputNames.includes('use_cache_branch')) { - // decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); - // } - - - // this.addPastKeyValues(decoderFeeds, past_key_values); - - // let decoderResults = await sessionRun(this.session, decoderFeeds); - - // let logits = decoderResults.logits; - - // past_key_values = this.getPastKeyValues(decoderResults, past_key_values); - // const result = { logits, past_key_values }; - // return result; - - // Retrieve the input embeddings - - // If an image is provided, compute the image embeddings - // if (!image_embeds && pixel_values) { - // image_embeds = await this.encode_image({ pixel_values }); - // console.log('image_embeds', image_embeds) - // } - - // let decoderFeeds = { - // inputs_embeds: inputs_embeds, - // attention_mask: attention_mask ?? ones_like(inputs_embeds), - // // prepareAttentionMask(this, input_ids), - // } - // const use_cache_branch = !!past_key_values; - - // if (this.session.inputNames.includes('use_cache_branch')) { - // decoderFeeds.use_cache_branch = boolTensor(use_cache_branch); - // } - - - // this.addPastKeyValues(decoderFeeds, past_key_values); - // console.log('run', decoderFeeds) - - // let decoderResults = await sessionRun(this.session, decoderFeeds); - - // let logits = decoderResults.logits; - - // past_key_values = this.getPastKeyValues(decoderResults, past_key_values); - // return { logits, past_key_values }; } } diff --git a/tests/tiny_random.test.js b/tests/tiny_random.test.js index 4fe1dfd47..39175ff20 100644 --- a/tests/tiny_random.test.js +++ b/tests/tiny_random.test.js @@ -27,6 +27,10 @@ import { BertForTokenClassification, BertForQuestionAnswering, MusicgenForConditionalGeneration, + LlavaForConditionalGeneration, + CLIPImageProcessor, + AutoProcessor, + RawImage, full, } from '../src/transformers.js'; @@ -382,6 +386,78 @@ describe('Tiny random models', () => { }); }); + + describe('llava', () => { + + const prompts = [ + // Example adapted from https://huggingface.co/docs/transformers/model_doc/llava#transformers.LlavaForConditionalGeneration.forward.example + "\nUSER: What's the content of the image?\nASSISTANT:", + "Hi", + ] + + // Empty white image + const dims = [224, 224, 3]; + const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); + + describe('LlavaForConditionalGeneration', () => { + const model_id = 'Xenova/tiny-random-LlavaForConditionalGeneration'; + + /** @type {LlavaForConditionalGeneration} */ + let model; + /** @type {LlamaTokenizer} */ + let tokenizer; + /** @type {CLIPImageProcessor} */ + let processor; + beforeAll(async () => { + model = await LlavaForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + tokenizer = await LlamaTokenizer.from_pretrained(model_id); + processor = await AutoProcessor.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it('forward', async () => { + const text_inputs = tokenizer(prompts[0]); + const vision_inputs = await processor(image); + const inputs = { ...text_inputs, ...vision_inputs }; + + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 244, 32002]); + expect(logits.mean().item()).toBeCloseTo(-0.0005755752790719271, 8); + }); + + it('batch_size=1', async () => { + const text_inputs = tokenizer(prompts[0]); + const vision_inputs = await processor(image); + const inputs = { ...text_inputs, ...vision_inputs }; + + const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([ + [1n, 32000n, 29871n, 13n, 11889n, 29901n, 1724n, 29915n, 29879n, 278n, 2793n, 310n, 278n, 1967n, 29973n, 13n, 22933n, 9047n, 13566n, 29901n, 21557n, 16781n, 27238n, 8279n, 20454n, 11927n, 12462n, 12306n, 2414n, 7561n] + ]); + }, MAX_TEST_EXECUTION_TIME); + + it('batch_size>1', async () => { + const text_inputs = tokenizer(prompts, { padding: true }); + const vision_inputs = await processor([image, image]); + const inputs = { ...text_inputs, ...vision_inputs }; + + const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); + expect(generate_ids.tolist()).toEqual([ + [1n, 32000n, 29871n, 13n, 11889n, 29901n, 1724n, 29915n, 29879n, 278n, 2793n, 310n, 278n, 1967n, 29973n, 13n, 22933n, 9047n, 13566n, 29901n, 21557n, 16781n, 27238n, 8279n, 20454n, 11927n, 12462n, 12306n, 2414n, 7561n], + [0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 1n, 32000n, 6324n, 1217n, 22958n, 22913n, 10381n, 148n, 31410n, 31736n, 7358n, 9150n, 28635n] + ]); + + }, MAX_TEST_EXECUTION_TIME); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + + describe('opt', () => { describe('OPTForCausalLM', () => { const model_id = 'hf-internal-testing/tiny-random-OPTForCausalLM'; From a0458cac1beffa93b5c1ad04203061bd62573fea Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 4 Apr 2024 01:42:10 +0200 Subject: [PATCH 066/473] Add Llava and Musicgen to supported models --- README.md | 2 ++ docs/snippets/6_supported-models.snippet | 2 ++ 2 files changed, 4 insertions(+) diff --git a/README.md b/README.md index 83a23490a..4945abf7e 100644 --- a/README.md +++ b/README.md @@ -310,10 +310,12 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom. +1. **[LLaVa](https://huggingface.co/docs/transformers/model_doc/llava)** (from Microsoft Research & University of Wisconsin-Madison) released with the paper [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. +1. **[MusicGen](https://huggingface.co/docs/transformers/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed. 1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index ac09bdbdb..552d6f184 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -45,10 +45,12 @@ 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LLaMA](https://huggingface.co/docs/transformers/model_doc/llama)** (from The FAIR team of Meta AI) released with the paper [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. 1. **[Llama2](https://huggingface.co/docs/transformers/model_doc/llama2)** (from The FAIR team of Meta AI) released with the paper [Llama2: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/XXX) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom. +1. **[LLaVa](https://huggingface.co/docs/transformers/model_doc/llava)** (from Microsoft Research & University of Wisconsin-Madison) released with the paper [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485) by Haotian Liu, Chunyuan Li, Yuheng Li and Yong Jae Lee. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[mBART](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[mBART-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. +1. **[MusicGen](https://huggingface.co/docs/transformers/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[Mistral](https://huggingface.co/docs/transformers/model_doc/mistral)** (from Mistral AI) by The [Mistral AI](https://mistral.ai) team: Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed. 1. **[MMS](https://huggingface.co/docs/transformers/model_doc/mms)** (from Facebook) released with the paper [Scaling Speech Technology to 1,000+ Languages](https://arxiv.org/abs/2305.13516) by Vineel Pratap, Andros Tjandra, Bowen Shi, Paden Tomasello, Arun Babu, Sayani Kundu, Ali Elkahky, Zhaoheng Ni, Apoorv Vyas, Maryam Fazel-Zarandi, Alexei Baevski, Yossi Adi, Xiaohui Zhang, Wei-Ning Hsu, Alexis Conneau, Michael Auli. 1. **[MobileBERT](https://huggingface.co/docs/transformers/model_doc/mobilebert)** (from CMU/Google Brain) released with the paper [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. From b51a9545ad77af367b631bba1fce78ed236f8b58 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 5 Apr 2024 01:21:57 +0200 Subject: [PATCH 067/473] Update llava attention mask --- src/models.js | 65 +++++++++------------------------------------------ 1 file changed, 11 insertions(+), 54 deletions(-) diff --git a/src/models.js b/src/models.js index 80b29e15f..566b7064e 100644 --- a/src/models.js +++ b/src/models.js @@ -3246,6 +3246,8 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// LLaVa Models export class LlavaPreTrainedModel extends PreTrainedModel { forward_params = [ 'input_ids', @@ -3277,9 +3279,6 @@ export class LlavaPreTrainedModel extends PreTrainedModel { */ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { - /** - * - */ async encode_image({ pixel_values }) { // image_inputs === { pixel_values } return (await sessionRun(this.vision_encoder_session, { pixel_values })).image_features; @@ -3290,55 +3289,6 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { return (await sessionRun(this.session, { input_ids })).inputs_embeds; } - // async answer_question({ - // image_embeds, - // question, - // tokenizer, - // chat_history = "", - // }) { - // const prompt = `\n\n${chat_history}Question: ${question}\n\nAnswer:`; - // const answer = await this.generate({ - // image_embeds, - // prompt, - // tokenizer, - // eos_text: "", - // max_new_tokens: 512, - // })[0]; - // console.log('answer', answer) - // return 'todo' - // const cleaned_answer = answer.replace(/<$| { - const input_ids = tokenizer(text); - return this.encode_text(input_ids); - } - - if (prompt.includes('')) { - const splits = prompt.split(''); - if (splits.length !== 2) { - throw new Error('Prompt should contain only one tag'); - } - const [before, after] = splits; - const embeds = []; - if (before.length > 0) { - embeds.push(await e(before)); - } - embeds.push(image_embeds); - if (after.length > 0) { - embeds.push(await e(after)); - } - console.log('embeds', embeds) - return cat(embeds, 1); - } else { - return await e(prompt); - } - } - _merge_input_ids_with_image_features({ inputs_embeds, image_features, @@ -3405,7 +3355,6 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { return model_inputs; } - /** * * @param {Object} params @@ -3452,9 +3401,17 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { input_ids, attention_mask, })); + } else if (past_key_values && pixel_values && input_ids.dims[1] === 1) { // In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of // generation with cache + const target_length = input_ids.dims[1]; // always 1 + const past_length = Object.values(past_key_values)[0].dims.at(-2); + + attention_mask = cat([ + ones([input_ids.dims[0], past_length]), + attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]), + ], 1); } } @@ -3468,6 +3425,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { return outputs; } } +////////////////////////////////////////////////// ////////////////////////////////////////////////// // CLIP models @@ -6237,7 +6195,6 @@ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ ['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]], - // ['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]], ]); const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ From ee53edc1575db914c2ba045dfda8dff5cddbcaf7 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 17:58:46 +0200 Subject: [PATCH 068/473] Fix musicgen --- src/generation/configuration_utils.js | 9 + src/generation/logits_process.js | 256 +++++++++++++++++--------- src/generation/logits_sampler.js | 5 - src/models.js | 104 ++++++++--- 4 files changed, 255 insertions(+), 119 deletions(-) diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js index 428b079a3..04039fffc 100644 --- a/src/generation/configuration_utils.js +++ b/src/generation/configuration_utils.js @@ -270,6 +270,15 @@ export class GenerationConfig { */ forced_decoder_ids = null; + /** + * The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + * Higher guidance scale encourages the model to generate samples that are more closely linked to the input + * prompt, usually at the expense of poorer quality. + * @type {number} + * @default null + */ + guidance_scale = null; + // Parameters that define the output variables of `generate` /** * The number of independently computed returned sequences for each element in the batch. diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index 409fd3ba8..a1be72d9b 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -11,7 +11,7 @@ export class LogitsProcessor extends Callable { * Apply the processor to the input logits. * * @abstract - * @param {number[]} input_ids The input ids. + * @param {number[][]} input_ids The input ids. * @param {Tensor} logits The logits to process. * @throws {Error} Throws an error if `_call` is not implemented in the subclass. */ @@ -29,7 +29,7 @@ export class LogitsWarper extends Callable { * Apply the processor to the input logits. * * @abstract - * @param {number[]} input_ids The input ids. + * @param {number[][]} input_ids The input ids. * @param {Tensor} logits The logits to process. * @throws {Error} Throws an error if `_call` is not implemented in the subclass. */ @@ -74,16 +74,16 @@ export class LogitsProcessorList extends Callable { /** * Applies all logits processors in the list to a batch of logits, modifying them in-place. * - * @param {number[]} input_ids The input IDs for the language model. - * @param {Tensor} logits + * @param {number[][]} input_ids The input IDs for the language model. + * @param {Tensor} logits */ _call(input_ids, logits) { - // NOTE: This is different from the Python code, since vanilla JS does not support vectorized operations. - // As a result, we apply each processor to each - // Modifies logits inplace - this.processors.forEach( - func => func(input_ids, logits) - ) + let toReturn = logits; + // NOTE: Most processors modify logits inplace + for (const processor of this.processors) { + toReturn = processor(input_ids, toReturn); + } + return toReturn; } [Symbol.iterator]() { @@ -110,7 +110,7 @@ export class LogitsProcessorList extends Callable { // /** // * Apply the processor to the input logits. // * -// * @param {number[]} input_ids The input ids. +// * @param {number[][]} input_ids The input ids. // * @param {Tensor} logits The logits to process. // * @returns {Tensor} The processed logits. // */ @@ -144,14 +144,17 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor { /** * Apply the BOS token forcing to the logits. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The logits with BOS token forcing. */ _call(input_ids, logits) { - if (input_ids.length === 1) { - logits.data.fill(-Infinity) - logits.data[this.bos_token_id] = 0; + for (let i = 0; i < input_ids.length; ++i) { + if (input_ids[i].length === 1) { + const batch_logits = logits[i]; + batch_logits.data.fill(-Infinity); + batch_logits.data[this.bos_token_id] = 0; + } } return logits; } @@ -175,7 +178,7 @@ export class ForcedEOSTokenLogitsProcessor extends LogitsProcessor { /** * Apply the processor to input_ids and logits. * - * @param {number[]} input_ids The input ids. + * @param {number[][]} input_ids The input ids. * @param {Tensor} logits The logits tensor. */ _call(input_ids, logits) { @@ -203,14 +206,17 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor { /** * Apply the BOS token forcing to the logits. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The logits with BOS token forcing. */ _call(input_ids, logits) { - if (input_ids.length === this.begin_index) { - for (let token_id of this.begin_suppress_tokens) { - logits.data[token_id] = -Infinity; + for (let i = 0; i < input_ids.length; ++i) { + if (input_ids[i].length === this.begin_index) { + const batch_logits = logits[i]; + for (const token_id of this.begin_suppress_tokens) { + batch_logits.data[token_id] = -Infinity; + } } } return logits; @@ -245,48 +251,51 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor { /** * Modify the logits to handle timestamp tokens. - * @param {number[]} input_ids The input sequence of tokens. + * @param {number[][]} input_ids The input sequence of tokens. * @param {Tensor} logits The logits output by the model. * @returns {Tensor} The modified logits. */ _call(input_ids, logits) { - const logitsData = /** @type {Float32Array} */(logits.data); + for (let i = 0; i < input_ids.length; ++i) { + const batch_logits = logits[i]; + const logitsData = /** @type {Float32Array} */(batch_logits.data); - // suppress <|notimestamps|> which is handled by without_timestamps - logitsData[this.no_timestamps_token_id] = -Infinity; + // suppress <|notimestamps|> which is handled by without_timestamps + logitsData[this.no_timestamps_token_id] = -Infinity; - if (input_ids.length === this.begin_index - 1) { - logitsData.fill(-Infinity); - logitsData[this.timestamp_begin] = 0; - return logits; - } + if (input_ids[i].length === this.begin_index - 1) { + logitsData.fill(-Infinity); + logitsData[this.timestamp_begin] = 0; + continue; + } - // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly - const seq = input_ids.slice(this.begin_index); - const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin; - const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin; + // timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly + const seq = input_ids[i].slice(this.begin_index); + const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin; + const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin; - if (last_was_timestamp) { - if (penultimate_was_timestamp) { // has to be non-timestamp - logitsData.subarray(this.timestamp_begin).fill(-Infinity); - } else { // cannot be normal text tokens - logitsData.subarray(0, this.eos_token_id).fill(-Infinity); + if (last_was_timestamp) { + if (penultimate_was_timestamp) { // has to be non-timestamp + logitsData.subarray(this.timestamp_begin).fill(-Infinity); + } else { // cannot be normal text tokens + logitsData.subarray(0, this.eos_token_id).fill(-Infinity); + } } - } - // apply the `max_initial_timestamp` option - if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) { - const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index; - logitsData.subarray(last_allowed + 1).fill(-Infinity); - } + // apply the `max_initial_timestamp` option + if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) { + const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index; + logitsData.subarray(last_allowed + 1).fill(-Infinity); + } - // if sum of probability over timestamps is above any other token, sample timestamp - const logprobs = log_softmax(logitsData); - const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b)); - const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0]; + // if sum of probability over timestamps is above any other token, sample timestamp + const logprobs = log_softmax(logitsData); + const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b)); + const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0]; - if (timestamp_logprob > max_text_token_logprob) { - logitsData.subarray(0, this.timestamp_begin).fill(-Infinity); + if (timestamp_logprob > max_text_token_logprob) { + logitsData.subarray(0, this.timestamp_begin).fill(-Infinity); + } } return logits; @@ -368,15 +377,17 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor { /** * Apply the no-repeat-ngram processor to the logits. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The logits with no-repeat-ngram processing. */ _call(input_ids, logits) { - const bannedTokens = this.calcBannedNgramTokens(input_ids); - - for (const token of bannedTokens) { - logits.data[token] = -Infinity; + for (let i = 0; i < input_ids.length; ++i) { + const batch_logits = logits[i]; + const bannedTokens = this.calcBannedNgramTokens(input_ids[i]); + for (const token of bannedTokens) { + batch_logits.data[token] = -Infinity; + } } return logits; } @@ -397,7 +408,7 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { /** * Apply the repetition penalty to the logits. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The logits with repetition penalty processing. */ @@ -405,13 +416,19 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor { // Modify the logits corresponding to each element in `input_ids`. // As a consequence, the logits corresponding to tokens that appear // many times in the output will be penalised more. - for (const input_id of input_ids) { - if (logits.data[input_id] < 0) { - logits.data[input_id] *= this.penalty; - } else { - logits.data[input_id] /= this.penalty; + + for (let i = 0; i < input_ids.length; ++i) { + const batch_logits = logits[i]; + + for (const input_id of input_ids[i]) { + if (batch_logits.data[input_id] < 0) { + batch_logits.data[input_id] *= this.penalty; + } else { + batch_logits.data[input_id] /= this.penalty; + } } } + return logits } } @@ -433,14 +450,17 @@ export class MinLengthLogitsProcessor extends LogitsProcessor { /** * Apply logit processor. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The processed logits. */ _call(input_ids, logits) { - if (input_ids.length < this.min_length) { - for (const eos_token of this.eos_token_id) { - logits.data[eos_token] = -Infinity; + for (let i = 0; i < input_ids.length; ++i) { + if (input_ids[i].length < this.min_length) { + const batch_logits = logits[i]; + for (const eos_token of this.eos_token_id) { + batch_logits.data[eos_token] = -Infinity; + } } } @@ -467,18 +487,20 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor { /** * Apply logit processor. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The processed logits. */ _call(input_ids, logits) { - const new_tokens_length = input_ids.length - this.prompt_length_to_skip; - if (new_tokens_length < this.min_new_tokens) { - for (const eos_token of this.eos_token_id) { - logits.data[eos_token] = -Infinity; + for (let i = 0; i < input_ids.length; ++i) { + const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip; + if (new_tokens_length < this.min_new_tokens) { + const batch_logits = logits[i]; + for (const eos_token of this.eos_token_id) { + batch_logits[eos_token] = -Infinity; + } } } - return logits } } @@ -497,32 +519,88 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { /** * Apply logit processor. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The processed logits. */ _call(input_ids, logits) { + for (let i = 0; i < input_ids.length; ++i) { + const batch_logits = logits[i]; + for (const bad_word_ids of this.bad_words_ids) { + // Whether to modify the logits of the last token in the bad word id sequence + let mark = true; + + // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last), + // then we set the logits of the last bad word id to -Infinity. + for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids[i].length; ++i) { + + if (bad_word_ids.at(-i - 1) !== input_ids[i].at(-i)) { + // We have found a mismatch + mark = false; + break; + } + } + if (mark) { + batch_logits[bad_word_ids.at(-1)] = -Infinity; + } + } + } + return logits + } +} - for (const bad_word_ids of this.bad_words_ids) { - // Whether to modify the logits of the last token in the bad word id sequence - let mark = true; +/** + * [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, + * where the first half correspond to the conditional logits (predicted from the input prompt) and the second half + * correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a + * weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. + * + * See [the paper](https://arxiv.org/abs/2306.05284) for more information. + */ +export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor { - // For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last), - // then we set the logits of the last bad word id to -Infinity. - for (let i = 1; i <= bad_word_ids.length - 1 && bad_word_ids.length < input_ids.length; ++i) { + /** + * Create a `ClassifierFreeGuidanceLogitsProcessor`. + * @param {number} guidance_scale The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + * Higher guidance scale encourages the model to generate samples that are more closely linked to the input + * prompt, usually at the expense of poorer quality. + */ + constructor(guidance_scale) { + super(); + if (guidance_scale <= 1) { + throw new Error( + `Require guidance scale >1 to use the classifier free guidance processor, got guidance scale ${guidance_scale}.` + ) + } + this.guidance_scale = guidance_scale; + } - if (bad_word_ids.at(-i - 1) !== input_ids.at(-i)) { - // We have found a mismatch - mark = false; - break; - } - } - if (mark) { - logits.data[bad_word_ids.at(-1)] = -Infinity; - } + /** + * Apply logit processor. + * @param {number[][]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + if (logits.dims[0] !== 2 * input_ids.length) { + throw new Error( + `Logits should have twice the batch size of the input ids, the first half of batches corresponding to ` + + `the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got ` + + `batch size ${logits.dims[0]} for the logits and ${input_ids.length} for the input ids.` + ) } - return logits + const unguided_bsz = input_ids.length; + const cond_logits = logits.slice([0, unguided_bsz], null); + const uncond_logits = logits.slice([unguided_bsz, logits.dims[0]], null); + + // Merge into uncond_logits (to save memory). This is equivalent to the following: + // scores = uncond_logits + (cond_logits - uncond_logits) * guidance_scale + for (let i = 0; i < uncond_logits.data.length; ++i) { + uncond_logits.data[i] += (cond_logits.data[i] - uncond_logits.data[i]) * this.guidance_scale; + } + + return uncond_logits; } } @@ -553,7 +631,7 @@ export class TemperatureLogitsWarper extends LogitsWarper { /** * Apply logit warper. - * @param {number[]} input_ids The input IDs. + * @param {number[][]} input_ids The input IDs. * @param {Tensor} logits The logits. * @returns {Object} The processed logits. */ diff --git a/src/generation/logits_sampler.js b/src/generation/logits_sampler.js index 00a405da8..c298d538a 100644 --- a/src/generation/logits_sampler.js +++ b/src/generation/logits_sampler.js @@ -62,11 +62,6 @@ export class LogitsSampler extends Callable { let startIndex = index * vocabSize; logs = logs.slice(startIndex, startIndex + vocabSize); } - - // add temperature - // if (this.generation_config.temperature > 0) { - // logs = logs.map(x => x / this.generation_config.temperature) - // } return logs; } diff --git a/src/models.js b/src/models.js index 566b7064e..9dcff7158 100644 --- a/src/models.js +++ b/src/models.js @@ -60,7 +60,6 @@ import { import { isIntegralNumber, - isTypedArray, mergeArrays, pick, } from './utils/core.js'; @@ -85,6 +84,7 @@ import { TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, + ClassifierFreeGuidanceLogitsProcessor, } from './generation/logits_process.js'; import { @@ -94,12 +94,14 @@ import { import { cat, dynamicTimeWarping, + full_like, mean, ones, ones_like, stack, std_mean, Tensor, + zeros_like, } from './utils/tensor.js'; import { medianFilter } from './utils/maths.js'; @@ -440,7 +442,7 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) { return await sessionRun(session, fixed); } -function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { +function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { if (self.session.inputNames.includes('position_ids') && model_inputs.attention_mask && !model_inputs.position_ids) { // If the model supports providing position_ids, we create position_ids on the fly for batch generation, @@ -477,7 +479,7 @@ function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { return model_inputs; } -function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs) { +function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { // console.log('model_inputs', model_inputs) const { ...new_model_inputs } = model_inputs; @@ -492,7 +494,7 @@ function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_in // input_ids; } new_model_inputs['decoder_input_ids'] = toI64Tensor(input_ids); - // throw new Error('Not implemented'); + return new_model_inputs; } ////////////////////////////////////////////////// @@ -833,6 +835,12 @@ export class PreTrainedModel extends Callable { // processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)); // } + + // 8. prepare batched CFG externally + if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { + processors.push(new ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)); + } + if (logits_processor !== null) { processors.extend(logits_processor) } @@ -995,12 +1003,27 @@ export class PreTrainedModel extends Callable { return { inputs_tensor, model_inputs, model_input_name: input_name }; } - async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name }) { + async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name, generation_config }) { const encoder_kwargs = pick(model_inputs, this.session.inputNames); - const encoder_outputs = await encoderForward(this, encoder_kwargs); + let { last_hidden_state } = await encoderForward(this, encoder_kwargs); + + // for classifier free guidance we need to add a 'null' input to our encoder hidden states + if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { - model_inputs['encoder_outputs'] = encoder_outputs.last_hidden_state; + last_hidden_state = cat([ + last_hidden_state, + full_like(last_hidden_state, 0.0), + ], 0); + + if ('attention_mask' in model_inputs) { + model_inputs['attention_mask'] = cat([ + model_inputs['attention_mask'], + zeros_like(model_inputs['attention_mask']), + ], 0); + } + } + model_inputs['encoder_outputs'] = last_hidden_state; return model_inputs; } @@ -1009,7 +1032,7 @@ export class PreTrainedModel extends Callable { * Prepares `decoder_input_ids` for generation with encoder-decoder models * @param {*} param0 */ - _prepare_decoder_input_ids_for_generation({ batch_size, model_input_name, model_kwargs, decoder_start_token_id, bos_token_id }) { + _prepare_decoder_input_ids_for_generation({ batch_size, model_input_name, model_kwargs, decoder_start_token_id, bos_token_id, generation_config }) { decoder_start_token_id = decoder_start_token_id ?? bos_token_id; @@ -1042,6 +1065,13 @@ export class PreTrainedModel extends Callable { const decoder_input_ids = decoder_input_ids_start; model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids); + // if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { + // model_kwargs['decoder_attention_mask'] = cat([ + // model_kwargs['decoder_attention_mask'], + // zeros_like(model_kwargs['decoder_attention_mask']), + // ], 0) + // } + return { input_ids: decoder_input_ids, model_inputs: model_kwargs }; } @@ -1079,7 +1109,7 @@ export class PreTrainedModel extends Callable { // if model is encoder decoder encoder_outputs are created // and added to `model_kwargs` model_inputs = await this._prepare_encoder_decoder_kwargs_for_generation( - { inputs_tensor, model_inputs, model_input_name } + { inputs_tensor, model_inputs, model_input_name, generation_config } ) } @@ -1095,7 +1125,8 @@ export class PreTrainedModel extends Callable { model_kwargs: model_inputs, decoder_start_token_id: generation_config.decoder_start_token_id, bos_token_id: generation_config.bos_token_id, - })) + generation_config, + })); } else { input_ids = model_inputs[model_input_name] } @@ -1168,7 +1199,7 @@ export class PreTrainedModel extends Callable { //////////////////////////////////////////////////// while (true) { // prepare model inputs - model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs) + model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs, generation_config); const outputs = await this.forward(model_inputs); @@ -1178,14 +1209,15 @@ export class PreTrainedModel extends Callable { // (equivalent to `logits = outputs.logits[:, -1, :]`) const logits = outputs.logits.slice(null, -1, null); + const next_tokens_scores = prepared_logits_processor(all_input_ids, logits); + // only for this batch const generated_input_ids = []; // const new_kv_cache = [];// NOTE: Only used for beam search when concatenating new kv // Loop over each batch - for (let batch_idx = 0; batch_idx < logits.dims.at(0); ++batch_idx) { - const logs = logits[batch_idx]; + for (let batch_idx = 0; batch_idx < next_tokens_scores.dims.at(0); ++batch_idx) { + const logs = next_tokens_scores[batch_idx]; - prepared_logits_processor(all_input_ids[batch_idx], logs); const sampledTokens = sampler(logs); for (const [newTokenId, logProb] of sampledTokens) { const bigint = BigInt(newTokenId); @@ -3295,7 +3327,6 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { input_ids, attention_mask, }) { - console.log('_merge_input_ids_with_image_features'); const image_token_index = this.config.image_token_index; @@ -3351,7 +3382,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { }; } - prepare_inputs_for_generation(input_ids, model_inputs) { + prepare_inputs_for_generation(input_ids, model_inputs, generation_config) { return model_inputs; } @@ -5863,7 +5894,7 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: let newDataSize = 0; for (let i = 0; i < outputs.size; ++i) { - if (outputs.data[i] === this.config.pad_token_id) { + if (outputs.data[i] === this.config.decoder.pad_token_id) { continue; } @@ -5881,7 +5912,31 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: // TODO: assert `inferred` is an integer return new Tensor( outputs.type, - outputs.data.slice(0, newDataSize), [batch_size, num_codebooks, inferred]); + outputs.data.slice(0, newDataSize), + [batch_size, num_codebooks, inferred] + ); + } + + + prepare_inputs_for_generation(input_ids, model_inputs, generation_config) { + // apply the delay pattern mask + let clonedInputIds = structuredClone(input_ids); + for (let i = 0; i < clonedInputIds.length; ++i) { + for (let j = 0; j < clonedInputIds[i].length; ++j) { + if ((i % this.config.decoder.num_codebooks) >= j) { + clonedInputIds[i][j] = BigInt(this.config.decoder.pad_token_id); + } + } + } + // for classifier free guidance we need to replicate the decoder args across the batch dim + // (we'll split these before sampling) + if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { + // [batch, seqLength] -> [2 * batch, seqLength] + clonedInputIds = clonedInputIds.concat(clonedInputIds); + } + + const prepped = super.prepare_inputs_for_generation(clonedInputIds, model_inputs, generation_config); + return prepped; } /** @@ -5890,19 +5945,18 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: * @returns {Promise} The output of the model, which can contain the generated token ids, attentions, and scores. */ async generate(options) { - let output_ids = await super.generate(options); + + const output_ids = await super.generate(options); // apply the pattern mask to the final ids - output_ids = this._apply_and_filter_by_delay_pattern_mask( + // tensor: int64[1,batch_size,4,chunk_length] + const audio_codes = this._apply_and_filter_by_delay_pattern_mask( /** @type {Tensor} */(output_ids) ).unsqueeze_(0); // append the frame dimension back to the audio codes - const output_values = await sessionRun(this.encodec_decode, { - // tensor: int64[1,batch_size,4,chunk_length] - audio_codes: output_ids, - }) + const { audio_values } = await sessionRun(this.encodec_decode, { audio_codes }) - return output_values.audio_values; + return audio_values; } } From 991118648c07707a97a808b396f1d7480f26829f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 18:00:38 +0200 Subject: [PATCH 069/473] Remove dead code --- src/models.js | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/models.js b/src/models.js index 9dcff7158..2dec7e7bf 100644 --- a/src/models.js +++ b/src/models.js @@ -1065,13 +1065,6 @@ export class PreTrainedModel extends Callable { const decoder_input_ids = decoder_input_ids_start; model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids); - // if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) { - // model_kwargs['decoder_attention_mask'] = cat([ - // model_kwargs['decoder_attention_mask'], - // zeros_like(model_kwargs['decoder_attention_mask']), - // ], 0) - // } - return { input_ids: decoder_input_ids, model_inputs: model_kwargs }; } From 44f8a0bbbf6689220da61eb056009d506a9a7c46 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 20:58:42 +0200 Subject: [PATCH 070/473] Fix JSDoc --- src/models.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index 2dec7e7bf..c1919e066 100644 --- a/src/models.js +++ b/src/models.js @@ -3390,7 +3390,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { * @param {Tensor} [params.past_key_values=null] * @param {Object} [params.generation_config=null] * @param {Object} [params.logits_processor=null] - * @returns + * @returns {Promise} The model's output tensor */ async forward({ // These are produced by the processors: From b88b6a2c555f123183beba5ae98cec537a9c9f9a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 21:08:33 +0200 Subject: [PATCH 071/473] Add return types to `Tensor` class --- src/utils/tensor.js | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index bf64456e0..92cb03d90 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -480,7 +480,7 @@ export class Tensor { * If you would like a copy, use `tensor.clone()` before squeezing. * * @param {number} [dim=null] If given, the input will be squeezed only in the specified dimensions. - * @returns The squeezed tensor + * @returns {Tensor} The squeezed tensor */ squeeze(dim = null) { return new Tensor( @@ -504,7 +504,7 @@ export class Tensor { * NOTE: The returned tensor shares the same underlying data with this tensor. * * @param {number} dim The index at which to insert the singleton dimension - * @returns The unsqueezed tensor + * @returns {Tensor} The unsqueezed tensor */ unsqueeze(dim = null) { return new Tensor( @@ -543,7 +543,7 @@ export class Tensor { * and ending with `end_dim` are flattened. The order of elements in input is unchanged. * @param {number} start_dim the first dim to flatten * @param {number} end_dim the last dim to flatten - * @returns The flattened tensor. + * @returns {Tensor} The flattened tensor. */ flatten(start_dim = 0, end_dim = -1) { return this.clone().flatten_(start_dim, end_dim); @@ -601,7 +601,7 @@ export class Tensor { * Clamps all elements in input into the range [ min, max ] * @param {number} min lower-bound of the range to be clamped to * @param {number} max upper-bound of the range to be clamped to - * @returns the output tensor. + * @returns {Tensor} the output tensor. */ clamp(min, max) { return this.clone().clamp_(min, max); @@ -619,7 +619,7 @@ export class Tensor { /** * Rounds elements of input to the nearest integer. - * @returns the output tensor. + * @returns {Tensor} the output tensor. */ round() { return this.clone().round_(); @@ -828,7 +828,7 @@ export function layer_norm(input, normalized_shape, { * Helper function to calculate new dimensions when performing a squeeze operation. * @param {number[]} dims The dimensions of the tensor. * @param {number|number[]|null} dim The dimension(s) to squeeze. - * @returns The new dimensions. + * @returns {number[]} The new dimensions. * @private */ function calc_squeeze_dims(dims, dim) { @@ -851,7 +851,7 @@ function calc_squeeze_dims(dims, dim) { * Helper function to calculate new dimensions when performing an unsqueeze operation. * @param {number[]} dims The dimensions of the tensor. * @param {number} dim The dimension to unsqueeze. - * @returns The new dimensions. + * @returns {number[]} The new dimensions. * @private */ function calc_unsqueeze_dims(dims, dim) { @@ -1038,7 +1038,7 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) { * @param {Tensor} input the input tensor. * @param {number|null} dim the dimension to reduce. * @param {boolean} keepdim whether the output tensor has dim retained or not. - * @returns A new tensor with means taken along the specified dimension. + * @returns {Tensor} A new tensor with means taken along the specified dimension. */ export function mean(input, dim = null, keepdim = false) { From ed748834d6bdaefed4f3858b2c17effb35722bf8 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 21:16:08 +0200 Subject: [PATCH 072/473] Fix JSDoc --- src/generation/stopping_criteria.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.js index c9863bca8..c343d3faa 100644 --- a/src/generation/stopping_criteria.js +++ b/src/generation/stopping_criteria.js @@ -116,7 +116,7 @@ export class EosTokenCriteria extends StoppingCriteria { * * @param {number[][]} input_ids * @param {number[][]} scores - * @returns + * @returns {boolean[]} */ _call(input_ids, scores) { return input_ids.map(ids => this.eos_token_id.includes(ids.at(-1))); From 194a2ce5412876ed90d8c41f7150b243a3757b89 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 22:39:58 +0200 Subject: [PATCH 073/473] Update _toctree.yml --- docs/source/_toctree.yml | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1fe9150f6..0a473e8d4 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -48,6 +48,19 @@ title: ONNX title: Backends isExpanded: false + - sections: + - local: api/generation/parameters + title: Parameters + - local: api/generation/configuration_utils + title: Configuration + - local: api/generation/logits_process + title: Logits Processors + - local: api/generation/logits_sampler + title: Logits Samplers + - local: api/generation/stopping_criteria + title: Stopping Criteria + title: Generation + isExpanded: false - sections: - local: api/utils/core title: Core @@ -61,8 +74,6 @@ title: Tensor - local: api/utils/maths title: Maths - - local: api/utils/generation - title: Generation - local: api/utils/data-structures title: Data Structures title: Utilities From 6b9fbca4293da381ca1b431e1205a53bef4b76af Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 22:43:33 +0200 Subject: [PATCH 074/473] Add `module` JSDoc annotations --- src/generation/configuration_utils.js | 5 +++++ src/generation/logits_process.js | 5 +++++ src/generation/logits_sampler.js | 4 ++++ src/generation/parameters.js | 4 ++++ src/generation/stopping_criteria.js | 4 ++++ 5 files changed, 22 insertions(+) diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js index 04039fffc..0becde567 100644 --- a/src/generation/configuration_utils.js +++ b/src/generation/configuration_utils.js @@ -1,3 +1,8 @@ + +/** + * @module generation/configuration_utils + */ + import { pick } from "../utils/core.js"; /** diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index a1be72d9b..92066788b 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -1,3 +1,8 @@ + +/** + * @module generation/configuration_utils + */ + import { Callable } from "../utils/generic.js"; import { Tensor } from "../utils/tensor.js"; diff --git a/src/generation/logits_sampler.js b/src/generation/logits_sampler.js index c298d538a..4b292ac01 100644 --- a/src/generation/logits_sampler.js +++ b/src/generation/logits_sampler.js @@ -1,4 +1,8 @@ +/** + * @module generation/logits_sampler + */ + import { Callable } from "../utils/generic.js"; import { Tensor } from "../utils/tensor.js"; diff --git a/src/generation/parameters.js b/src/generation/parameters.js index b1dc4b5a8..d0df2f0b0 100644 --- a/src/generation/parameters.js +++ b/src/generation/parameters.js @@ -1,4 +1,8 @@ +/** + * @module generation/parameters + */ + /** * @typedef {Object} GenerationFunctionParameters * @property {import('../utils/tensor.js').Tensor} [inputs=null] (`Tensor` of varying shape depending on the modality, *optional*): diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.js index c343d3faa..95da854ef 100644 --- a/src/generation/stopping_criteria.js +++ b/src/generation/stopping_criteria.js @@ -1,4 +1,8 @@ +/** + * @module generation/stopping_criteria + */ + import { Callable } from "../utils/generic.js"; // NOTE: From 1866791e6c84c188be5316d863ae11bed39d912c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 22:44:03 +0200 Subject: [PATCH 075/473] Fix typo --- src/generation/logits_process.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index 92066788b..999705ae1 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -1,6 +1,6 @@ /** - * @module generation/configuration_utils + * @module generation/logits_process */ import { Callable } from "../utils/generic.js"; From a3288289b2100b66a34a5fda466718d49cfe047c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 6 Apr 2024 23:00:13 +0200 Subject: [PATCH 076/473] Add example Musicgen code --- src/models.js | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/src/models.js b/src/models.js index c1919e066..b122005d0 100644 --- a/src/models.js +++ b/src/models.js @@ -5849,6 +5849,37 @@ export class MusicgenForCausalLM extends MusicgenPreTrainedModel { } /** * The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder, * for music generation tasks with one or both of text and audio prompts. + * + * **Example:** Generate music from text with `Xenova/musicgen-small`. + * ```javascript + * import { AutoTokenizer, MusicgenForConditionalGeneration } from '@xenova/transformers'; + * + * // Load tokenizer and model + * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/musicgen-small'); + * const model = await MusicgenForConditionalGeneration.from_pretrained( + * 'Xenova/musicgen-small', { dtype: 'fp32' } + * ); + * + * // Prepare text input + * const prompt = '80s pop track with bassy drums and synth'; + * const inputs = tokenizer(prompt); + * + * // Generate audio + * const audio_values = await model.generate({ + * ...inputs, + * max_new_tokens: 512, + * do_sample: true, + * guidance_scale: 3, + * }); + * + * // (Optional) Write the output to a WAV file + * import wavefile from 'wavefile'; + * import fs from 'fs'; + * + * const wav = new wavefile.WaveFile(); + * wav.fromScratch(1, model.config.audio_encoder.sampling_rate, '32f', audio_values.data); + * fs.writeFileSync('musicgen_out.wav', wav.toBuffer()); + * ``` */ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: not MusicgenPreTrainedModel forward_params = ['input_ids', 'attention_mask', 'encoder_outputs', 'decoder_input_ids', 'decoder_attention_mask', 'past_key_values']; From ead5220f36e06013c244575cf87f91912fb50ac7 Mon Sep 17 00:00:00 2001 From: flatsiedatsie Date: Wed, 10 Apr 2024 17:56:33 +0200 Subject: [PATCH 077/473] Update worker.js - fix TTS example (#686) * Update worker.js - fix TTS example I added the `dtype: 'fp32',` tip you gave me earlier, so the example should then work again. --- examples/text-to-speech-client/src/worker.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-to-speech-client/src/worker.js b/examples/text-to-speech-client/src/worker.js index 76b8f76ef..4644890d3 100644 --- a/examples/text-to-speech-client/src/worker.js +++ b/examples/text-to-speech-client/src/worker.js @@ -25,14 +25,14 @@ class MyTextToSpeechPipeline { if (this.model_instance === null) { this.model_instance = SpeechT5ForTextToSpeech.from_pretrained(this.model_id, { - quantized: false, + dtype: 'fp32', progress_callback, }); } if (this.vocoder_instance === null) { this.vocoder_instance = SpeechT5HifiGan.from_pretrained(this.vocoder_id, { - quantized: false, + dtype: 'fp32', progress_callback, }); } From c0ca37edbff912cd5e51c54a482d0ca19c8c1ffe Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 10 Apr 2024 23:33:42 +0200 Subject: [PATCH 078/473] Align tokenizer API with python version --- src/tokenizers.js | 71 +++++++++++++++++++++++++++++----------- tests/tokenizers.test.js | 2 +- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/src/tokenizers.js b/src/tokenizers.js index b06920a47..042f32e54 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2542,6 +2542,7 @@ export class PreTrainedTokenizer extends Callable { * @param {...string} keys One or more keys to search for in the tokenizer config object. * @returns {string|null} The value associated with the first matching key, or null if no match is found. * @throws {Error} If an object is found for a matching key and its __type property is not "AddedToken". + * @private */ getToken(...keys) { for (const key of keys) { @@ -2648,11 +2649,11 @@ export class PreTrainedTokenizer extends Callable { } encodedTokens = text.map( - (t, i) => this._encode_plus(t, text_pair[i], { add_special_tokens }) + (t, i) => this._encode_plus(t, { text_pair: text_pair[i], add_special_tokens }) ) } else { - encodedTokens = text.map(x => this._encode_plus(x, null, { add_special_tokens })); + encodedTokens = text.map(x => this._encode_plus(x, { add_special_tokens })); } } else { @@ -2665,7 +2666,7 @@ export class PreTrainedTokenizer extends Callable { } // For single input, we just wrap in an array, and then unwrap later. - encodedTokens = [this._encode_plus(text, text_pair, { add_special_tokens })]; + encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens })]; } // At this point, tokens is batched: [batch_size, tokens] // However, array may be jagged. So, we pad to max_length @@ -2820,51 +2821,83 @@ export class PreTrainedTokenizer extends Callable { * Encodes a single text or a pair of texts using the model's tokenizer. * * @param {string} text The text to encode. - * @param {string|null} text_pair The optional second text to encode. * @param {Object} options An optional object containing the following properties: + * @param {string} [options.text_pair=null] The optional second text to encode. * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. * @returns {EncodingSingle} An object containing the encoded text. * @private */ - _encode_plus(text, text_pair = null, { + _encode_plus(text, { + text_pair = null, add_special_tokens = true, } = {}) { - // Function called by users to encode possibly multiple texts - const tokens = this._encode_text(text); - const tokens2 = this._encode_text(text_pair); - const combinedTokens = this.post_processor - ? this.post_processor(tokens, tokens2, { add_special_tokens }) - : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; + const { tokens, token_type_ids } = this._tokenize_helper(text, { pair: text_pair, add_special_tokens }); - const input_ids = this.model.convert_tokens_to_ids(combinedTokens.tokens); + const input_ids = this.model.convert_tokens_to_ids(tokens); const result = { input_ids, attention_mask: new Array(input_ids.length).fill(1), } - if (this.return_token_type_ids && combinedTokens.token_type_ids) { - result.token_type_ids = combinedTokens.token_type_ids; + if (this.return_token_type_ids && token_type_ids) { + result.token_type_ids = token_type_ids; } return result; } + /** + * Internal helper function to tokenize a text, and optionally a pair of texts. + * @param {string} text The text to tokenize. + * @param {Object} options An optional object containing the following properties: + * @param {string} [options.pair=null] The optional second text to tokenize. + * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model. + * @returns {{tokens: string[], token_type_ids?: number[]}} An object containing the tokens and optionally the token type IDs. + */ + _tokenize_helper(text, { + pair = null, + add_special_tokens = false, + } = {}) { + const tokens = this._encode_text(text); + const tokens2 = this._encode_text(pair); + + return this.post_processor + ? this.post_processor(tokens, tokens2, { add_special_tokens }) + : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; + } + + /** + * Converts a string into a sequence of tokens. + * @param {string} text The sequence to be encoded. + * @param {Object} options An optional object containing the following properties: + * @param {string} [options.pair] A second sequence to be encoded with the first. + * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model. + * @returns {string[]} The list of tokens. + */ + tokenize(text, { + pair = null, + add_special_tokens = false, + } = {}) { + return this._tokenize_helper(text, { pair, add_special_tokens }).tokens; + } + /** * Encodes a single text or a pair of texts using the model's tokenizer. * * @param {string} text The text to encode. - * @param {string|null} text_pair The optional second text to encode. * @param {Object} options An optional object containing the following properties: + * @param {string} [options.text_pair=null] The optional second text to encode. * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. * @returns {number[]} An array of token IDs representing the encoded text(s). */ - encode(text, text_pair = null, { + encode(text, { + text_pair = null, add_special_tokens = true, } = {}) { - const { input_ids } = this._encode_plus(text, text_pair, { + return this._encode_plus(text, { + text_pair, add_special_tokens, - }); - return input_ids; + }).input_ids; } /** diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 8b92c6702..050b65967 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -140,7 +140,7 @@ describe('Tokenizers (hard-coded)', () => { const tokenizer = await AutoTokenizer.from_pretrained(m(tokenizerName), { legacy }); for (const [text, expected] of Object.entries(data)) { - const token_ids = tokenizer.encode(text, null, { add_special_tokens: false }); + const token_ids = tokenizer.encode(text, { add_special_tokens: false }); expect(token_ids).toEqual(expected); // If reversible, test that decoding produces the original text From 28fc75efcf918bcfc71946675b292a470b7eb5d0 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 12:26:34 +0200 Subject: [PATCH 079/473] Make a clone of `session_options` before passing to ORT ORT modifies the object in-place, which means you can't pass different session objects to multiple sessions --- src/models.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index b122005d0..ebc5ecd6d 100644 --- a/src/models.js +++ b/src/models.js @@ -160,7 +160,7 @@ async function constructSession(pretrained_model_name_or_path, fileName, options const buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options); - const session_options = options.session_options ?? {}; + const session_options = { ...options.session_options } ?? {}; // Overwrite `executionProviders` if not specified session_options.executionProviders ??= executionProviders; From 9af35996badbcf3a00096cebd06f6cc896887b20 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 17:14:40 +0200 Subject: [PATCH 080/473] Update default quantization settings for musicgen models --- scripts/convert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/convert.py b/scripts/convert.py index df33eac32..5d9ec0ea4 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -48,7 +48,6 @@ # Encoder-decoder models 'whisper', 'vision-encoder-decoder', - 'musicgen', # Encoder-only models 'owlv2', From 57da953a5235409981df0ab57a294a6abbc8b60a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 18:11:08 +0200 Subject: [PATCH 081/473] Refactor transformers.js `env` var --- src/backends/onnx.js | 20 +++------------ src/env.js | 58 +++++++++++++++++++++++++++++++++----------- src/transformers.js | 2 +- src/utils/dtypes.js | 43 ++++++++++++++++++++++---------- src/utils/hub.js | 4 +-- webpack.config.js | 1 - 6 files changed, 81 insertions(+), 47 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 710557648..111b14846 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -17,7 +17,7 @@ */ import path from 'path'; -import { env, RUNNING_LOCALLY } from '../env.js'; +import { env, apis } from '../env.js'; // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`. // In either case, we select the default export if it exists, otherwise we use the named export. @@ -26,22 +26,19 @@ import * as ONNX_WEB from 'onnxruntime-web/webgpu'; export { Tensor } from 'onnxruntime-common'; -const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; -const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node'; - /** @type {import('../utils/devices.js').DeviceType[]} */ const supportedExecutionProviders = []; /** @type {import('../utils/devices.js').DeviceType[]} */ let defaultExecutionProviders; let ONNX; -if (USE_ONNXRUNTIME_NODE) { +if (apis.IS_NODE_ENV) { ONNX = ONNX_NODE.default ?? ONNX_NODE; supportedExecutionProviders.push('cpu'); defaultExecutionProviders = ['cpu']; } else { ONNX = ONNX_WEB; - if (WEBGPU_AVAILABLE) { + if (apis.IS_WEBGPU_AVAILABLE) { supportedExecutionProviders.push('webgpu'); } supportedExecutionProviders.push('wasm'); @@ -75,13 +72,7 @@ export function deviceToExecutionProviders(device) { * @returns {Promise} The ONNX inference session. */ export async function createInferenceSession(buffer, session_options) { - - // NOTE: Important to create a clone, since ORT modifies the object. - const options = { - ...session_options - } - - return await InferenceSession.create(buffer, options); + return await InferenceSession.create(buffer, session_options); } /** @@ -102,9 +93,6 @@ if (ONNX_ENV?.wasm) { // https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths // We use remote wasm files by default to make it easier for newer users. // In practice, users should probably self-host the necessary .wasm files. - // ONNX_ENV.wasm.wasmPaths = RUNNING_LOCALLY - // ? path.join(env.__dirname, '/dist/') - // : `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`; // TODO: update this before release ONNX_ENV.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/'; diff --git a/src/env.js b/src/env.js index e4a736457..284cd231a 100644 --- a/src/env.js +++ b/src/env.js @@ -29,13 +29,42 @@ import url from 'url'; const VERSION = '3.0.0-alpha.0'; // Check if various APIs are available (depends on environment) -const BROWSER_ENV = typeof self !== 'undefined'; -const WEB_CACHE_AVAILABLE = BROWSER_ENV && 'caches' in self; -const FS_AVAILABLE = !isEmpty(fs); // check if file system is available -const PATH_AVAILABLE = !isEmpty(path); // check if path is available +const IS_BROWSER_ENV = typeof self !== 'undefined'; +const IS_WEBWORKER_ENV = IS_BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope'; +const IS_WEB_CACHE_AVAILABLE = IS_BROWSER_ENV && 'caches' in self; +const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator; -export const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE; +const IS_NODE_ENV = typeof process !== 'undefined' && process?.release?.name === 'node'; +const IS_FS_AVAILABLE = !isEmpty(fs); +const IS_PATH_AVAILABLE = !isEmpty(path); +/** + * A read-only object containing information about the APIs available in the current environment. + */ +export const apis = Object.freeze({ + /** Whether we are running in a browser environment */ + IS_BROWSER_ENV, + + /** Whether we are running in a web worker environment */ + IS_WEBWORKER_ENV, + + /** Whether the Cache API is available */ + IS_WEB_CACHE_AVAILABLE, + + /** Whether the WebGPU API is available */ + IS_WEBGPU_AVAILABLE, + + /** Whether we are running in a Node.js environment */ + IS_NODE_ENV, + + /** Whether the filesystem API is available */ + IS_FS_AVAILABLE, + + /** Whether the path API is available */ + IS_PATH_AVAILABLE, +}); + +const RUNNING_LOCALLY = IS_FS_AVAILABLE && IS_PATH_AVAILABLE; const __dirname = RUNNING_LOCALLY ? path.dirname(path.dirname(url.fileURLToPath(import.meta.url))) : './'; @@ -52,11 +81,11 @@ const localModelPath = RUNNING_LOCALLY : DEFAULT_LOCAL_MODEL_PATH; /** - * Global variable used to control execution. This provides users a simple way to configure Transformers.js. + * Global variable given visible to users to control execution. This provides users a simple way to configure Transformers.js. + * @typedef {Object} TransformersEnvironment + * @property {string} version This version of Transformers.js. * @property {Object} backends Expose environment variables of different backends, * allowing users to set these variables if they want to. - * @property {string} __dirname Directory name of module. Useful for resolving local paths. - * @property {string} version This version of Transformers.js. * @property {boolean} allowRemoteModels Whether to allow loading of remote files, defaults to `true`. * If set to `false`, it will have the same effect as setting `local_files_only=true` when loading pipelines, models, tokenizers, processors, etc. * @property {string} remoteHost Host URL to load models from. Defaults to the Hugging Face Hub. @@ -73,7 +102,10 @@ const localModelPath = RUNNING_LOCALLY * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache */ +/** @type {TransformersEnvironment} */ export const env = { + version: VERSION, + /////////////////// Backends settings /////////////////// // NOTE: These will be populated later by the backends themselves. backends: { @@ -84,22 +116,20 @@ export const env = { tfjs: {}, }, - __dirname, - version: VERSION, /////////////////// Model settings /////////////////// allowRemoteModels: true, remoteHost: 'https://huggingface.co/', remotePathTemplate: '{model}/resolve/{revision}/', - allowLocalModels: !BROWSER_ENV, + allowLocalModels: !IS_BROWSER_ENV, localModelPath: localModelPath, - useFS: FS_AVAILABLE, + useFS: IS_FS_AVAILABLE, /////////////////// Cache settings /////////////////// - useBrowserCache: WEB_CACHE_AVAILABLE, + useBrowserCache: IS_WEB_CACHE_AVAILABLE, - useFSCache: FS_AVAILABLE, + useFSCache: IS_FS_AVAILABLE, cacheDir: DEFAULT_CACHE_DIR, useCustomCache: false, diff --git a/src/transformers.js b/src/transformers.js index 9dcd0160c..25daa2efb 100644 --- a/src/transformers.js +++ b/src/transformers.js @@ -11,8 +11,8 @@ * @module transformers */ +export { env } from './env.js'; export * from './pipelines.js'; -export * from './env.js'; export * from './models.js'; export * from './tokenizers.js'; export * from './processors.js'; diff --git a/src/utils/dtypes.js b/src/utils/dtypes.js index 268c8b60e..8b64ff4a6 100644 --- a/src/utils/dtypes.js +++ b/src/utils/dtypes.js @@ -1,17 +1,36 @@ +import { apis } from "../env.js"; + import { DEVICE_TYPES } from "./devices.js"; // TODO: Use the adapter from `env.backends.onnx.webgpu.adapter` to check for `shader-f16` support, // when available in https://github.com/microsoft/onnxruntime/pull/19940. // For more information, see https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753 -async function isFp16Supported() { - try { - const adapter = await navigator.gpu.requestAdapter(); - return adapter.features.has('shader-f16'); - } catch (e) { - return false - } -} -export const FP16_SUPPORTED = await isFp16Supported(); + +/** + * Checks if fp16 support is available in the current environment. + */ +export const isFp16Supported = (function () { + /** @type {boolean} */ + let cachedResult; + + return async function () { + if (cachedResult === undefined) { + if (apis.IS_NODE_ENV) { + cachedResult = true; + } else if (!apis.IS_WEBGPU_AVAILABLE) { + cachedResult = false; + } else { + try { + const adapter = await navigator.gpu.requestAdapter(); + cachedResult = adapter.features.has('shader-f16'); + } catch (e) { + cachedResult = false; + } + } + } + return cachedResult; + }; +})(); export const DATA_TYPES = Object.freeze({ fp32: 'fp32', @@ -20,15 +39,13 @@ export const DATA_TYPES = Object.freeze({ int8: 'int8', uint8: 'uint8', }); - /** @typedef {keyof typeof DATA_TYPES} DataType */ -const defaultGpuDtype = FP16_SUPPORTED ? DATA_TYPES.fp16 : DATA_TYPES.fp32; export const DEFAULT_DEVICE_DTYPE_MAPPING = Object.freeze({ [DEVICE_TYPES.cpu]: DATA_TYPES.q8, - [DEVICE_TYPES.gpu]: defaultGpuDtype, + [DEVICE_TYPES.gpu]: DATA_TYPES.fp32, [DEVICE_TYPES.wasm]: DATA_TYPES.q8, - [DEVICE_TYPES.webgpu]: defaultGpuDtype, + [DEVICE_TYPES.webgpu]: DATA_TYPES.fp32, }); /** @type {Record} */ diff --git a/src/utils/hub.js b/src/utils/hub.js index 7f4b0547e..700aa3a74 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -29,8 +29,8 @@ import { dispatchCallback } from './core.js'; * @property {string} [subfolder='onnx'] In case the relevant files are located inside a subfolder of the model repo on huggingface.co, * you can specify the folder name here. * @property {string} [model_file_name=null] If specified, load the model with this name (excluding the .onnx suffix). Currently only valid for encoder- or decoder-only models. - * @property {import("./devices.js").DeviceType} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. - * @property {import("./dtypes.js").DataType} [dtype=null] The data type to use for the model. If not specified, the data type will be chosen from the environment settings. + * @property {import("./devices.js").DeviceType|Record} [device=null] The device to run the model on. If not specified, the device will be chosen from the environment settings. + * @property {import("./dtypes.js").DataType|Record} [dtype=null] The data type to use for the model. If not specified, the data type will be chosen from the environment settings. * @property {Object} [session_options] (Optional) User-specified session options passed to the runtime. If not provided, suitable defaults will be chosen. */ diff --git a/webpack.config.js b/webpack.config.js index b0751cced..ab7549932 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -63,7 +63,6 @@ function buildConfig({ }, experiments: { outputModule, - topLevelAwait: true, }, resolve: { alias }, From 82680b890fc6bd8e4508c9263d4d622bfff653a7 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 18:11:48 +0200 Subject: [PATCH 082/473] Store sessions for a model as a record --- src/models.js | 300 ++++++++++++++++++++++++++++---------------------- 1 file changed, 166 insertions(+), 134 deletions(-) diff --git a/src/models.js b/src/models.js index ebc5ecd6d..41988f169 100644 --- a/src/models.js +++ b/src/models.js @@ -51,7 +51,8 @@ import { import { DATA_TYPES, DEFAULT_DEVICE_DTYPE_MAPPING, - DEFAULT_DTYPE_SUFFIX_MAPPING, FP16_SUPPORTED, + DEFAULT_DTYPE_SUFFIX_MAPPING, + isFp16Supported, } from './utils/dtypes.js'; import { @@ -137,20 +138,40 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. - * @returns {Promise} A Promise that resolves to an InferenceSession object. + * @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object. * @private */ -async function constructSession(pretrained_model_name_or_path, fileName, options) { +async function getSession(pretrained_model_name_or_path, fileName, options) { + let device = options.device; + if (device && typeof device !== 'string') { + if (device.hasOwnProperty(fileName)) { + device = device[fileName]; + } else { + console.warn(`Device not specified for ${fileName}. Using the default device.`); + device = null; + } + } // If the device is not specified, we use the default (supported) execution providers. - const executionProviders = deviceToExecutionProviders(options.device); + const executionProviders = deviceToExecutionProviders( + /** @type {import("./utils/devices.js").DeviceType|null} */(device) + ); // If options.dtype is specified, we use it to choose the suffix for the model file. // Otherwise, we use the default dtype for the device. - const dtype = options.dtype ?? DEFAULT_DEVICE_DTYPE_MAPPING[executionProviders[0]]; + let dtype = options.dtype; + if (typeof dtype !== 'string') { + if (dtype && dtype.hasOwnProperty(fileName)) { + dtype = dtype[fileName]; + } else { + dtype = DEFAULT_DEVICE_DTYPE_MAPPING[executionProviders[0]]; + console.warn(`Dtype not specified for ${fileName}. Using the default dtype: ${dtype}.`); + } + } + if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(dtype)) { throw new Error(`Invalid dtype: ${dtype}. Should be one of: ${Object.keys(DATA_TYPES).join(', ')}`); - } else if (dtype === DATA_TYPES.fp16 && !FP16_SUPPORTED) { + } else if (dtype === DATA_TYPES.fp16 && !(await isFp16Supported())) { throw new Error(`The device does not support fp16.`); } @@ -184,8 +205,32 @@ async function constructSession(pretrained_model_name_or_path, fileName, options // options.session_options.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer'; // } // } + return { buffer, session_options }; +} + +/** + * Helper function to sequentially create multiple InferenceSession objects. + * NOTE: It is important to create the sessions sequentially, otherwise ORT will throw an error indicating + * that multiple calls to `initWasm` were made. + * + * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. + * @param {Record} names The names of the model files to load. + * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. + * @returns {Promise>} A Promise that resolves to a dictionary of InferenceSession objects. + */ +async function constructSessions(pretrained_model_name_or_path, names, options) { + const keys = Object.keys(names); + const sessionData = await Promise.all( + keys.map(async (name) => getSession(pretrained_model_name_or_path, names[name], options)) + ); - return await createInferenceSession(buffer, session_options); + const sessions = {}; + for (let i = 0; i < keys.length; ++i) { + const { buffer, session_options } = sessionData[i]; + const session = await createInferenceSession(buffer, session_options); + sessions[keys[i]] = session; + } + return sessions; } /** @@ -375,7 +420,7 @@ async function seq2seqForward(self, model_inputs) { // Encode if needed if (!encoder_outputs) { - const encoder_inputs = pick(model_inputs, self.session.inputNames); + const encoder_inputs = pick(model_inputs, self.sessions['model'].inputNames); // Encoder outputs are not given, so we must compute them. encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state; } @@ -384,7 +429,7 @@ async function seq2seqForward(self, model_inputs) { other_decoder_inputs.input_ids = decoder_input_ids; other_decoder_inputs.encoder_hidden_states = encoder_outputs; - if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { + if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) { other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask } @@ -404,11 +449,12 @@ async function seq2seqForward(self, model_inputs) { * @private */ async function encoderForward(self, model_inputs) { + const session = self.sessions['model']; const encoderFeeds = Object.create(null); - for (const key of self.session.inputNames) { + for (const key of session.inputNames) { encoderFeeds[key] = model_inputs[key]; } - if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { + if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it, // but they weren't created by the tokenizer. encoderFeeds.token_type_ids = new Tensor( @@ -417,7 +463,7 @@ async function encoderForward(self, model_inputs) { encoderFeeds.input_ids.dims ) } - return await sessionRun(self.session, encoderFeeds); + return await sessionRun(session, encoderFeeds); } /** @@ -428,9 +474,12 @@ async function encoderForward(self, model_inputs) { * @private */ async function decoderForward(self, model_inputs, is_encoder_decoder = false) { - const { past_key_values, ...new_model_inputs } = model_inputs; - const session = is_encoder_decoder ? self.decoder_merged_session : self.session; + const session = self.sessions[ + is_encoder_decoder ? 'decoder_model_merged' : 'model' + ] + + const { past_key_values, ...new_model_inputs } = model_inputs; if (session.inputNames.includes('use_cache_branch')) { new_model_inputs.use_cache_branch = boolTensor(!!past_key_values); @@ -443,8 +492,9 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) { } function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) { + const session = self.sessions['model']; - if (self.session.inputNames.includes('position_ids') && model_inputs.attention_mask && !model_inputs.position_ids) { + if (session.inputNames.includes('position_ids') && model_inputs.attention_mask && !model_inputs.position_ids) { // If the model supports providing position_ids, we create position_ids on the fly for batch generation, // by computing the cumulative sum of the attention mask along the sequence length dimension. // @@ -509,13 +559,13 @@ export class PreTrainedModel extends Callable { /** * Creates a new instance of the `PreTrainedModel` class. * @param {Object} config The model configuration. - * @param {any} session session for the model. + * @param {Record} sessions The inference sessions for the model. */ - constructor(config, session) { + constructor(config, sessions) { super(); this.config = config; - this.session = session; + this.sessions = sessions; const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor); const modelType = MODEL_TYPE_MAPPING.get(modelName); @@ -614,47 +664,59 @@ export class PreTrainedModel extends Callable { if (modelType === MODEL_TYPES.DecoderOnly) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'model', options), + constructSessions(pretrained_model_name_or_path, { + model: options.model_file_name ?? 'model', + }, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + constructSessions(pretrained_model_name_or_path, { + model: 'encoder_model', + decoder_model_merged: 'decoder_model_merged', + }, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.MaskGeneration) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'vision_encoder', options), - constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options), + constructSessions(pretrained_model_name_or_path, { + model: 'vision_encoder', + prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder', + }, options), ]); } else if (modelType === MODEL_TYPES.EncoderDecoder) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + constructSessions(pretrained_model_name_or_path, { + model: 'encoder_model', + decoder_model_merged: 'decoder_model_merged', + }, options), ]); } else if (modelType === MODEL_TYPES.ImageTextToText) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'embed_tokens', options), - constructSession(pretrained_model_name_or_path, 'vision_encoder', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + constructSessions(pretrained_model_name_or_path, { + model: 'embed_tokens', + vision_encoder: 'vision_encoder', + decoder_model_merged: 'decoder_model_merged', + }, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); } else if (modelType === MODEL_TYPES.Musicgen) { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'text_encoder', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - constructSession(pretrained_model_name_or_path, 'encodec_decode', options), + constructSessions(pretrained_model_name_or_path, { + model: 'text_encoder', + decoder_model_merged: 'decoder_model_merged', + encodec_decode: 'encodec_decode', + }, options), getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), ]); @@ -664,7 +726,9 @@ export class PreTrainedModel extends Callable { } info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'model', options) + constructSessions(pretrained_model_name_or_path, { + model: options.model_file_name ?? 'model', + }, options), ]); } @@ -1004,7 +1068,7 @@ export class PreTrainedModel extends Callable { } async _prepare_encoder_decoder_kwargs_for_generation({ inputs_tensor, model_inputs, model_input_name, generation_config }) { - const encoder_kwargs = pick(model_inputs, this.session.inputNames); + const encoder_kwargs = pick(model_inputs, this.sessions['model'].inputNames); let { last_hidden_state } = await encoderForward(this, encoder_kwargs); @@ -2391,13 +2455,11 @@ export class T5PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `T5PreTrainedModel` class. * @param {Object} config The model configuration. - * @param {any} session session for the model. - * @param {any} decoder_merged_session session for the decoder. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.num_decoder_layers; @@ -2428,13 +2490,11 @@ export class LongT5PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `LongT5ForConditionalGeneration` class. * @param {Object} config The model configuration. - * @param {any} session session for the model. - * @param {any} decoder_merged_session session for the decoder. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.num_decoder_layers; @@ -2465,14 +2525,12 @@ export class MT5PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `MT5ForConditionalGeneration` class. - * @param {any} config The model configuration. - * @param {any} session The ONNX session containing the encoder weights. - * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.num_decoder_layers; @@ -2499,14 +2557,12 @@ export class BartPretrainedModel extends PreTrainedModel { /** * Creates a new instance of the `BartForConditionalGeneration` class. - * @param {Object} config The configuration object for the Bart model. - * @param {Object} session The ONNX session used to execute the model. - * @param {Object} decoder_merged_session The ONNX session used to execute the decoder. - * @param {Object} generation_config The generation configuration object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -2552,14 +2608,12 @@ export class MBartPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `MBartForConditionalGeneration` class. - * @param {Object} config The configuration object for the Bart model. - * @param {Object} session The ONNX session used to execute the model. - * @param {Object} decoder_merged_session The ONNX session used to execute the decoder. - * @param {Object} generation_config The generation configuration object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -2601,13 +2655,12 @@ export class MBartForSequenceClassification extends MBartPreTrainedModel { export class MBartForCausalLM extends MBartPreTrainedModel { /** * Creates a new instance of the `MBartForCausalLM` class. - * @param {Object} config Configuration object for the model. - * @param {Object} decoder_merged_session ONNX Session object for the decoder. - * @param {Object} generation_config Configuration object for the generation process. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, decoder_merged_session, generation_config) { - super(config, decoder_merged_session); - this.generation_config = generation_config; + constructor(config, sessions, generation_config) { + super(config, sessions, generation_config); this.num_decoder_layers = this.config.decoder_layers; this.num_decoder_heads = this.config.decoder_attention_heads; @@ -2627,14 +2680,12 @@ export class BlenderbotPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `BlenderbotForConditionalGeneration` class. - * @param {any} config The model configuration. - * @param {any} session The ONNX session containing the encoder weights. - * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -2665,14 +2716,12 @@ export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `BlenderbotForConditionalGeneration` class. - * @param {any} config The model configuration. - * @param {any} session The ONNX session containing the encoder weights. - * @param {any} decoder_merged_session The ONNX session containing the merged decoder weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -2929,14 +2978,12 @@ export class WhisperPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `WhisperForConditionalGeneration` class. - * @param {Object} config Configuration object for the model. - * @param {Object} session ONNX Session object for the model. - * @param {Object} decoder_merged_session ONNX Session object for the decoder. - * @param {Object} generation_config Configuration object for the generation process. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -3215,14 +3262,12 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { /** * Creates a new instance of the `VisionEncoderDecoderModel` class. - * @param {Object} config The configuration object specifying the hyperparameters and other model settings. - * @param {Object} session The ONNX session containing the encoder model. - * @param {any} decoder_merged_session The ONNX session containing the merged decoder model. - * @param {Object} generation_config Configuration object for the generation process. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. + * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // Extract configs @@ -3247,7 +3292,7 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { // @ts-ignore const decoderModelClass = decoderModel[1]; // @ts-ignore - const decoder = new decoderModelClass(decoderConfig, decoder_merged_session, generation_config); + const decoder = new decoderModelClass(decoderConfig, { /* No sessions */ }, generation_config); this.add_encoder_pkv = 'num_decoder_layers' in decoder; if (this.add_encoder_pkv) { @@ -3281,11 +3326,8 @@ export class LlavaPreTrainedModel extends PreTrainedModel { 'attention_mask', ]; - constructor(config, input_embeds_session, vision_encoder_session, decoder_merged_session, generation_config) { - super(config, input_embeds_session); - this.vision_encoder_session = vision_encoder_session; - this.decoder_merged_session = decoder_merged_session; - + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; const decoderConfig = this.config.text_config; @@ -3306,12 +3348,12 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { async encode_image({ pixel_values }) { // image_inputs === { pixel_values } - return (await sessionRun(this.vision_encoder_session, { pixel_values })).image_features; + return (await sessionRun(this.sessions['vision_encoder'], { pixel_values })).image_features; } async encode_text({ input_ids }) { // text_inputs === { input_ids, attention_mask } - return (await sessionRun(this.session, { input_ids })).inputs_embeds; + return (await sessionRun(this.sessions['input_embeds'], { input_ids })).inputs_embeds; } _merge_input_ids_with_image_features({ @@ -4860,14 +4902,12 @@ export class MarianPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `MarianMTModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. - * @param {any} decoder_merged_session - * @param {any} generation_config - */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -4891,14 +4931,12 @@ export class M2M100PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `M2M100ForConditionalGeneration` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. - * @param {any} decoder_merged_session - * @param {any} generation_config - */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -5338,13 +5376,11 @@ export class SpeechT5PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `SpeechT5ForTextToSpeech` class. * @param {Object} config The model configuration. - * @param {any} session session for the model. - * @param {any} decoder_merged_session session for the decoder. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; @@ -5478,7 +5514,7 @@ export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel { }; this.addPastKeyValues(decoderFeeds, past_key_values); - decoder_outputs = await sessionRun(this.decoder_merged_session, decoderFeeds); + decoder_outputs = await sessionRun(this.sessions['decoder_model_merged'], decoderFeeds); past_key_values = this.getPastKeyValues(decoder_outputs, past_key_values); const { prob, spectrum } = decoder_outputs; @@ -5493,7 +5529,7 @@ export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel { } const spectrogram = cat(spectrogramParts); - const { waveform } = await sessionRun(vocoder.session, { spectrogram }); + const { waveform } = await sessionRun(vocoder.sessions['model'], { spectrogram }); return { spectrogram, @@ -5887,15 +5923,11 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: /** * Creates a new instance of the `MusicgenForConditionalGeneration` class. * @param {Object} config The model configuration. - * @param {any} session session for the model. - * @param {any} decoder_merged_session session for the decoder. - * @param {any} encodec_decode session for the encodec.decode function. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, decoder_merged_session, encodec_decode, generation_config) { - super(config, session); - this.decoder_merged_session = decoder_merged_session; - this.encodec_decode = encodec_decode; + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // decoder @@ -5978,7 +6010,7 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE: /** @type {Tensor} */(output_ids) ).unsqueeze_(0); // append the frame dimension back to the audio codes - const { audio_values } = await sessionRun(this.encodec_decode, { audio_codes }) + const { audio_values } = await sessionRun(this.sessions['encodec_decode'], { audio_codes }) return audio_values; } From 549cf6ad93681d080c340da0a33107aa50bf2fed Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 18:13:03 +0200 Subject: [PATCH 083/473] Mark `constructSessions` as private --- src/models.js | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models.js b/src/models.js index 41988f169..5b0960fed 100644 --- a/src/models.js +++ b/src/models.js @@ -217,6 +217,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) { * @param {Record} names The names of the model files to load. * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model. * @returns {Promise>} A Promise that resolves to a dictionary of InferenceSession objects. + * @private */ async function constructSessions(pretrained_model_name_or_path, names, options) { const keys = Object.keys(names); From 8834263d41343042f968b2d63d1998ee7afa43f2 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 18:27:40 +0200 Subject: [PATCH 084/473] Update llava session name --- src/models.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models.js b/src/models.js index 5b0960fed..4503cf8de 100644 --- a/src/models.js +++ b/src/models.js @@ -703,7 +703,7 @@ export class PreTrainedModel extends Callable { info = await Promise.all([ AutoConfig.from_pretrained(pretrained_model_name_or_path, options), constructSessions(pretrained_model_name_or_path, { - model: 'embed_tokens', + embed_tokens: 'embed_tokens', vision_encoder: 'vision_encoder', decoder_model_merged: 'decoder_model_merged', }, options), @@ -3354,7 +3354,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { async encode_text({ input_ids }) { // text_inputs === { input_ids, attention_mask } - return (await sessionRun(this.sessions['input_embeds'], { input_ids })).inputs_embeds; + return (await sessionRun(this.sessions['embed_tokens'], { input_ids })).inputs_embeds; } _merge_input_ids_with_image_features({ From cf4bfdfac914e9b7eab22da6b36722b8a84a8cb6 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 11 Apr 2024 18:32:57 +0200 Subject: [PATCH 085/473] Remove unused function --- src/utils/core.js | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/utils/core.js b/src/utils/core.js index ba43aaf40..453d0ba08 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -63,15 +63,6 @@ export function isIntegralNumber(x) { return Number.isInteger(x) || typeof x === 'bigint' } -/** - * Check if a value is exists. - * @param {*} x The value to check. - * @returns {boolean} True if the value exists, false otherwise. - */ -export function exists(x) { - return x !== undefined && x !== null; -} - /** * Calculates the dimensions of a nested array. * From 667b2c0e819ce886caeab166ba0f8720b3897964 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 12 Apr 2024 13:44:43 +0200 Subject: [PATCH 086/473] Add support for token streaming --- docs/source/_toctree.yml | 2 ++ src/generation/parameters.js | 3 +++ src/generation/streamers.js | 21 +++++++++++++++++++++ src/models.js | 15 ++++++++++++--- 4 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 src/generation/streamers.js diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0a473e8d4..4458c049b 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,6 +59,8 @@ title: Logits Samplers - local: api/generation/stopping_criteria title: Stopping Criteria + - local: api/generation/streamers + title: Streamers title: Generation isExpanded: false - sections: diff --git a/src/generation/parameters.js b/src/generation/parameters.js index d0df2f0b0..7a3269f2e 100644 --- a/src/generation/parameters.js +++ b/src/generation/parameters.js @@ -26,5 +26,8 @@ * Custom stopping criteria that complements the default stopping criteria built from arguments and a * generation config. If a stopping criteria is passed that is already created with the arguments or a * generation config an error is thrown. This feature is intended for advanced users. + * @property {import('./streamers.js').BaseStreamer} [streamer=null] (`BaseStreamer`, *optional*): + * Streamer object that will be used to stream the generated sequences. Generated tokens are passed + * through `streamer.put(token_ids)` and the streamer is responsible for any further processing. * @param {any} [kwargs] (`Dict[str, any]`, *optional*): */ diff --git a/src/generation/streamers.js b/src/generation/streamers.js new file mode 100644 index 000000000..a8b53da62 --- /dev/null +++ b/src/generation/streamers.js @@ -0,0 +1,21 @@ + +/** + * @module generation/streamers + */ + +export class BaseStreamer { + /** + * Function that is called by `.generate()` to push new tokens + * @param {bigint[][]} value + */ + put(value) { + throw Error('Not implemented'); + } + + /** + * Function that is called by `.generate()` to signal the end of generation + */ + end() { + throw Error('Not implemented'); + } +} diff --git a/src/models.js b/src/models.js index 4503cf8de..7b3bc648f 100644 --- a/src/models.js +++ b/src/models.js @@ -1143,6 +1143,7 @@ export class PreTrainedModel extends Callable { generation_config = null, logits_processor = null, stopping_criteria = null, + streamer = null, // inputs_attention_mask = null, ...kwargs @@ -1242,7 +1243,11 @@ export class PreTrainedModel extends Callable { // TODO make > numInputs const scores = new Array(numInputs).fill(0); + /** @type {bigint[][]} */ const all_input_ids = input_ids.tolist(); + if (streamer) { + streamer.put(all_input_ids); + } // const all_generated_input_ids = Array.from({ length: numInputs }, () => []); // NOTE: For now, we don't support spawning new beams @@ -1286,9 +1291,9 @@ export class PreTrainedModel extends Callable { generated_input_ids.push(bigint); } } - // if(streamer) { - // streamer.put(next_tokens.cpu()) - // } + if (streamer) { + streamer.put(all_input_ids); + } const stop = prepared_stopping_criteria(all_input_ids); if (stop.every(x => x)) { @@ -1300,6 +1305,10 @@ export class PreTrainedModel extends Callable { }) } + if (streamer) { + streamer.end(); + } + // TODO: ensure all_input_ids is padded correctly... return new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]); From 33ff33617158e3f85c115cd36661f583f5839775 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 12 Apr 2024 14:28:42 +0200 Subject: [PATCH 087/473] Remove warning debug log --- src/models.js | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/models.js b/src/models.js index 7b3bc648f..a63fa1f93 100644 --- a/src/models.js +++ b/src/models.js @@ -1030,8 +1030,7 @@ export class PreTrainedModel extends Callable { ], 1 ); } else if ('decoder_attention_mask' in model_inputs) { - // update decoder attention mask - console.warn('TODO: update decoder attention mask') + // TODO: update decoder attention mask if the model requires it } // force recreate position_ids in next iteration From 7194d06d7916b73c1663def6d5b34af7d954531a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 12 Apr 2024 14:29:30 +0200 Subject: [PATCH 088/473] Make `BaseStreamer` visible to users --- src/transformers.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers.js b/src/transformers.js index 25daa2efb..af2a116f4 100644 --- a/src/transformers.js +++ b/src/transformers.js @@ -22,3 +22,5 @@ export * from './utils/audio.js'; export * from './utils/image.js'; export * from './utils/tensor.js'; export * from './utils/maths.js'; + +export { BaseStreamer } from './generation/streamers.js'; From e205f182180f33cca7ef29246910fec7eb286ed9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 12 Apr 2024 16:49:21 +0200 Subject: [PATCH 089/473] Create MusicGen Web demo --- examples/musicgen-web/.eslintrc.cjs | 21 +++ examples/musicgen-web/.gitignore | 24 +++ examples/musicgen-web/README.md | 8 + examples/musicgen-web/index.html | 12 ++ examples/musicgen-web/package.json | 30 +++ examples/musicgen-web/postcss.config.js | 6 + examples/musicgen-web/src/App.css | 9 + examples/musicgen-web/src/App.jsx | 228 +++++++++++++++++++++++ examples/musicgen-web/src/index.css | 3 + examples/musicgen-web/src/main.jsx | 10 + examples/musicgen-web/src/utils.js | 59 ++++++ examples/musicgen-web/tailwind.config.js | 12 ++ examples/musicgen-web/vite.config.js | 7 + 13 files changed, 429 insertions(+) create mode 100644 examples/musicgen-web/.eslintrc.cjs create mode 100644 examples/musicgen-web/.gitignore create mode 100644 examples/musicgen-web/README.md create mode 100644 examples/musicgen-web/index.html create mode 100644 examples/musicgen-web/package.json create mode 100644 examples/musicgen-web/postcss.config.js create mode 100644 examples/musicgen-web/src/App.css create mode 100644 examples/musicgen-web/src/App.jsx create mode 100644 examples/musicgen-web/src/index.css create mode 100644 examples/musicgen-web/src/main.jsx create mode 100644 examples/musicgen-web/src/utils.js create mode 100644 examples/musicgen-web/tailwind.config.js create mode 100644 examples/musicgen-web/vite.config.js diff --git a/examples/musicgen-web/.eslintrc.cjs b/examples/musicgen-web/.eslintrc.cjs new file mode 100644 index 000000000..3e212e1d4 --- /dev/null +++ b/examples/musicgen-web/.eslintrc.cjs @@ -0,0 +1,21 @@ +module.exports = { + root: true, + env: { browser: true, es2020: true }, + extends: [ + 'eslint:recommended', + 'plugin:react/recommended', + 'plugin:react/jsx-runtime', + 'plugin:react-hooks/recommended', + ], + ignorePatterns: ['dist', '.eslintrc.cjs'], + parserOptions: { ecmaVersion: 'latest', sourceType: 'module' }, + settings: { react: { version: '18.2' } }, + plugins: ['react-refresh'], + rules: { + 'react/jsx-no-target-blank': 'off', + 'react-refresh/only-export-components': [ + 'warn', + { allowConstantExport: true }, + ], + }, +} diff --git a/examples/musicgen-web/.gitignore b/examples/musicgen-web/.gitignore new file mode 100644 index 000000000..a547bf36d --- /dev/null +++ b/examples/musicgen-web/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/examples/musicgen-web/README.md b/examples/musicgen-web/README.md new file mode 100644 index 000000000..f768e33fc --- /dev/null +++ b/examples/musicgen-web/README.md @@ -0,0 +1,8 @@ +# React + Vite + +This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules. + +Currently, two official plugins are available: + +- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh +- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh diff --git a/examples/musicgen-web/index.html b/examples/musicgen-web/index.html new file mode 100644 index 000000000..cad1bcd1a --- /dev/null +++ b/examples/musicgen-web/index.html @@ -0,0 +1,12 @@ + + + + + + MusicGen Web | In-browser text-to-music w/ 🤗 Transformers.js! + + +
+ + + diff --git a/examples/musicgen-web/package.json b/examples/musicgen-web/package.json new file mode 100644 index 000000000..0175494d7 --- /dev/null +++ b/examples/musicgen-web/package.json @@ -0,0 +1,30 @@ +{ + "name": "musicgen-web", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0", + "preview": "vite preview" + }, + "dependencies": { + "@xenova/transformers": "github:xenova/transformers.js#v3", + "react": "^18.2.0", + "react-dom": "^18.2.0" + }, + "devDependencies": { + "@types/react": "^18.2.66", + "@types/react-dom": "^18.2.22", + "@vitejs/plugin-react": "^4.2.1", + "autoprefixer": "^10.4.19", + "eslint": "^8.57.0", + "eslint-plugin-react": "^7.34.1", + "eslint-plugin-react-hooks": "^4.6.0", + "eslint-plugin-react-refresh": "^0.4.6", + "postcss": "^8.4.38", + "tailwindcss": "^3.4.3", + "vite": "^5.2.0" + } +} diff --git a/examples/musicgen-web/postcss.config.js b/examples/musicgen-web/postcss.config.js new file mode 100644 index 000000000..2e7af2b7f --- /dev/null +++ b/examples/musicgen-web/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/examples/musicgen-web/src/App.css b/examples/musicgen-web/src/App.css new file mode 100644 index 000000000..91ab868f6 --- /dev/null +++ b/examples/musicgen-web/src/App.css @@ -0,0 +1,9 @@ +#root { + max-width: 960px; + height: 100vh; + margin: 0 auto; + text-align: center; + display: flex; + justify-content: center; + align-items: center; +} diff --git a/examples/musicgen-web/src/App.jsx b/examples/musicgen-web/src/App.jsx new file mode 100644 index 000000000..da4e640bd --- /dev/null +++ b/examples/musicgen-web/src/App.jsx @@ -0,0 +1,228 @@ +import { useEffect, useState, useRef } from 'react'; +import { AutoTokenizer, MusicgenForConditionalGeneration, BaseStreamer } from '@xenova/transformers'; +import { encodeWAV, share } from './utils.js'; + +import './App.css'; + +const MODEL_ID = 'Xenova/musicgen-small'; + +// Adapted from https://huggingface.co/spaces/facebook/MusicGen +const EXAMPLES = [ + '80s pop track with bassy drums and synth', + '90s rock song with loud guitars and heavy drums', + 'a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130', + 'A cheerful country song with acoustic guitars', + 'lofi slow bpm electro chill with organic samples', +]; + +// Enable sharing if running on Hugging Face Spaces +const SHARING_ENABLED = window.location.host.endsWith('.hf.space'); + +// Streamer to update progress +class CallbackStreamer extends BaseStreamer { + constructor(callback_fn) { + super(); + this.callback_fn = callback_fn; + } + + put(value) { + return this.callback_fn(value); + } + + end() { + return this.callback_fn(); + } +} + +// Main App component +const App = () => { + // Input/output state + const [textInput, setTextInput] = useState(EXAMPLES[0]); + const [progress, setProgress] = useState(0); + const [loadProgress, setLoadProgress] = useState({}); + const [statusText, setStatusText] = useState('Loading model (656MB)...'); + const [result, setResult] = useState(null); + const audioRef = useRef(null); + + // Model and tokenizer references + const modelPromise = useRef(null); + const tokenizerPromise = useRef(null); + + // Generation parameters + const [guidance_scale, setGuidanceScale] = useState(3); + const [temperature, setTemperature] = useState(1); + const [duration, setDuration] = useState(10); + + // Load model and tokenizer on first render + useEffect(() => { + modelPromise.current ??= MusicgenForConditionalGeneration.from_pretrained(MODEL_ID, { + progress_callback: (data) => { + if (data.status !== 'progress') return; + setLoadProgress(prev => ({ ...prev, [data.file]: data })) + }, + dtype: { + text_encoder: 'q8', + decoder_model_merged: 'q8', + encodec_decode: 'fp32', + }, + device: 'wasm', + }); + + tokenizerPromise.current ??= AutoTokenizer.from_pretrained(MODEL_ID); + }, []); + + // Update progress bar based on load progress + useEffect(() => { + const items = Object.values(loadProgress); + if (items.length !== 5) return; // 5 files to load + let loaded = 0; + let total = 0; + for (const data of Object.values(loadProgress)) { + loaded += data.loaded; + total += data.total; + } + const progress = loaded / total; + setProgress(progress); + setStatusText(progress === 1 + ? 'Ready!' + : `Loading model (${(progress * 100).toFixed()}% of 656MB)...` + ); + }, [loadProgress]); + + // Function to handle generating music + const generateMusic = async () => { + // Reset audio player and result + audioRef.current.src = ''; + setResult(null); + + // Get model and tokenizer + const tokenizer = await tokenizerPromise.current; + const model = await modelPromise.current; + + // Get number of tokens to match user-specified duration (more intuitive for user) + // 503 tokens -> 10 seconds generated => ~50 tokens per second + // https://huggingface.co/docs/transformers/model_doc/musicgen#generation + const max_length = Math.min( + Math.max(Math.floor(duration * 50), 1) + 4, + model.generation_config.max_length ?? 1500, + ); + + // Create a streamer to update progress + const streamer = new CallbackStreamer((value) => { + const percent = value === undefined ? 1 : value[0].length / max_length; + setStatusText(`Generating (${(percent * 100).toFixed()}%)...`); + setProgress(percent); + }); + + // Tokenize input text + const inputs = tokenizer(textInput); + + // Generate music + const audio_values = await model.generate({ + // Inputs + ...inputs, + + // Generation parameters + max_length, + guidance_scale, + temperature, + + // Outputs + streamer, + }); + + setStatusText('Encoding audio...'); + + // Encode audio values to WAV + const sampling_rate = model.config.audio_encoder.sampling_rate; + const wav = encodeWAV(audio_values.data, sampling_rate); + const blob = new Blob([wav], { type: 'audio/wav' }); + setResult(blob); + + audioRef.current.src = URL.createObjectURL(blob); + setStatusText('Done!'); + }; + + return ( +
+

MusicGen Web

+

In-browser text-to-music w/ 🤗 Transformers.js! +

+ + {/* Text input for user */} + setTextInput(e.target.value)} + className="border border-gray-300 p-2 mb-4 w-full rounded" + /> + + {/* Example buttons */} +
+ {EXAMPLES.map((example, i) => ( + + ))} +
+ + {/* Generation parameters */} +
+ {/* Duration */} +
+ + setDuration(e.target.value)} /> +

{`${duration} second${duration > 1 ? 's' : ''}`}

+
+ + {/* Guidance Scale */} +
+ + setGuidanceScale(e.target.value)} /> +

{guidance_scale}

+
+ + {/* Temperature */} +
+ + setTemperature(e.target.value)} /> +

{temperature}

+
+
+ + {/* Button to generate music */} + + + {/* Progress bar */} +
+
+
+
+

{statusText}

+
+ + {/* Audio player */} + {
+
} +
+ ); +}; + +export default App; diff --git a/examples/musicgen-web/src/index.css b/examples/musicgen-web/src/index.css new file mode 100644 index 000000000..bd6213e1d --- /dev/null +++ b/examples/musicgen-web/src/index.css @@ -0,0 +1,3 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; \ No newline at end of file diff --git a/examples/musicgen-web/src/main.jsx b/examples/musicgen-web/src/main.jsx new file mode 100644 index 000000000..54b39dd1d --- /dev/null +++ b/examples/musicgen-web/src/main.jsx @@ -0,0 +1,10 @@ +import React from 'react' +import ReactDOM from 'react-dom/client' +import App from './App.jsx' +import './index.css' + +ReactDOM.createRoot(document.getElementById('root')).render( + + + , +) diff --git a/examples/musicgen-web/src/utils.js b/examples/musicgen-web/src/utils.js new file mode 100644 index 000000000..436c9daab --- /dev/null +++ b/examples/musicgen-web/src/utils.js @@ -0,0 +1,59 @@ + +// Adapted from https://www.npmjs.com/package/audiobuffer-to-wav +export function encodeWAV(samples, sampleRate = 16000) { + let offset = 44; + const buffer = new ArrayBuffer(offset + samples.length * 4); + const view = new DataView(buffer); + + /* RIFF identifier */ + writeString(view, 0, 'RIFF') + /* RIFF chunk length */ + view.setUint32(4, 36 + samples.length * 4, true) + /* RIFF type */ + writeString(view, 8, 'WAVE') + /* format chunk identifier */ + writeString(view, 12, 'fmt ') + /* format chunk length */ + view.setUint32(16, 16, true) + /* sample format (raw) */ + view.setUint16(20, 3, true) + /* channel count */ + view.setUint16(22, 1, true) + /* sample rate */ + view.setUint32(24, sampleRate, true) + /* byte rate (sample rate * block align) */ + view.setUint32(28, sampleRate * 4, true) + /* block align (channel count * bytes per sample) */ + view.setUint16(32, 4, true) + /* bits per sample */ + view.setUint16(34, 32, true) + /* data chunk identifier */ + writeString(view, 36, 'data') + /* data chunk length */ + view.setUint32(40, samples.length * 4, true) + + for (let i = 0; i < samples.length; ++i, offset += 4) { + view.setFloat32(offset, samples[i], true) + } + + return buffer +} +function writeString(view, offset, string) { + for (let i = 0; i < string.length; ++i) { + view.setUint8(offset + i, string.charCodeAt(i)) + } +} + +export async function share(body, settings) { + const response = await fetch('https://huggingface.co/uploads', { method: 'POST', body }); + if (!response.ok) throw new Error(`Failed to upload audio: ${response.statusText}`); + const url = await response.text(); + + const params = new URLSearchParams({ + title: `🎵 ${settings.prompt}`, + description: `\n${JSON.stringify(settings, null, 2)}`, + }); + + const shareURL = `https://huggingface.co/spaces/Xenova/musicgen-web/discussions/new?${params.toString()}`; + window.open(shareURL, '_blank'); +} \ No newline at end of file diff --git a/examples/musicgen-web/tailwind.config.js b/examples/musicgen-web/tailwind.config.js new file mode 100644 index 000000000..d37737fc0 --- /dev/null +++ b/examples/musicgen-web/tailwind.config.js @@ -0,0 +1,12 @@ +/** @type {import('tailwindcss').Config} */ +export default { + content: [ + "./index.html", + "./src/**/*.{js,ts,jsx,tsx}", + ], + theme: { + extend: {}, + }, + plugins: [], +} + diff --git a/examples/musicgen-web/vite.config.js b/examples/musicgen-web/vite.config.js new file mode 100644 index 000000000..5a33944a9 --- /dev/null +++ b/examples/musicgen-web/vite.config.js @@ -0,0 +1,7 @@ +import { defineConfig } from 'vite' +import react from '@vitejs/plugin-react' + +// https://vitejs.dev/config/ +export default defineConfig({ + plugins: [react()], +}) From f92204669bb77cb2053c09adbcc371f61dc218fc Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 19 Apr 2024 16:09:53 +0200 Subject: [PATCH 090/473] Upgrade `onnxruntime-web` to 1.17.3 --- package-lock.json | 16 ++++++++-------- package.json | 2 +- src/backends/onnx.js | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/package-lock.json b/package-lock.json index 7e5cebc8a..634704561 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,7 +10,7 @@ "license": "Apache-2.0", "dependencies": { "@huggingface/jinja": "^0.2.2", - "onnxruntime-web": "1.17.1", + "onnxruntime-web": "1.17.3", "sharp": "^0.33.2" }, "devDependencies": { @@ -6138,22 +6138,22 @@ } }, "node_modules/onnxruntime-web": { - "version": "1.17.1", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.17.1.tgz", - "integrity": "sha512-EotY9uJU4xFY/ZVZ2Zrl2OZmBcbTVTWn/2OOh4cCWODPwtsYN2xeJYgoz8LfCgZSrhenGg0q4ceYUWATXqEsYQ==", + "version": "1.17.3", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.17.3.tgz", + "integrity": "sha512-MSDrNUWgc1biP0YzY488OJ9n/jTMS9EXysgm9Aw4CUj2A836ALbO2J1sgzguWJeVUHTlM6p7tRzo8IGAgaXWKw==", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", "long": "^5.2.3", - "onnxruntime-common": "1.17.1", + "onnxruntime-common": "1.17.3", "platform": "^1.3.6", "protobufjs": "^7.2.4" } }, "node_modules/onnxruntime-web/node_modules/onnxruntime-common": { - "version": "1.17.1", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.1.tgz", - "integrity": "sha512-6wLNhpn+1hnsKN+jq6ulqUEJ61TdRmyFkGCvtRNnZkAupH8Yfr805UeNxjl9jtiX9B1q48pq6Q/67fEFpxT7Dw==" + "version": "1.17.3", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.3.tgz", + "integrity": "sha512-IkbaDelNVX8cBfHFgsNADRIq2TlXMFWW+nG55mwWvQT4i0NZb32Jf35Pf6h9yjrnK78RjcnlNYaI37w394ovMw==" }, "node_modules/open": { "version": "8.4.2", diff --git a/package.json b/package.json index 6e7e20c2b..36bc4a515 100644 --- a/package.json +++ b/package.json @@ -39,7 +39,7 @@ "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { "@huggingface/jinja": "^0.2.2", - "onnxruntime-web": "1.17.1", + "onnxruntime-web": "1.17.3", "sharp": "^0.33.2" }, "optionalDependencies": { diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 111b14846..4e585b2d3 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -94,7 +94,7 @@ if (ONNX_ENV?.wasm) { // We use remote wasm files by default to make it easier for newer users. // In practice, users should probably self-host the necessary .wasm files. // TODO: update this before release - ONNX_ENV.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/'; + ONNX_ENV.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.3/dist/'; // Proxy the WASM backend to prevent the UI from freezing ONNX_ENV.wasm.proxy = true; From d62e4102948415a769af33978512ac3cb063b137 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 19 Apr 2024 16:10:11 +0200 Subject: [PATCH 091/473] Remove unused import --- src/backends/onnx.js | 1 - 1 file changed, 1 deletion(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 4e585b2d3..225c90f08 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -16,7 +16,6 @@ * @module backends/onnx */ -import path from 'path'; import { env, apis } from '../env.js'; // NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`. From 6301327fd0f272383fb0a58a763bdfcb2f9e7fc8 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 19 Apr 2024 16:13:33 +0200 Subject: [PATCH 092/473] Allow user to update stream size --- examples/video-object-detection/main.js | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/video-object-detection/main.js b/examples/video-object-detection/main.js index 5eea3aa91..12c3552a4 100644 --- a/examples/video-object-detection/main.js +++ b/examples/video-object-detection/main.js @@ -18,6 +18,13 @@ const thresholdSlider = document.getElementById('threshold'); const thresholdLabel = document.getElementById('threshold-value'); const sizeSlider = document.getElementById('size'); const sizeLabel = document.getElementById('size-value'); +const scaleSlider = document.getElementById('scale'); +const scaleLabel = document.getElementById('scale-value'); + +function setStreamSize(width, height) { + video.width = canvas.width = Math.round(width); + video.height = canvas.height = Math.round(height); +} status.textContent = 'Loading model...'; @@ -27,6 +34,14 @@ const model = await AutoModel.from_pretrained(model_id); const processor = await AutoProcessor.from_pretrained(model_id); // Set up controls +let scale = 0.5; +scaleSlider.addEventListener('input', () => { + scale = Number(scaleSlider.value); + setStreamSize(video.videoWidth * scale, video.videoHeight * scale); + scaleLabel.textContent = scale; +}); +scaleSlider.disabled = false; + let threshold = 0.25; thresholdSlider.addEventListener('input', () => { threshold = Number(thresholdSlider.value); @@ -130,8 +145,7 @@ navigator.mediaDevices.getUserMedia( const videoTrack = stream.getVideoTracks()[0]; const { width, height } = videoTrack.getSettings(); - canvas.width = width; - canvas.height = height; + setStreamSize(width * scale, height * scale); // Set container width and height depending on the image aspect ratio const ar = width / height; From 061af0158a7cc211efb470359b22cc5d94234d85 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 20 Apr 2024 15:23:18 +0200 Subject: [PATCH 093/473] Early de-referencing --- src/processors.js | 71 +++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/src/processors.js b/src/processors.js index cd1e09a55..a2a49d33b 100644 --- a/src/processors.js +++ b/src/processors.js @@ -336,10 +336,11 @@ export class ImageFeatureExtractor extends FeatureExtractor { const threshold = gray_threshold / 255; let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0; + const gray_image_data = gray_image.data; for (let j = 0; j < gray_image.height; ++j) { const row = j * gray_image.width; for (let i = 0; i < gray_image.width; ++i) { - if ((gray_image.data[row + i] - minValue) / diff < threshold) { + if ((gray_image_data[row + i] - minValue) / diff < threshold) { // We have a non-zero pixel, so we update the min/max values accordingly x_min = Math.min(x_min, i); y_min = Math.min(y_min, j); @@ -673,7 +674,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { return { original_size: [srcHeight, srcWidth], reshaped_input_size: reshaped_input_size, - pixel_values: pixel_values, + pixel_values, } } @@ -747,12 +748,13 @@ export class SegformerFeatureExtractor extends ImageFeatureExtractor { // Buffer to store current largest value const buffer = data[0].data; + const segmentation_data = segmentation.data; for (let j = 1; j < data.dims[0]; ++j) { const row = data[j].data; for (let k = 0; k < row.length; ++k) { if (row[k] > buffer[k]) { buffer[k] = row[k]; - segmentation.data[k] = j; + segmentation_data[k] = j; } } } @@ -980,6 +982,8 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor { let mask_k_area = 0; let original_area = 0; + const mask_probs_k_data = mask_probs[k].data; + // Compute the area of all the stuff in query k for (let i = 0; i < mask_labels.length; ++i) { if (mask_labels[i] === k) { @@ -987,7 +991,7 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor { ++mask_k_area; } - if (mask_probs[k].data[i] >= mask_threshold) { + if (mask_probs_k_data[i] >= mask_threshold) { ++original_area; } } @@ -1050,11 +1054,13 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor { for (let i = 0; i < mask_probs.length; ++i) { let score = pred_scores[i]; - for (let j = 0; j < mask_probs[i].data.length; ++j) { - mask_probs[i].data[j] *= score - if (mask_probs[i].data[j] > bestScores[j]) { + const mask_probs_i_data = mask_probs[i].data; + + for (let j = 0; j < mask_probs_i_data.length; ++j) { + mask_probs_i_data[j] *= score + if (mask_probs_i_data[j] > bestScores[j]) { mask_labels[j] = i; - bestScores[j] = mask_probs[i].data[j]; + bestScores[j] = mask_probs_i_data[j]; } } } @@ -1062,6 +1068,7 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor { let current_segment_id = 0; // let stuff_memory_list = {} + const segmentation_data = segmentation.data; for (let k = 0; k < pred_labels.length; ++k) { let pred_class = pred_labels[k]; @@ -1093,7 +1100,7 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor { // Add current object segment to final segmentation map for (let index of mask_k) { - segmentation.data[index] = current_segment_id; + segmentation_data[index] = current_segment_id; } segments.push({ @@ -1369,9 +1376,10 @@ export class SamImageProcessor extends ImageFeatureExtractor { interpolated_mask = interpolate(interpolated_mask, original_size, 'bilinear', false); if (binarize) { - const binarizedMaskData = new Uint8Array(interpolated_mask.data.length); - for (let i = 0; i < interpolated_mask.data.length; ++i) { - if (interpolated_mask.data[i] > mask_threshold) { + const data = interpolated_mask.data; + const binarizedMaskData = new Uint8Array(data.length); + for (let i = 0; i < data.length; ++i) { + if (data[i] > mask_threshold) { binarizedMaskData[i] = 1; } } @@ -1468,7 +1476,7 @@ export class VitMatteImageProcessor extends ImageFeatureExtractor { ), 0); return { - pixel_values: pixel_values, + pixel_values, // Original sizes of images original_sizes: imageData.map(x => x.original_size), @@ -1685,27 +1693,28 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor { validate_audio_inputs(audio, 'SeamlessM4TFeatureExtractor'); let features = this._extract_fbank_features(audio, this.config.max_length); + const features_data = features.data; if (do_normalize_per_mel_bins) { const [num_features, feature_size] = features.dims; for (let i = 0; i < feature_size; ++i) { let sum = 0; for (let j = 0; j < num_features; ++j) { - sum += features.data[j * feature_size + i]; + sum += features_data[j * feature_size + i]; } const mean = sum / num_features; let variance = 0; for (let j = 0; j < num_features; ++j) { - variance += (features.data[j * feature_size + i] - mean) ** 2; + variance += (features_data[j * feature_size + i] - mean) ** 2; } variance /= num_features - 1; // NOTE: We use ddof=1 const std = Math.sqrt(variance + 1e-7); for (let j = 0; j < num_features; ++j) { const index = j * feature_size + i; - features.data[index] = (features.data[index] - mean) / std; + features_data[index] = (features_data[index] - mean) / std; } } } @@ -1717,8 +1726,8 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor { const pad_size = num_frames % pad_to_multiple_of; if (pad_size > 0) { const padded_data = new Float32Array(num_channels * (num_frames + pad_size)); - padded_data.set(features.data) - padded_data.fill(this.config.padding_value, features.data.length) + padded_data.set(features_data) + padded_data.fill(this.config.padding_value, features_data.length) const numPaddedFrames = num_frames + pad_size; features = { @@ -1746,7 +1755,7 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor { } const input_features = new Tensor('float32', - features.data, + features_data, features.dims, ).view( 1, @@ -1759,20 +1768,21 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor { if (return_attention_mask) { const reshapedNumFrames = input_features.dims[1]; - const attention_mask = new Tensor( - 'int64', - new BigInt64Array(reshapedNumFrames), - [1, reshapedNumFrames], - ); + const attention_mask_data = new BigInt64Array(reshapedNumFrames); + if (padded_attention_mask) { + const padded_attention_mask_data = padded_attention_mask.data; for (let i = 1, j = 0; i < num_frames; i += stride, ++j) { - attention_mask.data[j] = padded_attention_mask.data[i]; + attention_mask_data[j] = padded_attention_mask_data[i]; } } else { - attention_mask.data.fill(1n); + attention_mask_data.fill(1n); } - - result.attention_mask = attention_mask; + result.attention_mask = new Tensor( + 'int64', + attention_mask_data, + [1, reshapedNumFrames], + ); } return result; @@ -1854,8 +1864,9 @@ export class ASTFeatureExtractor extends FeatureExtractor { if (this.config.do_normalize) { // Normalize the input audio spectrogram to have mean=0, std=0.5 const denom = this.std * 2; - for (let i = 0; i < features.data.length; ++i) { - features.data[i] = (features.data[i] - this.mean) / denom; + const features_data = features.data; + for (let i = 0; i < features_data.length; ++i) { + features_data[i] = (features_data[i] - this.mean) / denom; } } From 02c6b153594a4c0e7949592685a1d521e6bd232a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 20 Apr 2024 15:48:47 +0200 Subject: [PATCH 094/473] Early de-referencing --- src/utils/tensor.js | 62 +++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index e99ba2941..70b712ae7 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -151,9 +151,10 @@ export class Tensor { * @returns {number} The index of the first occurrence of item in the tensor data. */ indexOf(item) { - for (let index = 0; index < this.data.length; ++index) { + const this_data = this.data; + for (let index = 0; index < this_data.length; ++index) { // Note: == instead of === so we can match Ints with BigInts - if (this.data[index] == item) { + if (this_data[index] == item) { return index; } } @@ -185,10 +186,11 @@ export class Tensor { * @throws {Error} If the tensor has more than one element. */ item() { - if (this.data.length !== 1) { - throw new Error(`a Tensor with ${this.data.length} elements cannot be converted to Scalar`); + const this_data = this.data; + if (this_data.length !== 1) { + throw new Error(`a Tensor with ${this_data.length} elements cannot be converted to Scalar`); } - return this.data[0]; + return this_data[0]; } /** @@ -212,8 +214,9 @@ export class Tensor { * @returns {Tensor} Returns `this`. */ sigmoid_() { - for (let i = 0; i < this.data.length; ++i) { - this.data[i] = 1 / (1 + Math.exp(-this.data[i])); + const this_data = this.data; + for (let i = 0; i < this_data.length; ++i) { + this_data[i] = 1 / (1 + Math.exp(-this_data[i])); } return this; } @@ -233,8 +236,9 @@ export class Tensor { * @returns {Tensor} Returns `this`. */ mul_(val) { - for (let i = 0; i < this.data.length; ++i) { - this.data[i] *= val; + const this_data = this.data; + for (let i = 0; i < this_data.length; ++i) { + this_data[i] *= val; } return this; } @@ -254,8 +258,9 @@ export class Tensor { * @returns {Tensor} Returns `this`. */ add_(val) { - for (let i = 0; i < this.data.length; ++i) { - this.data[i] += val; + const this_data = this.data; + for (let i = 0; i < this_data.length; ++i) { + this_data[i] += val; } return this; } @@ -308,9 +313,10 @@ export class Tensor { let newDims = newOffsets.map(([start, end]) => end - start); let newBufferSize = newDims.reduce((a, b) => a * b); + const this_data = this.data; // Allocate memory // @ts-ignore - let data = new this.data.constructor(newBufferSize); + let data = new this_data.constructor(newBufferSize); // Precompute strides const stride = this.stride(); @@ -322,7 +328,7 @@ export class Tensor { originalIndex += ((num % size) + newOffsets[j][0]) * stride[j]; num = Math.floor(num / size); } - data[i] = this.data[originalIndex]; + data[i] = this_data[originalIndex]; } return new Tensor(this.type, data, newTensorDims); @@ -371,9 +377,11 @@ export class Tensor { throw Error(`Unsupported norm: ${p}`); } + const this_data = this.data; + if (dim === null) { // @ts-ignore - let val = this.data.reduce((a, b) => a + (b ** p), 0) ** (1 / p); + let val = this_data.reduce((a, b) => a + (b ** p), 0) ** (1 / p); return new Tensor(this.type, [val], []); } @@ -386,10 +394,10 @@ export class Tensor { // Create a new array to store the accumulated values // @ts-ignore - const result = new this.data.constructor(this.data.length / this.dims[dim]); + const result = new this_data.constructor(this_data.length / this.dims[dim]); // Iterate over the data array - for (let i = 0; i < this.data.length; ++i) { + for (let i = 0; i < this_data.length; ++i) { // Calculate the index in the resulting array let resultIndex = 0; @@ -405,7 +413,7 @@ export class Tensor { } // Accumulate the value at the current index - result[resultIndex] += (this.data[i]) ** p; + result[resultIndex] += (this_data[i]) ** p; } if (p !== 1) { @@ -432,7 +440,8 @@ export class Tensor { const norm = this.norm(p, dim, true); - for (let i = 0; i < this.data.length; ++i) { + const this_data = this.data; + for (let i = 0; i < this_data.length; ++i) { // Calculate the index in the resulting array let resultIndex = 0; @@ -448,7 +457,7 @@ export class Tensor { } // Divide by normalized value - this.data[i] /= norm.data[resultIndex]; + this_data[i] /= norm.data[resultIndex]; } return this; @@ -578,8 +587,9 @@ export class Tensor { } neg_() { - for (let i = 0; i < this.data.length; ++i) { - this.data[i] = -this.data[i]; + const this_data = this.data; + for (let i = 0; i < this_data.length; ++i) { + this_data[i] = -this_data[i]; } return this; } @@ -591,8 +601,9 @@ export class Tensor { * In-place version of @see {@link Tensor.clamp} */ clamp_(min, max) { - for (let i = 0; i < this.data.length; ++i) { - this.data[i] = Math.min(Math.max(this.data[i], min), max); + const this_data = this.data; + for (let i = 0; i < this_data.length; ++i) { + this_data[i] = Math.min(Math.max(this_data[i], min), max); } return this; } @@ -611,8 +622,9 @@ export class Tensor { * In-place version of @see {@link Tensor.round} */ round_() { - for (let i = 0; i < this.data.length; ++i) { - this.data[i] = Math.round(this.data[i]); + const this_data = this.data; + for (let i = 0; i < this_data.length; ++i) { + this_data[i] = Math.round(this_data[i]); } return this; } From bf009d728cafc2782cdf05c3579b42cb70a255c6 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 20 Apr 2024 16:00:18 +0200 Subject: [PATCH 095/473] Update remaining model constructors --- src/models.js | 109 +++++++++++++++++++++++--------------------------- 1 file changed, 49 insertions(+), 60 deletions(-) diff --git a/src/models.js b/src/models.js index a63fa1f93..c9a63789b 100644 --- a/src/models.js +++ b/src/models.js @@ -3812,12 +3812,12 @@ export class CLIPSegForImageSegmentation extends CLIPSegPreTrainedModel { } export class GPT2PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `GPT2PreTrainedModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -3845,12 +3845,12 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { } export class GPTNeoPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `GPTNeoPreTrainedModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -3871,12 +3871,12 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { } export class GPTNeoXPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `GPTNeoXPreTrainedModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -3898,12 +3898,12 @@ export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { } export class GPTJPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `GPTJPreTrainedModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -3926,12 +3926,12 @@ export class GPTJForCausalLM extends GPTJPreTrainedModel { } export class GPTBigCodePreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `GPTBigCodePreTrainedModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -3953,12 +3953,12 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { } export class CodeGenPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `CodeGenPreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -3990,12 +3990,12 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { } export class LlamaPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `LlamaPreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -4023,12 +4023,12 @@ export class LlamaForCausalLM extends LlamaPreTrainedModel { } export class Qwen2PreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `Qwen2PreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -4050,16 +4050,15 @@ export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { } ////////////////////////////////////////////////// // Phi models - export class PhiPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `PhiPreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -4087,12 +4086,12 @@ export class PhiForCausalLM extends PhiPreTrainedModel { } export class BloomPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `BloomPreTrainedModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -4120,12 +4119,12 @@ export class BloomForCausalLM extends BloomPreTrainedModel { } export class MptPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `MptPreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -4154,12 +4153,12 @@ export class MptForCausalLM extends MptPreTrainedModel { } export class OPTPreTrainedModel extends PreTrainedModel { /** * Creates a new instance of the `OPTPreTrainedModel` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. + * @param {Object} config The model configuration. + * @param {Record} sessions The inference sessions for the model. * @param {GenerationConfig} generation_config The generation configuration. */ - constructor(config, session, generation_config) { - super(config, session); + constructor(config, sessions, generation_config) { + super(config, sessions); this.generation_config = generation_config; // config doesn't contain pad_token_id, so we assume it is the eos_token_id @@ -4798,16 +4797,6 @@ export class SamPreTrainedModel extends PreTrainedModel { } * ``` */ export class SamModel extends SamPreTrainedModel { - /** - * Creates a new instance of the `SamModel` class. - * @param {Object} config The configuration object specifying the hyperparameters and other model settings. - * @param {Object} vision_encoder The ONNX session containing the vision encoder model. - * @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model. - */ - constructor(config, vision_encoder, prompt_encoder_mask_decoder) { - super(config, vision_encoder); - this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder; - } /** * Compute image embeddings and positional image embeddings, given the pixel values of an image. @@ -4868,7 +4857,7 @@ export class SamModel extends SamPreTrainedModel { // Returns: // - iou_scores: tensor.float32[batch_size,point_batch_size,3] // - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256] - return await sessionRun(this.prompt_encoder_mask_decoder, { + return await sessionRun(this.sessions['prompt_encoder_mask_decoder'], { input_points: model_inputs.input_points, input_labels: model_inputs.input_labels, image_embeddings: model_inputs.image_embeddings, From 0637743d2d55ce8501c3caef8b64f536e6a6a9c8 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 20 Apr 2024 17:59:15 +0200 Subject: [PATCH 096/473] Implement `interpolate_4d` with ORT sessions --- src/ops/registry.js | 44 ++++++++++++++++ src/processors.js | 69 ++++++++++++------------- src/utils/tensor.js | 52 +++++++++++++++++++ tests/tensor_ops.test.js | 109 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 238 insertions(+), 36 deletions(-) create mode 100644 src/ops/registry.js create mode 100644 tests/tensor_ops.test.js diff --git a/src/ops/registry.js b/src/ops/registry.js new file mode 100644 index 000000000..f81a4f36e --- /dev/null +++ b/src/ops/registry.js @@ -0,0 +1,44 @@ +import { createInferenceSession } from "../backends/onnx.js"; +import { Tensor } from "../utils/tensor.js"; + +const wrap = async (session_bytes, session_options, name) => { + const session = await createInferenceSession( + session_bytes, + session_options, + ); + return async (inputs) => { + const ortFeed = Object.fromEntries(Object.entries(inputs).map(([k, v]) => [k, v.ort_tensor])); + const outputs = await session.run(ortFeed); + return new Tensor(outputs[name]); + } +} + +// In-memory registry of initialized ONNX operators +export class TensorOpRegistry { + static session_options = { + // TODO: Allow for multiple execution providers + // executionProviders: ['webgpu'], + }; + + static get bilinear_interpolate_4d() { + if (!this._bilinear_interpolate_4d) { + this._bilinear_interpolate_4d = wrap( + new Uint8Array([8, 9, 18, 0, 58, 128, 1, 10, 40, 10, 1, 120, 10, 0, 10, 0, 10, 1, 115, 18, 1, 121, 34, 6, 82, 101, 115, 105, 122, 101, 42, 17, 10, 4, 109, 111, 100, 101, 34, 6, 108, 105, 110, 101, 97, 114, 160, 1, 3, 18, 1, 114, 90, 31, 10, 1, 120, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 90, 15, 10, 1, 115, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 4, 98, 31, 10, 1, 121, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 66, 2, 16, 20]), + this.session_options, + 'y', + ); + } + return this._bilinear_interpolate_4d; + } + + static get bicubic_interpolate_4d() { + if (!this._bicubic_interpolate_4d) { + this._bicubic_interpolate_4d = wrap( + new Uint8Array([8, 9, 18, 0, 58, 127, 10, 39, 10, 1, 120, 10, 0, 10, 0, 10, 1, 115, 18, 1, 121, 34, 6, 82, 101, 115, 105, 122, 101, 42, 16, 10, 4, 109, 111, 100, 101, 34, 5, 99, 117, 98, 105, 99, 160, 1, 3, 18, 1, 114, 90, 31, 10, 1, 120, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 90, 15, 10, 1, 115, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 4, 98, 31, 10, 1, 121, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 66, 2, 16, 20]), + this.session_options, + 'y', + ); + } + return this._bicubic_interpolate_4d; + } +} \ No newline at end of file diff --git a/src/processors.js b/src/processors.js index a2a49d33b..c93712ddc 100644 --- a/src/processors.js +++ b/src/processors.js @@ -40,7 +40,7 @@ import { } from './utils/maths.js'; -import { Tensor, permute, cat, interpolate, stack } from './utils/tensor.js'; +import { Tensor, permute, cat, interpolate, stack, interpolate_4d } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; import { @@ -1332,17 +1332,17 @@ export class SamImageProcessor extends ImageFeatureExtractor { /** * Remove padding and upscale masks to the original image size. * @param {Tensor} masks Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. - * @param {number[][]} original_sizes The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. - * @param {number[][]} reshaped_input_sizes The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + * @param {[number, number][]} original_sizes The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format. + * @param {[number, number][]} reshaped_input_sizes The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. * @param {Object} options Optional parameters for post-processing. * @param {number} [options.mask_threshold] The threshold to use for binarizing the masks. * @param {boolean} [options.binarize] Whether to binarize the masks. * @param {Object} [options.pad_size] The target size the images were padded to before being passed to the model. If `null`, the target size is assumed to be the processor's `pad_size`. * @param {number} [options.pad_size.height] The height the images were padded to. * @param {number} [options.pad_size.width] The width the images were padded to. - * @returns {Tensor[]} Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. + * @returns {Promise} Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size. */ - post_process_masks(masks, original_sizes, reshaped_input_sizes, { + async post_process_masks(masks, original_sizes, reshaped_input_sizes, { mask_threshold = 0.0, binarize = true, pad_size = null, @@ -1353,47 +1353,44 @@ export class SamImageProcessor extends ImageFeatureExtractor { pad_size = pad_size ?? this.pad_size; + /** @type {[number, number]} */ const target_image_size = [pad_size.height, pad_size.width]; for (let i = 0; i < original_sizes.length; ++i) { const original_size = original_sizes[i]; const reshaped_input_size = reshaped_input_sizes[i]; - const mask = masks[i]; // [b, c, h, w] - - // TODO: improve - const interpolated_masks = []; - for (let j = 0; j < mask.dims[0]; ++j) { - const m = mask[j]; // 3d tensor - - // Upscale mask to padded size - let interpolated_mask = interpolate(m, target_image_size, 'bilinear', false); - - // Crop mask - interpolated_mask = interpolated_mask.slice(null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]); - - // Downscale mask - interpolated_mask = interpolate(interpolated_mask, original_size, 'bilinear', false); - - if (binarize) { - const data = interpolated_mask.data; - const binarizedMaskData = new Uint8Array(data.length); - for (let i = 0; i < data.length; ++i) { - if (data[i] > mask_threshold) { - binarizedMaskData[i] = 1; - } + // Upscale mask to padded size + let interpolated_mask = (await interpolate_4d( + masks[i], + { mode: 'bilinear', size: target_image_size } + )); + + // Crop mask + interpolated_mask = interpolated_mask.slice(null, null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]); + + // Downscale mask + interpolated_mask = (await interpolate_4d( + interpolated_mask, + { mode: 'bilinear', size: original_size } + )); + + if (binarize) { + const data = interpolated_mask.data; + const binarizedMaskData = new Uint8Array(data.length); + for (let i = 0; i < data.length; ++i) { + if (data[i] > mask_threshold) { + binarizedMaskData[i] = 1; } - interpolated_mask = new Tensor( - 'bool', - binarizedMaskData, - interpolated_mask.dims - ) } - - interpolated_masks.push(interpolated_mask); + interpolated_mask = new Tensor( + 'bool', + binarizedMaskData, + interpolated_mask.dims + ) } - output_masks.push(stack(interpolated_masks)); + output_masks.push(interpolated_mask); } return output_masks; diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 70b712ae7..10d8a1014 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -16,6 +16,8 @@ import { Tensor as ONNXTensor, isONNXTensor, } from '../backends/onnx.js'; +import { TensorOpRegistry } from '../ops/registry.js'; + const DataTypeMap = Object.freeze({ float32: Float32Array, float16: Uint16Array, @@ -754,6 +756,56 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a return new Tensor(input.type, output, [in_channels, out_height, out_width]); } + +/** + * Down/up samples the input. + * Inspired by https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. + * @param {Tensor} input the input tensor + * @param {Object} options the options for the interpolation + * @param {[number, number]|[number, number, number]|[number, number, number, number]} [options.size=null] output spatial size. + * @param {"bilinear"|"bicubic"} [options.mode='bilinear'] algorithm used for upsampling + * @returns {Promise} The interpolated tensor. + */ +export async function interpolate_4d(input, { + size = null, + mode = 'bilinear', +} = {}) { + + // Error checking + if (input.dims.length !== 4) { + throw new Error('`interpolate_4d` currently only supports 4D input.'); + } + if (!size) { + // TODO: support scale_factor + throw new Error('`interpolate_4d` requires a `size` argument.'); + } + + // Fill in missing dimensions + let targetDims; + if (size.length === 2) { + targetDims = [...input.dims.slice(0, 2), ...size]; + } else if (size.length === 3) { + targetDims = [input.dims[0], ...size]; + } else if (size.length === 4) { + targetDims = size; + } else { + throw new Error('`size` must be of length 2, 3, or 4.'); + } + + let op; + if (mode === 'bilinear') { + op = await TensorOpRegistry.bilinear_interpolate_4d; + } else if (mode === 'bicubic') { + op = await TensorOpRegistry.bicubic_interpolate_4d; + } else { + throw new Error(`Unsupported mode: ${mode}`); + } + + const sizeTensor = new Tensor('int64', new BigInt64Array(targetDims.map(BigInt)), [targetDims.length]); + return await op({ x: input, s: sizeTensor }); +} + + /** * Perform mean pooling of the last hidden state followed by a normalization step. * @param {Tensor} last_hidden_state Tensor of shape [batchSize, seqLength, embedDim] diff --git a/tests/tensor_ops.test.js b/tests/tensor_ops.test.js new file mode 100644 index 000000000..821c559f0 --- /dev/null +++ b/tests/tensor_ops.test.js @@ -0,0 +1,109 @@ +import { Tensor, interpolate_4d } from '../src/utils/tensor.js'; +import { init } from './init.js'; + +// Initialise the testing environment +init(); + +function expectToBeCloseToArray(actual, expected) { + expect(actual.length).toEqual(expected.length) + actual.forEach((x, i) => expect(x).toBeCloseTo(expected[i])) +} + +describe('Tensor operations', () => { + + describe('interpolate', () => { + const input = new Tensor('float32', new Float32Array(2 * 3 * 4 * 5).map((_, i) => i), [2, 3, 4, 5]); + + const size = [2, 3, 3, 2]; + it('bilinear', async () => { + const resized = await interpolate_4d( + input, + { mode: 'bilinear', size }, + ); + const target = new Float32Array([ + [ + [ + [1.5833335, 4.0833335], + [8.25, 10.75], + [14.916668, 17.416668] + ], + [ + [21.583332, 24.083334], + [28.25, 30.75], + [34.916668, 37.416668] + ], + [ + [41.583332, 44.083332], + [48.25, 50.75], + [54.916668, 57.416668] + ] + ], + [ + [ + [61.583332, 64.083336], + [68.25, 70.75], + [74.916664, 77.41667] + ], + [ + [81.58333, 84.083336], + [88.25, 90.75], + [94.91667, 97.41667] + ], + [ + [101.583336, 104.08333], + [108.25, 110.75], + [114.916664, 117.416664] + ] + ] + ].flat(Infinity)); + + expectToBeCloseToArray(target, resized.data); + }); + + it('bicubic', async () => { + const resized = await interpolate_4d( + input, + { mode: 'bicubic', size }, + ); + + const target = new Float32Array([ + [ + [ + [1.2987545, 3.9628172], + [8.167969, 10.832031], + [15.037184, 17.701244] + ], + [ + [21.298756, 23.962818], + [28.167969, 30.832031], + [35.037186, 37.701252] + ], + [ + [41.298756, 43.96282], + [48.16797, 50.83203], + [55.037193, 57.701256] + ] + ], + [ + [ + [61.29875, 63.96282], + [68.16797, 70.83203], + [75.03719, 77.701256] + ], + [ + [81.29875, 83.96282], + [88.16797, 90.83203], + [95.03721, 97.70126] + ], + [ + [101.29875, 103.962814], + [108.16797, 110.83203], + [115.03721, 117.70127] + ] + ] + ].flat(Infinity)); + + expectToBeCloseToArray(target, resized.data); + }); + }); +}); From bcbc2fe61febffbe845de19da1d8d17ecff22a28 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 20 Apr 2024 17:59:54 +0200 Subject: [PATCH 097/473] Only use proxying when not in web-worker env --- src/backends/onnx.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 225c90f08..81637942e 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -68,7 +68,7 @@ export function deviceToExecutionProviders(device) { * Create an ONNX inference session. * @param {Uint8Array} buffer The ONNX model buffer. * @param {Object} session_options ONNX inference session options. - * @returns {Promise} The ONNX inference session. + * @returns {Promise} The ONNX inference session. */ export async function createInferenceSession(buffer, session_options) { return await InferenceSession.create(buffer, session_options); @@ -96,7 +96,8 @@ if (ONNX_ENV?.wasm) { ONNX_ENV.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.3/dist/'; // Proxy the WASM backend to prevent the UI from freezing - ONNX_ENV.wasm.proxy = true; + // NOTE: This is only needed when running in a non-worker browser environment. + ONNX_ENV.wasm.proxy = !apis.IS_WEBWORKER_ENV; // https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) { From 520fcfcb9a82a2242b0585cb4defa4c42bd661d4 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 21 Apr 2024 00:52:27 +0200 Subject: [PATCH 098/473] Pass text inputs to text-generation pipeline --- src/pipelines.js | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index 2f877785b..99662a97a 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -970,17 +970,16 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli : generate_kwargs.return_full_text ?? true; this.tokenizer.padding_side = 'left'; - const { input_ids, attention_mask } = this.tokenizer(inputs, { + const text_inputs = this.tokenizer(inputs, { add_special_tokens, padding: true, truncation: true, }); const outputTokenIds = /** @type {Tensor} */(await this.model.generate({ - inputs: input_ids, - - // TODO: add back? - // inputs_attention_mask: attention_mask, + // inputs: input_ids, + // attention_mask, + ...text_inputs, ...generate_kwargs })); @@ -3189,6 +3188,7 @@ export async function pipeline( revision = 'main', device = null, dtype = null, + model_file_name = null, session_options = {}, } = {} ) { @@ -3218,6 +3218,7 @@ export async function pipeline( revision, device, dtype, + model_file_name, session_options, } From aeb9d87d8a423d5cb00258af048051d11e2f8a68 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 21 Apr 2024 01:22:04 +0200 Subject: [PATCH 099/473] Update segment-anything web demo --- examples/segment-anything-client/index.css | 2 +- examples/segment-anything-client/index.html | 6 ++- examples/segment-anything-client/package.json | 17 ++++++++ .../segment-anything-client/vite.config.js | 6 +++ examples/segment-anything-client/worker.js | 42 ++++++++++--------- 5 files changed, 50 insertions(+), 23 deletions(-) create mode 100644 examples/segment-anything-client/package.json create mode 100644 examples/segment-anything-client/vite.config.js diff --git a/examples/segment-anything-client/index.css b/examples/segment-anything-client/index.css index a896b8846..fc556bcac 100644 --- a/examples/segment-anything-client/index.css +++ b/examples/segment-anything-client/index.css @@ -23,7 +23,7 @@ body, align-items: center; } -h1 { +h1, h3 { text-align: center; } diff --git a/examples/segment-anything-client/index.html b/examples/segment-anything-client/index.html index 5e8a2e9b9..f016cf423 100644 --- a/examples/segment-anything-client/index.html +++ b/examples/segment-anything-client/index.html @@ -6,11 +6,13 @@ - Transformers.js - Segment Anything + Transformers.js - Segment Anything Web -

Segment Anything w/ 🤗 Transformers.js

+

Segment Anything Web

+

In-browser image segmentation w/ 🤗 + Transformers.js