Skip to content

Commit

Permalink
[Llama + LLama2] Add model support (#232)
Browse files Browse the repository at this point in the history
* Add support for llama models

* Fix JSDoc
  • Loading branch information
xenova authored Aug 9, 2023
1 parent 1e157ba commit 46dd490
Showing 1 changed file with 102 additions and 12 deletions.
114 changes: 102 additions & 12 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -1836,13 +1836,13 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel {
}

/**
* Generates the start beams for the given input tokens and output sequence length.
*
* @param {any[]} inputs The input sequence.
* @param {number} numOutputTokens The desired length of the output sequence.
* @param {...*} args Additional arguments to pass to the `seq2seqStartBeams` function.
* @returns {any[]} An array of `Beam` objects representing the start beams.
*/
* Generates the start beams for the given input tokens and output sequence length.
*
* @param {any[]} inputs The input sequence.
* @param {number} numOutputTokens The desired length of the output sequence.
* @param {...*} args Additional arguments to pass to the `seq2seqStartBeams` function.
* @returns {any[]} An array of `Beam` objects representing the start beams.
*/
getStartBeams(inputs, numOutputTokens, ...args) {
return seq2seqStartBeams(this, inputs, numOutputTokens);
}
Expand All @@ -1860,16 +1860,16 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel {
* Updates the given beam with the new predicted token.
* @param {any} beam The beam to update.
* @param {number} newTokenId The index of the predicted token.
*/
*/
updateBeam(beam, newTokenId) {
beam.output_token_ids = [...beam.output_token_ids, newTokenId];
}

/**
* Runs the forward pass of the model on the given inputs.
* @param {any} model_inputs The model inputs.
* @returns {Promise<any>} A Promise that resolves to the model outputs.
*/
* Runs the forward pass of the model on the given inputs.
* @param {any} model_inputs The model inputs.
* @returns {Promise<any>} A Promise that resolves to the model outputs.
*/
async forward(model_inputs) {
return await seq2seqForward(this, model_inputs);
}
Expand Down Expand Up @@ -2884,6 +2884,94 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel {
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// LLama models

/**
* The bare LLama Model outputting raw hidden-states without any specific head on top.
*/
export class LlamaPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `LlamaPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
*/
constructor(config, session) {
super(config, session);

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id

this.num_heads = this.config.num_attention_heads
this.num_layers = this.config.num_hidden_layers
this.dim_kv = this.config.hidden_size / this.num_heads;
}
}
/**
* The bare LLaMA Model outputting raw hidden-states without any specific head on top.
*/
export class LlamaModel extends LlamaPreTrainedModel {
/**
* Throws an error indicating that the current model class is not compatible with `.generate()`,
* as it doesn't have a language model head.
*
* @throws {Error} The current model class is not compatible with `.generate()`
*
* @param {...any} args Arguments passed to the generate function
* @returns {Promise<any>}
*/
async generate(...args) {
throw Error(
"The current model class (LlamaModel) is not compatible with `.generate()`, as it doesn't have a language model head. Please use one of the following classes instead: {'LlamaForCausalLM'}"
)
}
}

export class LlamaForCausalLM extends LlamaPreTrainedModel {

/**
* Initializes and returns the beam for text generation task
* @param {Tensor} inputTokenIds The input token ids.
* @param {number} numOutputTokens The number of tokens to be generated.
* @param {Tensor} inputs_attention_mask Optional input attention mask.
* @returns {any} A Beam object representing the initialized beam.
*/
getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) {
return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask)
}

/**
* Runs a single step of the beam search generation algorithm.
* @param {any} beam The current beam being generated.
* @returns {Promise<any>} The updated beam after a single generation step.
*/
async runBeam(beam) {
return await decoderRunBeam(this, beam);
}

/**
* Updates the given beam with the new generated token id.
* @param {any} beam The Beam object representing the beam.
* @param {number} newTokenId The new generated token id to be added to the beam.
*/
updateBeam(beam, newTokenId) {
return decoderUpdatebeam(beam, newTokenId);
}

/**
* Forward pass for the model.
* @param {Object} model_inputs The inputs for the model.
* @returns {Promise<any>} The output tensor of the model.
*/
async forward(model_inputs) {
return await decoderForward(this, model_inputs);
}

}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
export class ViTPreTrainedModel extends PreTrainedModel { }
export class ViTForImageClassification extends ViTPreTrainedModel {
Expand Down Expand Up @@ -3260,6 +3348,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
['gpt_bigcode', GPTBigCodeModel],
['gpt_neo', GPTNeoModel],
['codegen', CodeGenModel],
['llama', LlamaModel],
]);

const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
Expand Down Expand Up @@ -3300,6 +3389,7 @@ const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
['gpt_bigcode', GPTBigCodeForCausalLM],
['gpt_neo', GPTNeoForCausalLM],
['codegen', CodeGenForCausalLM],
['llama', LlamaForCausalLM],
]);

const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
Expand Down

0 comments on commit 46dd490

Please sign in to comment.