diff --git a/src/transformers_neuronx/__init__.py b/src/transformers_neuronx/__init__.py index c30c320..786e225 100644 --- a/src/transformers_neuronx/__init__.py +++ b/src/transformers_neuronx/__init__.py @@ -20,6 +20,7 @@ from transformers_neuronx.config import NeuronConfig, QuantizationConfig, ContinuousBatchingConfig, GenerationConfig from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter +from transformers_neuronx.qwen2.model import Qwen2ForSampling from transformers_neuronx.bloom.model import BloomForSampling from transformers_neuronx.llama.model import LlamaForSampling from transformers_neuronx.gpt2.model import GPT2ForSamplingWithContextBroadcasting diff --git a/src/transformers_neuronx/modeling_auto.py b/src/transformers_neuronx/modeling_auto.py index 0269713..67a031d 100644 --- a/src/transformers_neuronx/modeling_auto.py +++ b/src/transformers_neuronx/modeling_auto.py @@ -12,6 +12,7 @@ "mistral": transformers_neuronx.MistralForSampling, "mixtral": transformers_neuronx.MixtralForSampling, "opt": transformers_neuronx.OPTForSampling, + "qwen2": transformers_neuronx.Qwen2ForSampling, } @@ -24,6 +25,7 @@ transformers.MistralConfig: "mistral", transformers.MixtralConfig: "mixtral", transformers.OPTConfig: "opt", + transformers.Qwen2Config: "qwen2", } diff --git a/src/transformers_neuronx/qwen2/config.py b/src/transformers_neuronx/qwen2/config.py new file mode 100644 index 0000000..699a7fc --- /dev/null +++ b/src/transformers_neuronx/qwen2/config.py @@ -0,0 +1,57 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from transformers_neuronx import utils + +class Qwen2Config: + + def __init__( + self, + config, + n_positions, + batch_size, + amp, + tp_degree, + **kwargs + ): + + # Extract configs used for building HLO + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.vocab_size = config.vocab_size + self.hidden_act = config.hidden_act + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.max_position_embeddings = config.max_position_embeddings + self.rms_norm_eps = config.rms_norm_eps + self.rotary_percentage = getattr(config, "rotary_percentage", 1) + self.rope_theta = getattr(config, "rope_theta", 10000) + self.position_interpolation_factor = getattr(config, "position_interpolation_factor", None) + self.rope_scaling = getattr(config, "rope_scaling", None) + rope_scaling_type = self.rope_scaling.get("rope_type", self.rope_scaling.get("type", None)) if self.rope_scaling is not None else None + if self.rope_scaling is not None and rope_scaling_type not in {'default', 'llama3'}: + raise ValueError(f"Only default and llama3 ropes scaling types are currently supported. Received {rope_scaling_type}") + + utils.maybe_override_attributes(self, kwargs) + + # Add required Neuron configs + self.n_positions = n_positions + self.batch_size = batch_size + self.amp = amp + self.tp_degree = tp_degree + self.model_type = 'qwen2' \ No newline at end of file diff --git a/src/transformers_neuronx/qwen2/hlo.py b/src/transformers_neuronx/qwen2/hlo.py new file mode 100644 index 0000000..1601b96 --- /dev/null +++ b/src/transformers_neuronx/qwen2/hlo.py @@ -0,0 +1,498 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Optional + +from transformers_neuronx import hlo, utils +from transformers_neuronx import constants +from transformers_neuronx import utils +from transformers_neuronx.layers import transformer, rotary, attention, attention_utils, flash_decoding +from transformers_neuronx.qwen2.config import Qwen2Config +from transformers_neuronx.config import NeuronConfig +from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB +from transformers_neuronx.hlo import quantize_kv_cache_direct_cast, dequantize_kv_cache_direct_cast + +from transformers_neuronx.nki.compile import nki_call + + +class Qwen2ForSamplingNoEmbeddingHlo: + + def __init__(self, + config: Qwen2Config, + neuron_config: Optional[NeuronConfig] = None + ): + self.config = config + self.neuron_config = neuron_config + self.n_positions = None + self.num_active_blocks = None + + @property + def shard_over_batch(self): + # Property access allows fallback configuration to be enabled after construction + return ( + self.neuron_config is not None + and self.neuron_config.group_query_attention == constants.GQA.SHARD_OVER_BATCH + ) + + def inputs(self, scribe, dtype, n_active_tokens, batch_size): + tensors, dims = transformer.inputs( + scribe, dtype, batch_size, n_active_tokens, self.config.hidden_size, self.neuron_config, self.config.tp_degree) + + return tensors, dims + + def token_tree_inputs(self, scribe, dtype, n_active_tokens, batch_size): + tensors, dims = self.inputs(scribe, dtype, n_active_tokens, batch_size) + s32 = scribe.s32 + cache_2d = self.neuron_config and self.neuron_config.use_2d_cache_ids + # Allow tree based speculation inputs + if cache_2d: + position_sizes = batch_size, n_active_tokens + previous_cache_ids = s32[position_sizes].Parameter(parameter_number=4) + reorder_mapping = s32[position_sizes].Parameter(parameter_number=5) + else: + previous_cache_ids = s32[n_active_tokens].Parameter(parameter_number=4) + reorder_mapping = s32[n_active_tokens].Parameter(parameter_number=5) + seq_slice_dim = 1 if cache_2d else 0 + + return (*tensors, previous_cache_ids, reorder_mapping), (*dims, seq_slice_dim, seq_slice_dim) + + def embedding(self, input_ids, cache_ids, start_ids, last_token_id, *weights): + if self.neuron_config.shard_over_sequence and self.neuron_config.on_device_embedding: + *rst, embed_weight = weights + else: + embed_weight, *rst = weights + dtype = getattr(input_ids.scribe, self.config.amp) + if self.neuron_config.on_device_embedding and self.neuron_config.sequence_parallel_norm: + hidden = hlo.embedding(embed_weight, input_ids, tp_degree=1, dtype=dtype) + else: + hidden = hlo.embedding(embed_weight, input_ids, tp_degree=self.config.tp_degree, dtype=dtype) + if self.config.hidden_size % self.config.tp_degree != 0: + hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0) + if self.neuron_config.attention_layout == LAYOUT_HSB: + hidden = hlo.transpose210(hidden) + return hidden + + def token_tree_embedding(self, input_ids, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights): + return self.embedding(input_ids, cache_ids, start_ids, last_token_id, *weights) + + def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights): + # TODO: move this fallback calculation to decoder.py + if self.num_active_blocks is None and self.neuron_config.optimized_paged_attention: + max_model_len = self.neuron_config.continuous_batching.max_model_len + max_num_seqs = self.neuron_config.continuous_batching.max_num_seqs + block_size = self.neuron_config.continuous_batching.block_size + self.num_active_blocks = (max_model_len * max_num_seqs // block_size) - 2 + + if self.neuron_config.optimized_paged_attention and len(last_token_id.sizes) == 2: + # For decoding with multiple KV cache blocks: + # - cache_ids are used as context_lens + # - start_ids are used as slot_mapping + # - last_token_id is used as block_tables + # The function below transforms 2D block_tables into 1D active block table + last_token_id = attention_utils.active_block_tables( + block_tables=last_token_id, context_lens=cache_ids, + num_active_blocks=self.num_active_blocks, neuron_config=self.neuron_config) + max_num_seqs = self.neuron_config.continuous_batching.max_num_seqs + block_size = self.neuron_config.continuous_batching.block_size + block_to_seq = attention_utils.block_to_seq_indexing( + context_lens=cache_ids, num_seqs=max_num_seqs, num_blocks=self.num_active_blocks, block_size=block_size) + else: + block_to_seq = None + + head_dim = self.config.attention_head_size + pos_embed = rotary.hlo_rotary_embedding( + hidden.dtype, int(head_dim * self.config.rotary_percentage), cache_ids, + base=self.config.rope_theta, + interpolation_factor=self.config.position_interpolation_factor, + rope_scaling=self.config.rope_scaling + ) + core_id = None + + # flash decoding + if self.neuron_config.shard_over_sequence: + core_id, *rst = weights + n_kv_heads = self.config.num_key_value_heads if hasattr(self.config, "num_key_value_heads") else self.config.num_attention_heads + cores_per_kv_head = self.config.tp_degree // n_kv_heads + self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree + cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id(cache_ids, start_ids, + core_id, self.n_positions, + cores_per_kv_head=self.cores_per_kv_head) + else: + mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions, + last_token_id=last_token_id, num_active_blocks=self.num_active_blocks, neuron_config=self.neuron_config) + + + return hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, mask, active_mask, core_id + + def token_tree_pre_layer(self, hidden, cache_ids, start_ids, last_token_id, previous_cache_ids, reorder_mapping, *weights): + hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, mask, active_mask, core_id = self.pre_layer(hidden, cache_ids, start_ids, last_token_id, *weights) + if self.neuron_config.on_device_embedding: + embed_weight, token_tree_mask = weights + else: + token_tree_mask, *rst = weights + active_mask = hlo.token_tree_attention_mask(token_tree_mask, active_mask) + return hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, previous_cache_ids, reorder_mapping, mask, active_mask, core_id + + def layer( + self, hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, mask, active_mask, core_id, + attn_k_cache, attn_v_cache, + pre_attn_ln_weight, pre_attn_ln_bias, + fused_pre_attn_ln_qkv_weight, + attn_q_weight, attn_q_scales, attn_q_bias, + attn_k_weight, attn_k_scales, attn_k_bias, + attn_v_weight, attn_v_scales, attn_v_bias, + attn_out_weight, attn_out_scales, attn_out_bias, + post_attn_ln_weight, post_attn_ln_bias, + pre_mlp_ln_weight, pre_mlp_ln_bias, + mlp_in_weight, mlp_in_scales, mlp_in_bias, + mlp_out_weight, mlp_out_scales, mlp_out_bias, + post_mlp_ln_weight, post_mlp_ln_bias, + in0_weight=None, in0_scales=None, + in1_weight=None, in1_scales=None, + out_weight=None, out_scales=None, + ): + eps = self.config.rms_norm_eps + is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH + if self.neuron_config and self.neuron_config.fused_rmsnorm_qkv and active_mask is None: + assert fused_pre_attn_ln_qkv_weight is not None + attn_output, out_attn_k_cache, out_attn_v_cache = self.fused_rmsnorm_qkv( + hidden, None, eps, + cache_ids, start_ids, last_token_id, block_to_seq, pos_embed, mask, active_mask, core_id, + attn_k_cache, attn_v_cache, + fused_pre_attn_ln_qkv_weight, attn_q_scales, attn_q_bias, + attn_k_weight, attn_k_scales, attn_k_bias, # should be none + attn_v_weight, attn_v_scales, attn_v_bias, # should be none + attn_out_weight, attn_out_scales, attn_out_bias + ) + else: + ln_hidden = hlo.rms_norm(hidden, pre_attn_ln_weight, eps, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree) if is_bsh else hlo.rms_norm(hidden, pre_attn_ln_weight, eps, dim=0, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree) + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + ln_hidden, cache_ids, start_ids, last_token_id, block_to_seq, pos_embed, mask, active_mask, core_id, + attn_k_cache, attn_v_cache, + attn_q_weight, attn_q_scales, attn_q_bias, + attn_k_weight, attn_k_scales, attn_k_bias, + attn_v_weight, attn_v_scales, attn_v_bias, + attn_out_weight, attn_out_scales, attn_out_bias + ) + hidden = hlo.add(attn_output, hidden) + gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp + rms_norm_dim = 2 if is_bsh else 0 + norm_hidden = hlo.rms_norm(hidden, pre_mlp_ln_weight, eps, dim=rms_norm_dim, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree) + if self.neuron_config.fuse_mlp: + assert all(map(lambda x: not(x), [in0_weight, in1_weight, out_weight, in0_scales, in1_scales, out_scales])) ,\ + f"in0, in1 and out weights have to be None" + in0_weight, in0_scales = mlp_in_weight, mlp_in_scales + out_weight, out_scales = mlp_out_weight, mlp_out_scales + + mlp_hidden = gated_mlp( + norm_hidden, + in0_weight, in1_weight, out_weight, + in0_scales=in0_scales, + in1_scales=in1_scales, + out_scales=out_scales, + activation_function='silu', + tp_degree=self.config.tp_degree, + neuron_config=self.neuron_config + ) + res_hidden = hlo.add(mlp_hidden, hidden) + return res_hidden, out_attn_k_cache, out_attn_v_cache + + def token_tree_layer( + self, hidden, last_token_id, pos_embed, cache_ids, start_ids, block_to_seq, + previous_cache_ids, reorder_mapping, + mask, active_mask, core_id, + attn_k_cache, attn_v_cache, + pre_attn_ln_weight, pre_attn_ln_bias, + fused_pre_attn_ln_qkv_weight, + attn_q_weight, attn_q_scales, attn_q_bias, + attn_k_weight, attn_k_scales, attn_k_bias, + attn_v_weight, attn_v_scales, attn_v_bias, + attn_out_weight, attn_out_scales, attn_out_bias, + post_attn_ln_weight, post_attn_ln_bias, + pre_mlp_ln_weight, pre_mlp_ln_bias, + mlp_in_weight, mlp_in_scales, mlp_in_bias, + mlp_out_weight, mlp_out_scales, mlp_out_bias, + post_mlp_ln_weight, post_mlp_ln_bias, + in0_weight, in0_scales, + in1_weight, in1_scales, + out_weight, out_scales, + ): + eps = self.config.rms_norm_eps + is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH + ln_hidden = hlo.rms_norm(hidden, pre_attn_ln_weight, eps, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree) if is_bsh else hlo.rms_norm(hidden, pre_attn_ln_weight, eps, dim=0, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree) + reordered_attn_k_cache, reordered_attn_v_cache = attention.reorder_kv_cache(attn_k_cache, attn_v_cache, previous_cache_ids, reorder_mapping, neuron_config=self.neuron_config) + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + ln_hidden, cache_ids, start_ids, last_token_id, block_to_seq, pos_embed, mask, active_mask, core_id, + reordered_attn_k_cache, reordered_attn_v_cache, + attn_q_weight, attn_q_scales, attn_q_bias, + attn_k_weight, attn_k_scales, attn_k_bias, + attn_v_weight, attn_v_scales, attn_v_bias, + attn_out_weight, attn_out_scales, attn_out_bias + ) + hidden = hlo.add(attn_output, hidden) + gated_mlp = hlo.gated_mlp_bsh if is_bsh else hlo.gated_mlp + rms_norm_dim = 2 if is_bsh else 0 + norm_hidden = hlo.rms_norm(hidden, pre_mlp_ln_weight, eps, dim=rms_norm_dim, neuron_config=self.neuron_config, tp_degree=self.config.tp_degree) + mlp_hidden = gated_mlp( + norm_hidden, + in0_weight, in1_weight, out_weight, + in0_scales=in0_scales, + in1_scales=in1_scales, + out_scales=out_scales, + activation_function='silu', + tp_degree=self.config.tp_degree, + neuron_config=self.neuron_config + ) + res_hidden = hlo.add(mlp_hidden, hidden) + return res_hidden, out_attn_k_cache, out_attn_v_cache + + def ln_lm_head(self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias, return_all_outputs=True): + logits = transformer.rms_lm_head(self.config.tp_degree, hidden, last_token_id, rms_weight, lm_head_weight, lm_head_bias, return_all_outputs, eps=self.config.rms_norm_eps, neuron_config=self.neuron_config) + return logits + + def fused_rmsnorm_qkv( + self, hidden, pre_attn_ln_weight, eps, + cache_ids, start_ids, last_token_id, block_to_seq, pos_embed, mask, active_mask, core_id, + attn_k_cache, attn_v_cache, + attn_q_weight, attn_q_scales, attn_q_bias, + attn_k_weight, attn_k_scales, attn_k_bias, # should be none + attn_v_weight, attn_v_scales, attn_v_bias, # should be none + attn_out_weight, attn_out_scales, attn_out_bias + ): + # TODO: refactor below + from neuronxcc.nki._private_kernels.fused_linear import fused_rms_norm_qkv + def _kernel(h, w, output): + return fused_rms_norm_qkv(h, w, output, eps=eps) + + n_seqs, n_active_tokens, _ = hidden.sizes + d_head = self.config.attention_head_size + tp_degree = self.config.tp_degree + + # Compute the expected number of KV heads (Used in case fused QKV is used) + n_kv_heads_tp = None + if self.config.num_key_value_heads is not None: + n_head = self.config.num_attention_heads + n_kv_head = self.config.num_key_value_heads + n_head, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config) + n_kv_heads_tp = n_kv_head_padded // tp_degree + + _, hidden_size_tp = attn_q_weight.sizes + + n_total_heads_tp = hidden_size_tp // d_head + n_heads_tp = n_total_heads_tp - 2 * n_kv_heads_tp + # Q hidden size + hidden_size_tp = d_head * n_heads_tp + + nki_output = nki_call(_kernel, + hidden, attn_q_weight, + output_HloShapes=[hidden.dtype[hidden.sizes[0], hidden.sizes[1], attn_q_weight.sizes[-1]]]) + slice_lim = nki_output.sizes[-1] // (n_heads_tp + 2 * n_kv_heads_tp) + query = hlo.slice_along(nki_output, -1, n_heads_tp*slice_lim, start=0) + key = hlo.slice_along(nki_output, -1, (n_heads_tp+n_kv_heads_tp)*slice_lim, start=n_heads_tp*slice_lim) + value = hlo.slice_along(nki_output, -1, (n_heads_tp+2*n_kv_heads_tp)*slice_lim, start=(n_heads_tp+n_kv_heads_tp)*slice_lim) + + # shard over head (qwen2/hlo.py) + active_q_sizes = n_active_tokens, n_seqs, n_heads_tp, d_head + active_kv_sizes = n_active_tokens, n_seqs, n_kv_heads_tp, d_head + query = hlo.reshape(query, active_q_sizes) + key = hlo.reshape(key, active_kv_sizes) + value = hlo.reshape(value, active_kv_sizes) + assert all([attn_q_scales is None, + attn_q_bias is None, + attn_k_weight is None, + attn_k_scales is None, + attn_k_bias is None, + attn_v_weight is None, + attn_v_scales is None, + attn_v_bias is None]) + + # Pass QKV tuple since it will not be computed in the attention block + attn_output, out_attn_k_cache, out_attn_v_cache = self.attention( + nki_output, cache_ids, start_ids, last_token_id, block_to_seq, pos_embed, mask, active_mask, core_id, + attn_k_cache, attn_v_cache, + attn_q_weight, None, None, + None, None, None, + None, None, None, + attn_out_weight, attn_out_scales, attn_out_bias, + qkv_tuple=(query, key, value), + ) + return attn_output, out_attn_k_cache, out_attn_v_cache + + + def attention( + self, + hidden, cache_ids, start_ids, last_token_id, block_to_seq, pos_embed, mask, active_mask, core_id, + cached_keys, cached_values, + q_weight, q_scales, q_bias, + k_weight, k_scales, k_bias, + v_weight, v_scales, v_bias, + out_weight, out_scales, out_bias, + qkv_tuple: tuple = None, + ): + d_head = self.config.attention_head_size + tp_degree = self.config.tp_degree + + # Compute the expected number of KV heads (Used in case fused QKV is used) + n_kv_heads_tp = None + if self.config.num_key_value_heads is not None: + n_head = self.config.num_attention_heads + n_kv_head = self.config.num_key_value_heads + n_head, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config) + n_kv_heads_tp = n_kv_head_padded // tp_degree + + # Q = (hidden @ wQ) + bQ + # K = (hidden @ wK) + bK + # V = (hidden @ wV) + bV + if qkv_tuple: + # If computed already, skip computation here + assert active_mask is None + query, key, value = qkv_tuple + else: + query, key, value = attention.query_key_value( + hidden, + q_weight, q_scales, q_bias, + k_weight, k_scales, k_bias, + v_weight, v_scales, v_bias, + d_head, + neuron_config=self.neuron_config, + tp_degree=tp_degree, # TODO: include tp_degree into neuron_config + shard_over_batch=self.shard_over_batch, + n_kv_heads_tp=n_kv_heads_tp, + ) + + # Q = Rotate(Q) + # K = Rotate(K) + query, key = rotary.rotate_half(query, key, pos_embed, self.config.rotary_percentage, + tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) + + # Q = Q / sqrt(d_head) + query = attention.scale(query, d_head) + + # In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV. + bsh_cache_layout = False + batch_dim = 1 + if self.neuron_config is not None: + bsh_cache_layout = self.neuron_config.cache_layout == constants.LAYOUT_BSH + if bsh_cache_layout: + query, key, value = attention_utils.transpose_qkv(query, key, value) + batch_dim = 0 + + + # Single Token Generation ("Prefetch"-style) ans speculative forward + if active_mask is not None: + + n_active_tokens = key.sizes[1] if bsh_cache_layout else key.sizes[0] + if n_active_tokens > 1 and self.neuron_config and self.neuron_config.continuous_batching: + # For speculative forward + continuous batching, slice out samples in the batch size + # corresponding to the batch size of the speculative head + slice_sizes = [1] * len(cached_keys.sizes) + if cached_keys.sizes[batch_dim] == 1: + # Use hlo.select for batch size 1 as index select is prohibitively slow + # TODO: revert to hlo.index_select once its faster P126527643 + cached_keys_s = hlo.select(cached_keys, batch_dim, hlo.reshape(start_ids, slice_sizes), keepdim=True) + cached_values_s = hlo.select(cached_values, batch_dim, hlo.reshape(start_ids, slice_sizes), keepdim=True) + else: + cached_keys_s = hlo.index_select(cached_keys, batch_dim, start_ids) + cached_values_s = hlo.index_select(cached_values, batch_dim, start_ids) + if self.neuron_config and self.neuron_config.kv_cache_quant: + cached_keys_s = dequantize_kv_cache_direct_cast(cached_keys_s, self.neuron_config) + cached_values_s = dequantize_kv_cache_direct_cast(cached_values_s, self.neuron_config) + elif self.neuron_config and self.neuron_config.paged_attention: + # For decoding with multiple KV cache blocks, start_ids are used as block_tables + cached_keys_s = attention_utils.gather_blocks(cached_keys, block_tables=last_token_id, neuron_config=self.neuron_config) + cached_values_s = attention_utils.gather_blocks(cached_values, block_tables=last_token_id, neuron_config=self.neuron_config) + if self.neuron_config and self.neuron_config.kv_cache_quant: + cached_keys_s = dequantize_kv_cache_direct_cast(cached_keys_s, self.neuron_config) + cached_values_s = dequantize_kv_cache_direct_cast(cached_values_s, self.neuron_config) + elif self.neuron_config and self.neuron_config.kv_cache_quant: + cached_keys_s = dequantize_kv_cache_direct_cast(cached_keys, self.neuron_config) + cached_values_s = dequantize_kv_cache_direct_cast(cached_values, self.neuron_config) + else: + cached_keys_s = cached_keys + cached_values_s = cached_values + # Communication 1: all-gather query from cores + if (n_active_tokens != self.n_positions) and self.neuron_config.shard_over_sequence: + query = flash_decoding.gather_query_group(query, self.cores_per_kv_head, + n_head, + tp_degree) + + # Sp = Q @ Kp + prior_scores = attention.score(query, cached_keys_s, n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, block_to_seq=block_to_seq, neuron_config=self.neuron_config) + prior_scores = attention.mask(prior_scores, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) + + # Sa = Q @ Ka + active_score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, neuron_config=self.neuron_config) + active_score = attention.mask(active_score, active_mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) + + # C = softmax(Sa, Sp) @ (Va, Vp) + if self.neuron_config.shard_over_sequence: + dtype = query.dtype + context = flash_decoding.context(prior_scores, active_score, cached_values_s, value, core_id, mask, active_mask, + n_kv_heads=self.config.num_key_value_heads, n_heads=n_head, dtype=dtype, + tp_degree=tp_degree, neuron_config=self.neuron_config, + shard_over_batch=self.shard_over_batch) + cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids,value, key,self.cores_per_kv_head,core_id,dim=0) + + else: + context = attention.context(prior_scores, active_score, cached_values_s, value, + n_kv_heads=self.config.num_key_value_heads, tp_degree=tp_degree, + context_lens=cache_ids, num_active_blocks=self.num_active_blocks, + block_to_seq=block_to_seq, + neuron_config=self.neuron_config) + + # KCache[I], VCache[I] = K, V + updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids, + key, value, start_ids, neuron_config=self.neuron_config) + + # Multi-Token Context Encoding + else: + _, batch_size, _, _ = query.sizes + if self.neuron_config.lhs_aligned or batch_size == 1: + context = attention.flash_attention(query, key, value) + else: + # do not use flash attention for lhs padded (right aligned) batch > 1 case + # because it does not correctly take mask into account + context = None + + if context is None: + # S = Q @ K + + score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, neuron_config=self.neuron_config) + score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch) + context = attention.context_combined(score, value, n_kv_heads=self.config.num_key_value_heads, + tp_degree=tp_degree, neuron_config=self.neuron_config) + + if self.neuron_config.shard_over_sequence: + cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids, + value, + key, + self.cores_per_kv_head, + core_id,dim=0) + # KCache, VCache = K, V + if cached_keys.sizes == key.sizes: + if self.neuron_config and self.neuron_config.kv_cache_quant: + updated_keys = quantize_kv_cache_direct_cast(key, self.neuron_config) + updated_values = quantize_kv_cache_direct_cast(value, self.neuron_config) + else: + updated_keys, updated_values = key, value + else: + updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids, + key, value, start_ids, neuron_config=self.neuron_config) + + # O = (C @ wO) + bO + output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config) + return output, updated_keys, updated_values \ No newline at end of file diff --git a/src/transformers_neuronx/qwen2/model.py b/src/transformers_neuronx/qwen2/model.py new file mode 100644 index 0000000..7597aa0 --- /dev/null +++ b/src/transformers_neuronx/qwen2/model.py @@ -0,0 +1,416 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch +import os +from transformers_neuronx import decoder +from transformers_neuronx import module +from transformers_neuronx import ops +from transformers_neuronx import sampling +from transformers_neuronx import utils +from transformers_neuronx import bucket +from transformers_neuronx import base +from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB, KV_SHARD_PAD +from transformers_neuronx.config import NeuronConfig +from transformers_neuronx.qwen2.config import Qwen2Config +from transformers_neuronx.qwen2.modules import Qwen2ForCausalLM +from transformers_neuronx.qwen2.hlo import Qwen2ForSamplingNoEmbeddingHlo +import warnings + +class Qwen2ForSampling(base.NeuronModelBase): + + def __init__(self, config, *, n_positions=2048, batch_size=1, amp='f32', tp_degree=2, + context_length_estimate=None, context_unroll=None, unroll=None, + neuron_config=None, prefixed_length=0, **kwargs): + config = Qwen2Config(config, n_positions, batch_size, amp, tp_degree) + super().__init__(Qwen2ForCausalLM, config) + self.context_pre_hook = None + self.context_hook = None + self.config = config + self.neuron_config = neuron_config if neuron_config else NeuronConfig() + if self.neuron_config.shard_over_sequence: + n_kv_head = self.config.num_key_value_heads + kv_shard_degree = self.config.tp_degree // n_kv_head + assert kv_shard_degree <= KV_SHARD_PAD, f"increase kv_shard degree is higher than default 128" + warnings.warn(f"shard over sequence enabled, increasing n_positions {n_positions} by 128") + if isinstance(n_positions, list): + npos = sorted(n_positions) + npos[-1] += KV_SHARD_PAD + else: + npos = n_positions + KV_SHARD_PAD + self.config.n_positions = npos + config.n_positions = npos + n_positions = npos + if self.neuron_config.on_device_generation: + self.neuron_config.on_device_generation.vocab_size = self.config.vocab_size + + self.layers_after_partition = self.neuron_config.auto_layer_partition(config.num_hidden_layers) + self.prefixed_length = prefixed_length + + if context_unroll is None: + context_unroll = len(self.layers_after_partition) + self.context_unroll = context_unroll + + if unroll is None: + unroll = len(self.layers_after_partition) + self.unroll=unroll + + self.token_buckets = bucket.token_sizes(n_positions) + self.context_buckets = bucket.context_sizes(context_length_estimate, self.token_buckets) + # input length should be divisable by tp_degree to activate seq paralle + if neuron_config and neuron_config.sequence_parallel_norm: + for bucket_size in self.context_buckets: + if bucket_size > neuron_config.sequence_parallel_norm_threshold and bucket_size % self.config.tp_degree != 0: + raise ValueError(f"Sequence parallel normalization requires the bucket size ({bucket_size}) to be divisible by the tensor parallel degree ({self.config.tp_degree})") + self.window_context_buckets = [] + if prefixed_length: + if prefixed_length not in self.context_buckets: + self.context_buckets.append(prefixed_length) + self.context_buckets = sorted(self.context_buckets) + + self.batch_sizes = bucket.batch_sizes(batch_size) + self.context_batch_sizes = [1] if self.neuron_config and self.neuron_config.continuous_batching else self.batch_sizes + hlo_builder = Qwen2ForSamplingNoEmbeddingHlo(config, neuron_config=self.neuron_config) + self.decoder_param_set = decoder.DecoderLmHeadForSamplingNoEmbedding( + tp_degree=tp_degree, n_positions_list=self.token_buckets, n_active_tokens=1, batch_size=self.batch_sizes, + attention_head_size=config.attention_head_size, amp=amp, + num_layers=len(self.layers_after_partition), n_head=config.num_attention_heads, n_kv_head=config.num_key_value_heads, + unroll=unroll, neuron_config=self.neuron_config, allow_pad=True, + builder=hlo_builder + ) + self.decoder_lm_head = self.decoder_param_set.init_token_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self) + self.decoder_lm_head_for_context = self.decoder_param_set.init_context_decoder(unroll=self.context_unroll, buckets=self.context_buckets, model_obj=self) + self.decoder_lm_head_for_speculation = {} + self.decoder_lm_head_for_window_context = {} + + def load_weights(self): + self.materialize_embeddings() + ops.init() + + for layer_id, layer in enumerate(self.chkpt_model.model.layers): + if layer_id not in self.layers_after_partition: + continue + layer.materialize() + attn = layer.self_attn + mlp = layer.mlp + if self.neuron_config and self.neuron_config.quant: + is_unit_scale = self.neuron_config.quant.is_unit_scale(layer_id) + else: + is_unit_scale = False + new_layer = self.decoder_lm_head.new_layer(is_unit_scale=is_unit_scale) + new_layer.add_pre_attention_layer_norm(layer.input_layernorm.weight.detach(), None) + new_layer.add_attention_query(attn.q_proj.weight.detach().T, attn.q_proj.bias.detach()) + new_layer.add_attention_key(attn.k_proj.weight.detach().T, attn.k_proj.bias.detach()) + new_layer.add_attention_value(attn.v_proj.weight.detach().T, attn.v_proj.bias.detach()) + if self.neuron_config and self.neuron_config.attn_output_transposed: + new_layer.add_attention_output(attn.o_proj.weight.T.detach(), None, sharding=0, transposed=True) + else: + new_layer.add_attention_output(attn.o_proj.weight.detach(), None, sharding=1, transposed=False) + new_layer.add_pre_mlp_layer_norm(layer.post_attention_layernorm.weight.detach(), None) + + # Note: Automatic MLP padding is safe since zeros are *only* introduced to intermediary state + if self.neuron_config.fuse_mlp: + assert all(getattr(mlp, attr, None) for attr in ['gate_proj', 'up_proj']),\ + "fuse_mlp need to have gate and up proj weights" + assert all(getattr(mlp, attr, None).weight.shape[0] % self.config.tp_degree == 0 + for attr in ['gate_proj', 'up_proj']),\ + f" mlp weights are not divisible tp_degree {self.config.tp_degree}" + mlp_in_weight = utils.interleave_mlp(mlp.gate_proj.weight, mlp.up_proj.weight, + tp_degree=self.config.tp_degree, dim=0) + new_layer.add_mlp_input(mlp_in_weight.T.detach(), None) + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_mlp_output( + mlp.down_proj.weight.T.detach(), None, + sharding=0, + transposed=True, + ) + else: + new_layer.add_mlp_output( + mlp.down_proj.weight.detach(), None, + sharding=1, + transposed=False, + ) + else: + new_layer.add_parameter(mlp.gate_proj.weight.T, sharding=1, allow_pad=True, + allow_quantize=True, allow_transform=True) + new_layer.add_parameter(mlp.up_proj.weight.T, sharding=1, allow_pad=True, + allow_quantize=True, allow_transform=True) + if self.neuron_config.weight_tiling: + new_layer.add_parameter(mlp.down_proj.weight.T, sharding=0, allow_pad=True, + allow_quantize=True, allow_transform=True) + else: + if self.neuron_config.mlp_out_weight_transpose: + new_layer.add_parameter(mlp.down_proj.weight.T, sharding=0, allow_pad=True, + allow_quantize=True) + else: + new_layer.add_parameter(mlp.down_proj.weight, sharding=1, allow_pad=True, + allow_quantize=True, out_feature_dim=0) + new_layer.to_neuron() + layer.nullify() + if self.neuron_config.shard_over_sequence: + self.decoder_lm_head.add_pre_layer_parameter(torch.arange(self.config.tp_degree), sharding=0) + # For pipeline parallel, we need to load ln and lm_head for now even if the pipeline stage doesn't compute the, because + # 1) we need the ln_lm_head hlo for pp0 to get the logits shape and dtype + # 2) we don't needs these for intermediate pp stages, but to keep things simple, just include ln_lm_head for all pp stages for now + # 3) to get ln_lm_head hlo, we need to do weight loading and sharding + # 4) this will introduce extra memory allocation, but ln_lm_head i/o tensor is much smaller and we can get rid of it when we can construct hlo in init + ln_f = self.chkpt_model.model.norm + ln_f.materialize() + self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), None) + + lm_head = self.chkpt_model.lm_head + lm_head.materialize() + self.decoder_lm_head.add_lm_head(lm_head.weight.detach().T) + if self.neuron_config.on_device_embedding: + if self.neuron_config.sequence_parallel_norm: + self.decoder_lm_head.add_pre_layer_parameter(self.chkpt_model.model.embed_tokens.weight, sharding=None, allow_pad=True) + else: + self.decoder_lm_head.add_pre_layer_parameter(self.chkpt_model.model.embed_tokens.weight, sharding=1, allow_pad=True) + lm_head.nullify() + + self.decoder_lm_head.to_neuron() + self.init_rest_of_model() + + def materialize_embeddings(self): + # Materialize the embedding to CPU + self.chkpt_model.model.embed_tokens.materialize() + + def init_rest_of_model(self): + # Pipeline sparallel deosn't support executor right now + if not self.neuron_config.is_pp(): + self.decoder_lm_head.use_executor = True + + if self.context_buckets: + for context_length_estimate in self.context_buckets: + for batch_size in self.context_batch_sizes: + model = self.decoder_lm_head.build_weight_shared(share_caches=True, + new=self.decoder_lm_head_for_context[context_length_estimate, batch_size]) + # PERF: No latency improvement seen in multi-layer models from executor + # Pipeline parallel deosn't support executor right now + if self.context_unroll == self.config.num_hidden_layers and not self.neuron_config.is_pp(): + model.use_executor = True + self.decoder_lm_head_for_context[context_length_estimate,batch_size] = model + + if self.decoder_lm_head_for_speculation: + for i,k in enumerate(self.decoder_lm_head_for_speculation): + model= self.decoder_lm_head.build_weight_shared(share_caches=True, + new=self.decoder_lm_head_for_speculation[k], + embed_weight=self.chkpt_model.model.embed_tokens.weight) + self.decoder_lm_head_for_speculation[k]=model + + if self.decoder_lm_head_for_window_context: + for i,k in enumerate(self.decoder_lm_head_for_window_context): + model= self.decoder_lm_head.build_weight_shared(share_caches=True, + new=self.decoder_lm_head_for_window_context[k]) + self.decoder_lm_head_for_window_context[k]=model + + + def set_prefixed(self, input_ids): + self.prefixed_input_ids = input_ids[:, :self.prefixed_length] + prefixed_length = self.prefixed_length + self.prefixed_length = 0 + self.forward(self.prefixed_input_ids) + self.prefixed_length = prefixed_length + + def preprocess_and_embed(self, input_ids, cache_ids=None, start_ids=None, **kwargs): + padded_inputs, *rst = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids, **kwargs) + if not self.neuron_config.on_device_embedding: + input_embeddings = self.chkpt_model.model.embed_tokens(padded_inputs) + if self.neuron_config.attention_layout == LAYOUT_HSB: + input_embeddings = input_embeddings.transpose(0, -1).contiguous() + else: + # embedding layer is on device and will be computed as part of self._forward(), so don't compute here + input_embeddings = None + return padded_inputs, input_embeddings, *rst + + def forward(self, input_ids, cache_ids=None, start_ids=None, last_token_id=None, input_embeddings=None, **kwargs): + if last_token_id is not None: # preprocess_and_embed() has already been invoked + rst = cache_ids, start_ids, last_token_id + else: # invoke preprocess_and_embed() + input_ids, input_embeddings, *rst = self.preprocess_and_embed(input_ids, cache_ids, start_ids, **kwargs) + # either input_embeddings are generated (off device embedding), or input_ids will be padded from preprocess_and_embed (on device embedding) + inputs = input_embeddings if input_embeddings is not None else input_ids + logits = self._forward(inputs, *rst) + logits = self._postprocess(logits, start_ids=start_ids, **kwargs) + return logits + + def speculative_forward(self, input_ids, cache_ids=None, start_ids=None, speculation_length=None): + if self.neuron_config and self.neuron_config.continuous_batching: + inputs, *args = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids) + else: + batch_size, *_ = input_ids.shape + if start_ids is None: + start_ids = torch.zeros(batch_size, dtype=torch.int32) + if cache_ids is None: + batch_size, context_length = input_ids.shape + cache_ids = torch.arange(context_length, dtype=torch.int32) + if self.neuron_config.use_2d_cache_ids: + cache_ids = cache_ids.unsqueeze(0).expand(batch_size, context_length) + + inputs, *args = input_ids, cache_ids, start_ids + + batch_size, seq_len = input_ids.shape + if speculation_length is None: + model = self.decoder_lm_head + elif speculation_length not in self.decoder_lm_head_for_speculation.keys(): + # auto-infer speculation bucket, if needed + speculation_buckets = [k for (k, batch_size) in self.decoder_lm_head_for_speculation.keys()] + speculation_length = bucket.find(speculation_buckets, seq_len) + model = self.decoder_lm_head_for_speculation[speculation_length, batch_size] + if input_ids.shape[-1] > speculation_length: + input_ids = input_ids[:, :speculation_length] + else: + model = self.decoder_lm_head_for_speculation[speculation_length, batch_size] + + if not self.neuron_config.on_device_embedding: + inputs = self.chkpt_model.model.embed_tokens(inputs) + if self.neuron_config.attention_layout == LAYOUT_HSB: + inputs = inputs.transpose(0, -1).contiguous() + with torch.inference_mode(): + logits = model(inputs, *args) + logits = self._cast_logits(logits) + logits = logits[:self.config.vocab_size, -speculation_length:, :] + logits = logits.transpose(0, 1) + return logits + + + def tree_speculative_forward(self, input_ids, cache_ids=None, start_ids=None, speculation_length=None, previous_cache_ids=None, reorder_mapping=None): + if self.neuron_config and self.neuron_config.continuous_batching: + inputs, *args = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids) + else: + batch_size, *_ = input_ids.shape + if start_ids is None: + start_ids = torch.zeros(batch_size, dtype=torch.int32) + if cache_ids is None: + batch_size, context_length = input_ids.shape + cache_ids = torch.arange(context_length, dtype=torch.int32) + if self.neuron_config.use_2d_cache_ids: + cache_ids = cache_ids.unsqueeze(0).expand(batch_size, context_length) + if previous_cache_ids is None: + batch_size, context_length = input_ids.shape + previous_cache_ids = torch.arange(context_length, dtype=torch.int32) + if self.neuron_config.use_2d_cache_ids: + previous_cache_ids = previous_cache_ids.unsqueeze(0).expand(batch_size, context_length) + if reorder_mapping is None: + batch_size, context_length = input_ids.shape + reorder_mapping = torch.arange(context_length, dtype=torch.int32) + if self.neuron_config.use_2d_cache_ids: + reorder_mapping = reorder_mapping.unsqueeze(0).expand(batch_size, context_length) + inputs, *args = input_ids, cache_ids, start_ids, previous_cache_ids, reorder_mapping + + batch_size, seq_len = input_ids.shape + if speculation_length is None: + model = self.decoder_lm_head + inputs, *args = input_ids, cache_ids, start_ids + elif speculation_length not in self.decoder_lm_head_for_speculation.keys(): + # auto-infer speculation bucket, if needed + speculation_buckets = [k for (k, batch_size) in self.decoder_lm_head_for_speculation.keys()] + speculation_length = bucket.find(speculation_buckets, seq_len) + model = self.decoder_lm_head_for_speculation[speculation_length, batch_size] + if input_ids.shape[-1] > speculation_length: + input_ids = input_ids[:, :speculation_length] + else: + model = self.decoder_lm_head_for_speculation[speculation_length, batch_size] + + if not self.neuron_config.on_device_embedding: + inputs = self.chkpt_model.model.embed_tokens(inputs) + if self.neuron_config.attention_layout == LAYOUT_HSB: + inputs = inputs.transpose(0, -1).contiguous() + with torch.inference_mode(): + logits = model(inputs, *args) + logits = self._cast_logits(logits) + logits = logits[:self.config.vocab_size, -speculation_length:, :] + logits = logits.transpose(0, 1) + return logits + + + def sample(self, input_ids, sequence_length, cache_ids=None, start_ids=None, + top_k=50, top_p=1.0, eos_token_override=None, temperature=1.0, streamer=None, stopping_criteria_list=None, no_repeat_ngram_size=None, **kwargs): + + if self.neuron_config.on_device_generation: + return sampling.sample_tokens(self, input_ids, start_ids, sequence_length=sequence_length, + config=self.neuron_config.on_device_generation, streamer=streamer, cache_ids=cache_ids) + + if self.context_pre_hook is not None: + self.context_pre_hook() + batch_size, context_length = input_ids.shape + if batch_size not in self.batch_sizes: + raise ValueError(f"Model not compiled for batch_size : {batch_size}. Acceptable batch_size is one of the following {self.batch_sizes}") + prefixed_length = self.prefixed_length + + if context_length < prefixed_length: + self.prefixed_length = 0 + else: + input_ids = input_ids[:, prefixed_length:] + context_length -= prefixed_length + sequence_length -= prefixed_length + + result = sampling.sample_llama( + self, input_ids, start_ids, sequence_length, + eos_token_id=self.config.eos_token_id if eos_token_override is None else eos_token_override, + top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer, + stopping_criteria_list=stopping_criteria_list, no_repeat_ngram_size=no_repeat_ngram_size, cache_ids=cache_ids, + ) + + return result + +class FIDQwen2ForSampling(Qwen2ForSampling): + + def __init__(self, config, *, n_positions=2048, batch_size=1, amp='f32', tp_degree=2, + context_length_estimate=None, context_unroll=None, unroll=None, + neuron_config=None, reorder_cache=False, **kwargs): + # Force batch_size=1 in NEFF + super().__init__(config, n_positions=n_positions, batch_size=1, amp=amp, + tp_degree=tp_degree, context_length_estimate=context_length_estimate, + context_unroll=context_unroll, unroll=unroll, neuron_config=neuron_config, + reorder_cache=False, **kwargs) + assert len(self.decoder_lm_head.batch_size) == 1, "FIDQwen2ForSampling does not support compilation for \ + multiple batch sizes" + self.batch_size = self.decoder_lm_head.batch_size[0] + self.bos_token_id = self.config.bos_token_id + + + def sample(self, input_ids, sequence_length, start_ids=None, top_k=50, streamer=None): + """ Sample function + input_ids: shape [batch_size, context_length] + + input_ids of different batch index represent single (context + query). + They will be mixed and generate a single output sequence. + """ + + # In FID-Qwen2, first, context encoding is done w/ generating any output token for context + # Here batch-size are different context+queries of single run + + offset = 0 + fused_batch_size = 1 + batch_size, context_length = input_ids.shape + + # The context length estimate is chosen based on single (context+query) + estimate = bucket.find(self.context_buckets, context_length) + + if batch_size * context_length >= sequence_length: + raise ValueError(f"sequence_length [{sequence_length}] should be larger than fused input contexts [{context_length} x {batch_size}]") + if batch_size * estimate >= sequence_length: + raise ValueError(f"sequence_length [{sequence_length}] should be larger than fused input context estimates [{estimate} x {batch_size}]") + + + # Flatten input_ids + context_length = batch_size * context_length + input_ids = input_ids.reshape(fused_batch_size, context_length) + + # Run the model + result = sampling.sample_llama(self, input_ids, start_ids, sequence_length, + eos_token_id=self.config.eos_token_id, top_k=top_k, streamer=streamer) + + return result \ No newline at end of file diff --git a/src/transformers_neuronx/qwen2/modules.py b/src/transformers_neuronx/qwen2/modules.py new file mode 100644 index 0000000..3b50db3 --- /dev/null +++ b/src/transformers_neuronx/qwen2/modules.py @@ -0,0 +1,85 @@ +# Copyright Amazon Web Services and its Affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from transformers_neuronx import dtypes +from transformers_neuronx import module +from transformers_neuronx import utils + + +class Qwen2ForCausalLM(module.PretrainedModel): + + def __init__(self, config): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.model = Qwen2Model(config) + self.lm_head = module.LowMemoryLazyLinear(config.vocab_size, dtype=dtype, bias=False) + + def get_tied_parameters(self): + return [(self.model.embed_tokens.weight, self.lm_head.weight)] + + def get_base_model(self): + return self.model + + +class Qwen2Model(module.LowMemoryModule): + + def __init__(self, config): + super().__init__() + self.embed_tokens = module.LowMemoryEmbedding(config.vocab_size, config.hidden_size) + self.layers = module.LowMemoryModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = Qwen2RMSNorm(config) + + +class Qwen2RMSNorm(module.LowMemoryModule): + + def __init__(self, config) -> None: + super().__init__() + self.weight = module.UninitializedParameter() + + +class Qwen2DecoderLayer(module.LowMemoryModule): + + def __init__(self, config): + super().__init__() + self.self_attn = Qwen2Attention(config) + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config) + self.post_attention_layernorm = Qwen2RMSNorm(config) + + +class Qwen2Attention(module.LowMemoryModule): + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.q_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=True, dtype=dtype) + self.k_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=True, dtype=dtype) + self.v_proj = module.LowMemoryLazyLinear(self.num_heads * self.head_dim, bias=True, dtype=dtype) + self.o_proj = module.LowMemoryLazyLinear(self.hidden_size, bias=False, dtype=dtype) + + +class Qwen2MLP(module.LowMemoryModule): + + def __init__(self, config): + super().__init__() + dtype, _, _ = utils.parse_amp(config.amp) + dtype = dtypes.to_torch_dtype(dtype) + self.gate_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.up_proj = module.LowMemoryLazyLinear(config.intermediate_size, bias=False, dtype=dtype) + self.down_proj = module.LowMemoryLazyLinear(config.hidden_size, bias=False, dtype=dtype) \ No newline at end of file