From 4bfeb7b403b36996783ca4c4c885ff32c5c06692 Mon Sep 17 00:00:00 2001 From: ftgreat Date: Wed, 25 Dec 2024 10:31:59 +0800 Subject: [PATCH] Add Aquila2 model Signed-off-by: ftgreat --- paddlenlp/transformers/aquila2/__init__.py | 29 + .../transformers/aquila2/configuration.py | 157 ++ paddlenlp/transformers/aquila2/fusion_ops.py | 257 +++ paddlenlp/transformers/aquila2/modeling.py | 1952 +++++++++++++++++ paddlenlp/transformers/aquila2/tokenizer.py | 340 +++ 5 files changed, 2735 insertions(+) create mode 100644 paddlenlp/transformers/aquila2/__init__.py create mode 100644 paddlenlp/transformers/aquila2/configuration.py create mode 100644 paddlenlp/transformers/aquila2/fusion_ops.py create mode 100644 paddlenlp/transformers/aquila2/modeling.py create mode 100644 paddlenlp/transformers/aquila2/tokenizer.py diff --git a/paddlenlp/transformers/aquila2/__init__.py b/paddlenlp/transformers/aquila2/__init__.py new file mode 100644 index 000000000000..9bfec4bc1968 --- /dev/null +++ b/paddlenlp/transformers/aquila2/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import * +from .modeling import * +from .tokenizer import * diff --git a/paddlenlp/transformers/aquila2/configuration.py b/paddlenlp/transformers/aquila2/configuration.py new file mode 100644 index 000000000000..2bdbde520536 --- /dev/null +++ b/paddlenlp/transformers/aquila2/configuration.py @@ -0,0 +1,157 @@ +# Copyright © 2024 BAAI. All rights reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Aquila model configuration""" + +# modified from https://github.com/PaddlePaddle/PaddleNLP/blob/7947bca07f0dfb37172a4c0040defd0cdbbc10a0/paddlenlp/transformers/llama/configuration.py + +from ..configuration_utils import PretrainedConfig + +__all__ = [ + "AquilaConfig", + "AQUILA_PRETRAINED_INIT_CONFIGURATION", + "AQUILA_PRETRAINED_RESOURCE_FILES_MAP", +] + +AQUILA_PRETRAINED_INIT_CONFIGURATION = {} + +AQUILA_PRETRAINED_RESOURCE_FILES_MAP = { + "model_state": {}, +} + + +class AquilaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~AquilaModel`]. It is used to instantiate an Aquila + model according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Aquila model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~AquilaModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Enable rope fusion or not. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + ```""" + model_type = "aquila" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + pretrained_init_configuration = AQUILA_PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + max_position_embeddings=2048, + seq_length=2048, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + initializer_range=0.02, + rms_norm_eps=1e-6, + rope_theta=10000.0, + use_cache=True, + fuse_attention_qkv=False, + fuse_attention_ffn=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + alibi=False, + rope_scaling_factor=1.0, + rope_scaling_type=None, + long_sequence_strategy_type=None, + long_sequence_strategy_name=None, + long_sequence_init_args=None, + use_long_sequence_strategies=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.seq_length = seq_length + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + + self.use_cache = use_cache + self.fuse_attention_qkv = fuse_attention_qkv + self.fuse_attention_ffn = fuse_attention_ffn + + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.alibi = alibi + + self.rope_scaling_factor = rope_scaling_factor + self.rope_scaling_type = rope_scaling_type + + self.long_sequence_strategy_type = long_sequence_strategy_type + self.long_sequence_strategy_name = long_sequence_strategy_name + self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args + self.use_long_sequence_strategies = use_long_sequence_strategies + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def rope(self): + return not self.alibi diff --git a/paddlenlp/transformers/aquila2/fusion_ops.py b/paddlenlp/transformers/aquila2/fusion_ops.py new file mode 100644 index 000000000000..9ff276785905 --- /dev/null +++ b/paddlenlp/transformers/aquila2/fusion_ops.py @@ -0,0 +1,257 @@ +# Copyright © 2024 BAAI. All rights reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied from https://github.com/PaddlePaddle/PaddleNLP/blob/7947bca07f0dfb37172a4c0040defd0cdbbc10a0/paddlenlp/transformers/llama/fusion_ops.py +import os + +import paddle +import paddle.nn.functional as F + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +from paddle.utils import try_import + +from paddlenlp.utils.tools import get_env_device + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None +try: + if get_env_device() in ["npu", "gcu"]: + from paddle.base import core + + for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): + if lib.endswith(".so"): + paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib) + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +from paddlenlp.transformers.ring_flash_attention import RingFlashAttention + + +def fusion_rope( + query_states, + key_states, + value_states, + hidden_states, + position_ids, + past_key_value, + rotary_emb, + context_parallel_degree=-1, +): + if get_env_device() != "gcu": + assert past_key_value is None, "fuse rotary not support cache kv for now" + batch_size, seq_length, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape + if context_parallel_degree > 1: + assert get_env_device() == "gpu", "context parallel only support cuda device for now" + kv_seq_len *= context_parallel_degree + if get_env_device() != "gcu": + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + if get_env_device() == "npu": + query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] + key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] + elif get_env_device() == "gcu": + cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len) + query_states, key_states = core.eager._run_custom_op( + "fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True + ) + else: + # paddle version > 2.6 or develop support q and k/v with different num_heads + paddle_version = float(paddle.__version__[:3]) + if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): + query_states, _, _ = fused_rotary_position_embedding( + query_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + key_states, _, _ = fused_rotary_position_embedding( + key_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + return query_states, key_states + + +def rms_norm_fused(x_in, w, eps, use_fast_ln=False): + if use_fast_ln: + fast_ln = try_import("fast_ln") + return fast_ln.fast_rms_norm(x_in, w, eps)[0] + else: + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] + + +def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False): + if get_env_device() == "npu": + return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "gcu": + return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" + ) + return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln) + + +def fusion_flash_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi=None, + attn_mask_startend_row_indices=None, + sequence_parallel=False, + reshard_layer=None, + npu_is_casual=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + if alibi is not None: + raise ValueError("Flash Attention doesn't support alibi") + if config.context_parallel_degree > 1: + raise ValueError(f"Context parallel is not implemented in version {version}") + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + causal=True, + return_softmax=output_attentions, + ) + else: + if alibi is not None: + alibi = alibi.reshape([bsz, num_heads, 1, -1]) + attention_mask = attention_mask.cast(alibi.dtype) + alibi + if get_env_device() == "npu": + if config.context_parallel_degree > 1: + raise ValueError("Context parallel is not implemented for npu") + attn_output = core.eager._run_custom_op( + "flash_attention_npu", + query_states, + key_states, + value_states, + None, + attention_mask, + 0.0, + attention_mask is None, + True, + False, + npu_is_casual, + )[0] + elif get_env_device() == "gcu": + if config.context_parallel_degree > 1: + raise ValueError("Context parallel is not implemented for gcu") + attn_output = core.eager._run_custom_op( + "fused_sdp_flash_attention_gcu", + query_states, + key_states, + value_states, + attention_mask, + 0.0, + attention_mask is None, + True, + )[0] + else: + if config.context_parallel_degree > 1: + attn_output = RingFlashAttention.apply( + query_states, + key_states, + value_states, + attn_mask=None, + is_causal=True, + ) + else: + if attn_mask_startend_row_indices is not None: + assert alibi is None, "flash_attention_with_sparse_mask not support alibi" + if len(attn_mask_startend_row_indices.shape) == 2: + attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1) + attn_output = F.flash_attention_with_sparse_mask( + query_states, + key_states, + value_states, + attn_mask_start_row_indices=attn_mask_startend_row_indices, + is_causal=True, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) + attn_weights = None + + if reshard_layer is not None: + # attn_output shape: [bs, seqlen, num_head/sep, head_dim] + attn_output = reshard_layer( + attn_output, + split_axis=1, + concat_axis=2, + ) + # attn_output shape: [bs, seqlen/sep, num_head, head_dim] + assert ( + config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0 + ), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}" + q_len = q_len // config.sep_parallel_degree + num_heads = num_heads * config.sep_parallel_degree + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output diff --git a/paddlenlp/transformers/aquila2/modeling.py b/paddlenlp/transformers/aquila2/modeling.py new file mode 100644 index 000000000000..6ce902fa3c61 --- /dev/null +++ b/paddlenlp/transformers/aquila2/modeling.py @@ -0,0 +1,1952 @@ +# Copyright © 2024 BAAI. All rights reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# modified from PaddleNLP https://github.com/PaddlePaddle/PaddleNLP/blob/7947bca07f0dfb37172a4c0040defd0cdbbc10a0/paddlenlp/transformers/llama/modeling.py + +"""Paddle Aquila model""" +from __future__ import annotations + +import math +import os +import warnings +from functools import partial +from typing import Optional, Tuple + +import paddle +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +from paddlenlp.transformers import linear_utils +from paddlenlp.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddlenlp.transformers.linear_utils import Linear +from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model +from paddlenlp.transformers.segment_parallel_utils import ReshardLayer +from paddlenlp.utils.log import logger +from paddlenlp.utils.tools import get_env_device + +from .configuration import ( + AQUILA_PRETRAINED_INIT_CONFIGURATION, + AQUILA_PRETRAINED_RESOURCE_FILES_MAP, + AquilaConfig, +) + +try: + if get_env_device() in ["npu", "gcu"]: + + for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): + if lib.endswith(".so"): + paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib) + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None +from . import fusion_ops + +rms_norm_fused = fusion_ops.rms_norm_fused + +__all__ = [ + "AquilaModel", + "AquilaPretrainedModel", + "AquilaForCausalLM", + "AquilaPretrainingCriterion", +] + + +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def get_use_casual_mask(): + """Get the value of the 'USE_CASUAL_MASK' environment variable.""" + return os.getenv("USE_CASUAL_MASK", "False") == "True" + + +def build_alibi_tensor( + bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1 +) -> Tensor: + batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1] + slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32") + alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand( + [num_heads, -1, -1] + ) + alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1]) + return paddle.cast(alibi, dtype) + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def assign_kv_heads(num_kv_heads: int, num_gpus: int): + # Initialize the assignment list + """ + Assign kv heads to different GPUs in the Tensor Parallel Setup + + Examples: + assign_kv_heads(num_kv_heads=1, num_gpus=2): [[0], [0]] + assign_kv_heads(num_kv_heads=2, num_gpus=2): [[0], [1]] + assign_kv_heads(num_kv_heads=4, num_gpus=2): [[0,1], [2,3]] + assign_kv_heads(num_kv_heads=1, num_gpus=4): [[0],[0],[0],[0]] + assign_kv_heads(num_kv_heads=2, num_gpus=4): [[0],[0],[1],[1]] + assign_kv_heads(num_kv_heads=4, num_gpus=4): [[0],[1],[2],[3]] + """ + assignment_list = [[] for _ in range(num_gpus)] + # Case 1: more heads than cards + if num_kv_heads > num_gpus: + num_heads_per_card = num_kv_heads // num_gpus + for i in range(num_gpus): + for j in range(num_heads_per_card): + assignment_list[i].append(i * num_heads_per_card + j) + # Case 2: more cards than heads. each card get only 1 head. + else: + num_card_per_heads = num_gpus // num_kv_heads + for i in range(num_kv_heads): + for j in range(num_card_per_heads): + assignment_list[i * num_card_per_heads + j].append(i) + return assignment_list + + +def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=False) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=False) + return logits + + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi=None, + attn_mask_startend_row_indices=None, + sequence_parallel=False, + reshard_layer=None, + npu_is_casual=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + return fusion_ops.fusion_flash_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + attn_mask_startend_row_indices, + sequence_parallel, + reshard_layer, + npu_is_casual, + ) + + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + + else: + if config.context_parallel_degree > 1: + raise ValueError("Context parallel requires `use_flash_attention=True`") + + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) + # then add alibi bias + if alibi is not None: + alibi = alibi.reshape([bsz, num_heads, 1, -1]) + attn_weights = attn_weights + alibi + + if paddle.in_dynamic_mode() and attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + # In sep mode, the attenion mask should be created in the runtime. + if reshard_layer is not None: + attention_mask = None + + # NOTE: we only call get_triangle_upper_mask under PP setup + # FIXME ZHUI when we use pipeline parallel, the attention_mask can be None + # we just make it triangle_upper_mask + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if paddle.in_dynamic_mode() and attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + else: + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if reshard_layer is not None: + attn_output = reshard_layer( + attn_output, + split_axis=1, + concat_axis=2, + ) + q_len = q_len // config.sep_parallel_degree + num_heads = num_heads * config.sep_parallel_degree + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def is_casual_mask(attention_mask): + """ + Upper triangular of attention_mask equals to attention_mask is casual + """ + return (paddle.triu(attention_mask) == attention_mask).all().item() + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make casual mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + if get_env_device() == "npu": + mask = paddle.tril(paddle.ones((target_length, target_length))).astype("int32") + else: + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + if get_env_device() == "npu": + mask = mask[:, None, None, :].astype(dtype) + else: + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask + + +class AquilaRMSNorm(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, hidden_states): + if self.config.use_fused_rms_norm: + return fusion_ops.fusion_rms_norm( + hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm + ) + + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + # hidden_states = hidden_states.astype("float32") + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + + hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) + return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim]) + + +class AquilaRotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + # [dim / 2] + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + cos = self.cos_cached[:, :seq_len, :, :] + sin = self.sin_cached[:, :seq_len, :, :] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + def get_fused_cos_sin(self, x, seq_len=None): + if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype: + return self.cos_sin_table.cast(x.dtype) + else: + return self.cos_sin_table + + +class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding): + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings * scaling_factor, base) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + t = t / self.scaling_factor + # [seq_len, dim/2] + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + + +class AquilaNTKScalingRotaryEmbedding(AquilaRotaryEmbedding): + """AquilaNTKScalingRotaryEmbedding extended with NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + base = base * scaling_factor ** (dim / (dim - 2)) + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings * scaling_factor, base) + + +class AquilaDynamicNTKScalingRotaryEmbedding(AquilaRotaryEmbedding): + """AquilaRotaryEmbedding extended with Dynamic NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base) + + def _scale_cos_sin(self, seq_len): + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + alpha = (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + base = self.base * alpha ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + freqs = paddle.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + scale_cos = emb.cos()[None, :, None, :] + scale_sin = emb.sin()[None, :, None, :] + scale_cos_sin = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + return scale_cos, scale_sin, scale_cos_sin + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_position_embeddings: + scale_cos, scale_sin, _ = self._scale_cos_sin(seq_len=seq_len) + else: + scale_cos, scale_sin = self.cos_cached, self.sin_cached + cos = scale_cos[:, :seq_len, :, ...] + sin = scale_sin[:, :seq_len, :, ...] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + def get_fused_cos_sin(self, x, seq_len=None): + if seq_len > self.max_position_embeddings: + _, _, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len) + else: + scale_cos_sin = self.cos_sin_table + if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype: + return scale_cos_sin.cast(x.dtype) + else: + return scale_cos_sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + + if position_ids is None: + # Note: Only for AquilaForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class AquilaMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.tensor_parallel_degree = config.tensor_parallel_degree + self.fuse_attention_ffn = config.fuse_attention_ffn + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if config.tensor_parallel_degree > 1: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size * 2, + gather_output=False, + has_bias=False, + ) + else: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + def forward(self, x): + if self.fuse_attention_ffn: + # FIXME(yangjianbang): use paddle's native swiglu + if get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + out = self.gate_up_fused_proj(x) + out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) + out = self.down_proj(out) + return out + except ImportError: + gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) + out = self.down_proj(F.silu(gate_out) * up_out) + return out + + x = swiglu(self.gate_up_fused_proj(x)) + else: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) + return out + + +class AquilaAttention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AquilaConfig, layerwise_recompute: bool = False): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // config.num_attention_heads + + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads + + self.max_position_embeddings = config.max_position_embeddings + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + self.fuse_attention_qkv = config.fuse_attention_qkv + + self.kv_indices = None + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + + if self.num_key_value_heads % config.tensor_parallel_degree == 0: + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + else: + if self.fuse_attention_qkv: + # TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp + raise ValueError( + f"fuse_attention_qkv can't be True when num_key_value_heads {config.num_key_value_heads} % tensor_parallel_degree {config.tensor_parallel_degree} != 0" + ) + logger.warning( + f"Get num_key_value_heads: {self.num_key_value_heads}, can't split to tensor_parallel_degree: {config.tensor_parallel_degree}, so we don't spilt key value weight." + ) + self.kv_indices = paddle.to_tensor( + assign_kv_heads(self.num_key_value_heads, config.tensor_parallel_degree)[ + config.tensor_parallel_rank + ] + ) + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if config.tensor_parallel_degree > 1: + if self.fuse_attention_qkv: + self.qkv_proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) + if self.kv_indices is None: + self.k_proj = ColumnParallelLinear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + has_bias=False, + gather_output=False, + ) + else: + self.k_proj = Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + else: + if self.fuse_attention_qkv: + self.qkv_proj = Linear( + self.hidden_size, + self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + else: + self.q_proj = Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + self.k_proj = Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + if config.tensor_parallel_degree > 1: + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + input_is_parallel=True, + ) + else: + self.o_proj = Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + + if config.rope: + if config.use_long_sequence_strategies: + self.rotary_emb = LongSequenceStrategies.build_long_sequence_strategy( + config.long_sequence_strategy_type, + config.long_sequence_strategy_name, + **config.long_sequence_init_args, + ) + else: + self._init_rope() + + self.reshard_layer = None + if config.sep_parallel_degree > 1: + assert self.num_key_value_heads % config.sep_parallel_degree == 0 + assert self.num_heads % config.sep_parallel_degree == 0 + self.reshard_layer = ReshardLayer() + + self.config = config + + def _init_rope(self): + if self.config.rope_scaling_type is None: + self.rotary_emb = AquilaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + ) + elif self.config.rope_scaling_type == "linear": + self.rotary_emb = AquilaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, + ) + elif self.config.rope_scaling_type == "ntk": + self.rotary_emb = AquilaNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, + ) + elif self.config.rope_scaling_type == "dynamic_ntk": + self.rotary_emb = AquilaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}") + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + alibi: Optional[paddle.Tensor] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + npu_is_casual: bool = False, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + + if self.fuse_attention_qkv: + mix_layer = self.qkv_proj(hidden_states) + # NOTE for GQA attention fusion (compatible with MHA and MQA): + # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim]. + # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim]. + # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim], + # where num_groups = num_q_heads // num_kv_heads. + # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim] + # to represent the q, k and v respectively. + # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim]. + # The k and v are in the shape like [b, s, num_kv_heads, head_dim]. + # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads, + # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim]. + if self.reshard_layer is not None: + if self.sequence_parallel: + assert self.seq_length % self.config.sep_parallel_degree == 0 + mix_layer = paddle.reshape_( + mix_layer, + [ + -1, + self.seq_length // self.config.sep_parallel_degree, + self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim, + ], + ) + # [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim] + mix_layer = self.reshard_layer( + mix_layer, + split_axis=2, + concat_axis=1, + ) + mix_layer = paddle.reshape_( + mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim] + ) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree + else: + if self.sequence_parallel: + target_shape = [ + -1, + self.seq_length, + self.num_key_value_heads, + (self.num_key_value_groups + 2) * self.head_dim, + ] + else: + target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim] + mix_layer = paddle.reshape_(mix_layer, target_shape) + query_states, key_states, value_states = paddle.split( + mix_layer, + num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim], + axis=-1, + ) + if self.gqa_or_mqa: + query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim]) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.reshard_layer is not None: + if self.sequence_parallel: + assert self.seq_length % self.config.sep_parallel_degree == 0 + query_states = paddle.reshape( + query_states, + [-1, self.seq_length // self.config.sep_parallel_degree, self.num_heads * self.head_dim], + ) + key_states = paddle.reshape( + key_states, + [ + -1, + self.seq_length // self.config.sep_parallel_degree, + self.num_key_value_heads * self.head_dim, + ], + ) + value_states = paddle.reshape( + value_states, + [ + -1, + self.seq_length // self.config.sep_parallel_degree, + self.num_key_value_heads * self.head_dim, + ], + ) + query_states = self.reshard_layer( + query_states, + split_axis=2, + concat_axis=1, + ) + key_states = self.reshard_layer( + key_states, + split_axis=2, + concat_axis=1, + ) + value_states = self.reshard_layer( + value_states, + split_axis=2, + concat_axis=1, + ) + query_states = paddle.reshape( + query_states, [0, self.seq_length, -1, self.head_dim] + ) # [bs, seq_len, num_head/k, head_dim], k is sep degree + key_states = paddle.reshape(key_states, [0, self.seq_length, -1, self.head_dim]) + value_states = paddle.reshape(value_states, [0, self.seq_length, -1, self.head_dim]) + else: + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(shape=target_query_shape) + key_states = key_states.reshape(shape=target_key_value_shape) + value_states = value_states.reshape(shape=target_key_value_shape) + + kv_seq_len = key_states.shape[-3] + + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + + if self.config.rope: + if self.reshard_layer is not None: + batch_size, seq_length, _, _ = query_states.shape + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + if self.config.context_parallel_degree > 1: + batch_size, seq_length, _, _ = query_states.shape + group = fleet.get_hybrid_communicate_group().get_sep_parallel_group() + chunk_size = seq_length // 2 + chunk_num = group.nranks * 2 + rank = group.rank + first_chunk_ids = paddle.arange(rank * chunk_size, (rank + 1) * chunk_size, dtype="int64") + second_chunk_ids = paddle.arange( + (chunk_num - rank - 1) * chunk_size, (chunk_num - rank) * chunk_size, dtype="int64" + ) + position_ids = paddle.concat([first_chunk_ids, second_chunk_ids]).expand((batch_size, seq_length)) + if self.use_fused_rope: + query_states, key_states = fusion_ops.fusion_rope( + query_states, + key_states, + value_states, + hidden_states, + position_ids, + past_key_value, + self.rotary_emb, + self.config.context_parallel_degree, + ) + + else: + if self.config.context_parallel_degree > 1: + kv_seq_len *= self.config.context_parallel_degree + if self.config.use_long_sequence_strategies: + cos, sin = self.rotary_emb(seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + cos, sin = ( + cos.cast(value_states.dtype) if cos.dtype != value_states.dtype else cos, + sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = (key_states, value_states) if use_cache else None + if self.kv_indices is not None: + key_states = paddle.index_select(key_states, self.kv_indices, axis=2) + value_states = paddle.index_select(value_states, self.kv_indices, axis=2) + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + # paddle version > 2.6 or develop support flash-attn with gqa/mqa + paddle_version = float(paddle.__version__[:3]) + if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + scaled_dot_product_attention, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + attn_mask_startend_row_indices, + self.sequence_parallel, + reshard_layer=self.reshard_layer, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = scaled_dot_product_attention( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + attn_mask_startend_row_indices, + self.sequence_parallel, + reshard_layer=self.reshard_layer, + npu_is_casual=npu_is_casual, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class AquilaDecoderLayer(nn.Layer): + def __init__(self, config, layerwise_recompute: bool = False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = AquilaAttention(config, layerwise_recompute) + self.mlp = AquilaMLP(config) + self.input_layernorm = AquilaRMSNorm(config) + self.post_attention_layernorm = AquilaRMSNorm(config) + self.sequence_parallel = config.sequence_parallel + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + alibi: Optional[paddle.Tensor] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + npu_is_casual: bool = False, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `cache` key value states are returned and can be used to speed up decoding + (see `cache`). + cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + npu_is_casual=npu_is_casual, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class AquilaPretrainedModel(PretrainedModel): + config_class = AquilaConfig + base_model_prefix = "aquila" + pretrained_init_configuration = AQUILA_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = AQUILA_PRETRAINED_RESOURCE_FILES_MAP + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + + @classmethod + def _get_name_mappings(cls, config: AquilaConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "AquilaModel" + if "AquilaModel" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "aquila." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: AquilaConfig, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: AquilaConfig, is_fuse=False): + # return parameter fuse utils + from paddlenlp.transformers.conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ) + + fuse_gate_up_keys = ( + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.up_proj.weight", + "layers.0.mlp.gate_up_fused_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys]) + final_actions[keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = partial(fn, split_nums=2) + return final_actions + + def _init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.RowParallelLinear, + mpu.ColumnParallelLinear, + linear_utils.RowSequenceParallelLinear, + linear_utils.ColumnSequenceParallelLinear, + AquilaLMHead, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.aquila.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.aquila.config.initializer_range, + shape=layer.weight.shape, + ) + ) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, AquilaMLP): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, AquilaAttention): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + + +@register_base_model +class AquilaModel(AquilaPretrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AquilaDecoderLayer`] + Args: + config: AquilaConfig + """ + + def __init__(self, config: AquilaConfig): + super().__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.sequence_parallel = config.sequence_parallel + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + self.config = config + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + self.layers = nn.LayerList( + [AquilaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)] + ) + self.norm = AquilaRMSNorm(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length + ) + if get_env_device() == "npu": + expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool") + else: + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + if get_env_device() == "npu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") + expanded_attn_mask = expanded_attn_mask.astype("float32") + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) + elif get_env_device() in ["xpu", "gcu"]: + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) + expanded_attn_mask = expanded_attn_mask.astype(dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + alibi=None, + attn_mask_startend_row_indices=None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi, + attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + attn_mask_startend_row_indices=None, + **kwargs, + ): + if self.sequence_parallel and use_cache: + raise ValueError("We currently only support sequence parallel without cache.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): + raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi") + # embed positions + if attn_mask_startend_row_indices is None and attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attn_mask_startend_row_indices is None and self.config.alibi: + if self.config.use_long_sequence_strategies: + alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( + self.config.long_sequence_strategy_type, + self.config.long_sequence_strategy_name, + **self.config.long_sequence_init_args, + ) + alibi = alibi_layer(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + else: + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + if self.config.tensor_parallel_degree > 1: + block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree + alibi = alibi[ + :, + self.config.tensor_parallel_rank + * block_size : (self.config.tensor_parallel_rank + 1) + * block_size, + ] + alibi = alibi.reshape([batch_size * block_size, 1, seq_length_with_past]) + else: + alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length_with_past]) + else: + alibi = None + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + use_casual_mask = get_use_casual_mask() and not self.config.alibi + + if use_casual_mask: + attention_mask = None + elif attn_mask_startend_row_indices is None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + + is_casual = False + + if attn_mask_startend_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu": + if use_casual_mask: + is_casual = True + else: + is_casual = is_casual_mask(attention_mask) + if get_env_device() != "npu": + if is_casual and alibi is None: + attention_mask = None + else: + attention_mask = None if attention_mask is None else attention_mask.astype("bool") + hidden_states = inputs_embeds + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + npu_is_casual=is_casual, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + ) + + +class AquilaPretrainingCriterion(paddle.nn.Layer): + """ + Criterion for Aquila. + It calculates the final loss. + """ + + def __init__(self, config): + + super(AquilaPretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = ( + config.tensor_parallel_degree > 1 + and config.vocab_size % config.tensor_parallel_degree == 0 + and config.tensor_parallel_output + ) + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + + if self.config.sep_parallel_degree > 1 or self.config.context_parallel_degree > 1: + _hcg = fleet.get_hybrid_communicate_group() + masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) + # skip ignore_index which loss == 0 + # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] + # loss = paddle.mean(masked_lm_loss) + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + if count == 0: + loss = paddle.sum(masked_lm_loss * binary_sequence) + else: + loss = paddle.sum(masked_lm_loss * binary_sequence) / count + + return loss + + +class ConcatMaskedLoss(PyLayer): + @staticmethod + def forward(ctx, inp, axis, group): + inputs = [] + paddle.distributed.all_gather(inputs, inp, group=group) + with paddle.no_grad(): + cat = paddle.concat(inputs, axis=axis) + ctx.args_axis = axis + ctx.args_group = group + return cat + + @staticmethod + def backward(ctx, grad): + axis = ctx.args_axis + group = ctx.args_group + with paddle.no_grad(): + grads = paddle.split(grad, paddle.distributed.get_world_size(group), axis=axis) + grad = grads[paddle.distributed.get_rank(group)] + return grad + + +class AquilaLMHead(nn.Layer): + def __init__(self, config: AquilaConfig): + super(AquilaLMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + if vocab_size != config.vocab_size: + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + else: + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + self.weight.split_axis = 1 + if get_env_device() == "xpu": + try: + from paddle_xpu.layers.nn import ( # noqa: F401 + parallel_matmul as xpu_parallel_matmul, + ) + + self.xpu_parallel_matmul = xpu_parallel_matmul() + except ImportError: + self.xpu_parallel_matmul = None + + def forward(self, hidden_states, tensor_parallel_output=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + seq_length = self.config.seq_length + if self.config.sep_parallel_degree > 1: + assert seq_length % self.config.sep_parallel_degree == 0 + seq_length = seq_length // self.config.sep_parallel_degree + if self.config.context_parallel_degree > 1: + assert seq_length % self.config.context_parallel_degree == 0 + seq_length = seq_length // self.config.context_parallel_degree + hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 + + if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: + logits = self.xpu_parallel_matmul( + hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training + ) + else: + logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + return logits + + +class AquilaForCausalLM(AquilaPretrainedModel): + enable_to_static_method = True + + def __init__(self, config): + super().__init__(config) + self.config = config + self.config["model_type"] = "aquila" + self.aquila = AquilaModel(config) + self.lm_head = AquilaLMHead(config) + self.criterion = AquilaPretrainingCriterion(config) + + def get_input_embeddings(self): + return self.aquila.embed_tokens + + def set_input_embeddings(self, value): + self.aquila.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.aquila = decoder + + def get_decoder(self): + return self.aquila + + def prepare_inputs_for_generation( + self, input_ids, use_cache=False, past_key_values=None, inputs_embeds=None, **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + attn_mask_startend_row_indices=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + attention_mask = None + + outputs = self.aquila( + input_ids, # [bs, seq_len] + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + + hidden_states = outputs[0] # [bs, seq_len, dim] + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/paddlenlp/transformers/aquila2/tokenizer.py b/paddlenlp/transformers/aquila2/tokenizer.py new file mode 100644 index 000000000000..b2bf9fd398ba --- /dev/null +++ b/paddlenlp/transformers/aquila2/tokenizer.py @@ -0,0 +1,340 @@ +# Copyright © 2024 BAAI. All rights reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Aquila.""" + +# modified from https://github.com/PaddlePaddle/PaddleNLP/blob/7947bca07f0dfb37172a4c0040defd0cdbbc10a0/paddlenlp/transformers/qwen2/tokenizer.py + +import json +import os +import unicodedata +from functools import lru_cache +from typing import Optional, Tuple + +import regex as re + +from paddlenlp.transformers.tokenizer_utils import AddedToken, PretrainedTokenizer +from paddlenlp.utils.log import logger + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +__all__ = ["AquilaTokenizer"] + +MAX_MODEL_INPUT_SIZES = {} + +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class AquilaTokenizer(PretrainedTokenizer): + """ + Construct a Aquila tokenizer. Based on byte-level Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + + >>> tokenizer = AquilaTokenizer.from_pretrained("Aquila/Aquila-tokenizer") + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + You should not use GPT2Tokenizer instead, because of the different pretokenization rules. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = + ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', + '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. + """ + + resource_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + max_model_input_sizes = MAX_MODEL_INPUT_SIZES + + pretrained_resource_files_map = { + "vocab_file": {}, + } + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + clean_up_tokenization_spaces=False, + split_special_tokens=False, + **kwargs, + ): + + if unk_token is None: + logger.info("The `unk_token` parameter needs to be defined: we use `eos_token` by default.") + unk_token = eos_token + + # Aquila vocab does not contain control tokens; added tokens need to be special + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_merges = [] + with open(merges_file, encoding="utf-8") as merges_handle: + for i, line in enumerate(merges_handle): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + bpe_merges.append(tuple(line.split())) + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # NOTE: the cache can grow without bound and will get really large for long running processes + # (esp. for texts of language that do not use space between word, e.g. Chinese); technically + # not a memory leak but appears as one. + # GPT2Tokenizer has the same problem, so let's be consistent. + self.cache = {} + + self.pat = re.compile(PRETOKENIZE_REGEX) + + if kwargs.get("add_prefix_space", False): + logger.warning_once( + f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." + ) + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.added_tokens_encoder.get(token, len(self.encoder))) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.added_tokens_decoder.get(index, self.unk_token)) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def _decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers + # and cannot be configured elsewhere, but it should default to False for AquilaTokenizer + return super()._decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, **kwargs): + text = unicodedata.normalize("NFC", text) + return (text, kwargs)