Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YaRN: Efficient Context Window Extension of Large Language Models #35

Open
eagle705 opened this issue Feb 14, 2024 · 0 comments
Open

YaRN: Efficient Context Window Extension of Large Language Models #35

eagle705 opened this issue Feb 14, 2024 · 0 comments
Assignees

Comments

@eagle705
Copy link
Owner

eagle705 commented Feb 14, 2024

Note

  • ICLR 2024 Accept paper
  • 기존의 Abs position embedding이 더해주는 방법론이었다면 RoPE는 곱해주는 방법론
    • Abs는 입력단에서 더해주지만, RoPE는 attention 연산시마다 (layer마다) 곱해줌
  • RoPE는 attention formulation안에 녹여낸거고 rotation matrix로 위치정보를 인코딩함
    • Specifically, the proposed RoPE encodes the absolute position with a rotation matrix and meanwhile incorporates the explicit relative position dependency in self-attention formulation
    • RoFormer(2021)에서 제안됨
      image

Abstract

  • RoPE는 extrapolation 안됨
  • YaRN (Yet another RoPE extensioN method) 제안
    • compute-efficient method to extend the context window of such models, requiring 10x less tokens and 2.5x less training steps than previous methods
      • 10배 적은 토큰과 2.5배 적은 tr steps로 모델 학습 가능
  • YaRN은 finetuning dataset 이상의 context에서도 extrapolation 능력이 있음을 알수있었음
  • 128k context length까지 만들어봄
  • github: https://github.com/jquesnelle/yarn

Introduction

  • The original Transformer architecture used an absolute sinusoidal position encoding, which was later improved to a learnable absolute position encoding
  • Since then, relative positional encoding schemes [32] have further increased the performance of Transformers. Currently, the most popular relative positional encodings are T5 Relative Bias [30], RoPE [34], XPos [35], and ALiBi [27].
  • two improvements of the "NTK-aware" interpolation have been proposed, with different emphasis:
    • the "Dynamic NTK" interpolation method [14] for pre-trained models without fine-tuning.
    • the "NTK-by-parts" interpolation method [7] which performs the best when fine-tuned on a small amount of longer-context data.
  • The "NTK-aware" interpolation and the "Dynamic NTK" interpolation have already seen their presence in the open-source models such as Code Llama [31] (using "NTK-aware" interpolation) and Qwen 7B [2] (using "Dynamic NTK").
  • In this paper, in addition to making a complete account of the previous unpublished works on the "NTK-aware", the "Dynamic NTK" and the "NTK-by-part" interpolations, we present YaRN (Yet another RoPE extensioN method), an improved method to efficiently extend the context window of models trained with Rotary Position Embeddings (RoPE) including the LLaMA [38], the GPT- NeoX [5], and the PaLM [10] families of models.
    • 기존 NTK 연구들을 완성시키겠다라는 방향성으로 YaRN
  • YaRN reaches state-of-the-art performances in context window extensions after fine-tuning on less than ∼0.1% of the original pre-training data
    • 1T tokens면 0.1%면 1B tokens인가?
  • In the meantime, by combining with the inference-time technique called Dynamic Scaling, the Dynamic-YaRN allows for more than 2x context window extension without any fine-tuning.
    • dynamic-yarn은 파인튜닝없이 2배 커버 가능
image - 위 그림에서 특이한건 10개의 문서만 썼다는것(+sliding window가 256으로 엄청 작다?!)

Background and Related Work

Rotary Position Embeddings

image

2D case

  • 2개의 차원마다 rotate 시켜줌 (general form은 diagonal 위주의 sparse matrix인데 더 효율적으로도 구현가능)
image image

Attention은 아래와 같이 변형 가능

image

LLaMA2 Long에 있던 RoPE 수식

  • theta가 base freq인듯
image

Position Interpolation

  • 학습한 freq내에서 길이비율을 조절해서 모델에 feed
image image

Related work

  • Concurrently with our work, LM-Infinite [16] proposes similar ideas to YaRN, but focuses on "on-the- fly" length generalization for non-fine-tuned models. Since they also modify the attention mechanism of the models, it is not an embedding interpolation method and is not immediately compatible with Flash Attention 2.

Methodology

  • Whereas PI stretches all RoPE dimensions equally, we find that the theoretical interpolation bound described by PI [9] is insufficient at predicting the complex dynamics between RoPE and the LLM’s internal embeddings. In the following subsections, we describe the main issues with PI we have individually identified and solved, so as to give the readers the context, origin and justifications of each method which we use in concert to obtain the full YaRN method.

Loss of High Frequency information - "NTK-aware" interpolation

  • If we look at RoPE only from an information encoding perspective, it was shown in [36], using Neural Tangent Kernel (NTK) theory, that deep neural networks have trouble learning high frequency information if the input dimension is low and the corresponding embeddings lack high frequency components.

  • Here we can see the similarities: a token’s positional information is one-dimensional, and RoPE expands it to an n-dimensional complex vector embedding

    • 토큰의 포지션 정보는 1개의 dim 정보인데 n dim의 complex vector embedding에 이걸 확장해야하는 상황
    • 정보의 인코딩관점에서 보면 NTK는 high freq information을 다루기가 어렵(input dim이 낮을수록) / high frequency가 뭘까.. 미세한 차이같은건가 -> dim 개수 차이를 의미하는듯
      • Fourier features let networks learn high frequency functions in low dimensional domains (NIPS 2020) 논문에서 NTK언급
  • RoPE closely resembles Fourier Features [36] in many aspects, as it is possible to define RoPE as a special 1D case of a Fourier Feature

  • Stretching the RoPE embeddings indiscriminately results in the loss of important high frequency details which the network needs in order to resolve tokens that are both very similar and very close together (the rotation describing the smallest distance needs to not be too small for the network to be able to detect it)

  • We hypothesise that the slight increase of perplexity for short context sizes after fine-tuning on larger context sizes seen in PI [9] might be related to this problem.

    • 앞쪽의 passkey fail 현상이 있긴한데, 이거랑은 다른거려나, 이런현상을 딱히 보진못했던듯
  • In order to resolve the problem of losing high frequency information when interpolating the RoPE embeddings, the "NTK-aware" interpolation was developed in [6]. Instead of scaling every dimension of RoPE equally by a factor s, we spread out the interpolation pressure across multiple dimensions by scaling high frequencies less and low frequencies more

    • 기존 방식은 모든 dim을 균일하게 늘렸지만, NTK-aware은 high freq는 적게, low freq는 많이 늘리는 방식으로 수정
  • One can obtain such a transformation in many ways, but the simplest would be to perform a base change on the value of θ.

    • 많은 방법이 있지만 심플하게하는건 base값을 수정하는것
      image
  • 하지만 out-of-bound value가 발생해서 PI 보다 성능이 떨어지기도함

Loss of Relative Local Distances - "NTK-by-parts" interpolation

  • scale factor인 s(학습길이 대비 실제 사용길이의 비율)과 base frequency b를 늘리면 토큰들이 가까워지고 두 벡터의 dot product가 커지게됨 -> LLM이 small, local relationships을 갖는 internal embeddings을 이해하는데 안좋은 영향을 주게됨
  • dim이 특정값을 넘냐 안넘냐에 따라 base를 바꿔줌 (어렵 ㅠ)
    image

Dynamic Scaling - "Dynamic NTK" interpolation

image

  • 당연히 두번째 방법이 더 좋아보임
    • forward 할 때 그 길이에 따라 scale factor를 max함수 결과 내에서 바꿔주는 것
    • image
  • Scaling as (2), it allows the model to gracefully degrade instead of immediately breaking when hitting the trained context limit L′. We call this inference-time method the Dynamic Scaling method.
    • inference time에 하기 때문에 (길이에 따라 계속 바꿔주니) dynamic이라는 term이 붙게 된 것
    • When it is combined with "NTK-awared" interpolation, we call it "Dynamic NTK" interpolation. It first appeared in public as a reddit post in [14].
  • One notable fact is that the "Dynamic NTK" interpolation works exceptionally well on models pretrained on L without any finetuning (L′ = L). This is supported by the experiment in Appendix B.3.
    • 튜닝 없이 매우 잘되는편
  • Often in the repeated forward-passes, the kv-caching [8] is applied so that we can reuse the previous key-value vectors and improve the overall efficiency.
    • forward-pass가 여러번 일어나게 되나보니 key value vectors이 같은 값을 계속 쓰게 되서 캐싱할 수 있게됨 (vllm이나 캐싱하는 프레임워크들이 이런걸 잘하는듯)
    • We point out that in some implementations when the RoPE embeddings are cached, some care has to be taken in order to modify it for Dynamic Scaling with kv-caching. The correct implementation should cache the kv-embeddings before applying RoPE, as the RoPE embedding of every token changes when s changes.
      • 캐싱할때 스케일 바뀌니 구현도 바껴야함

YaRN

  • we also observe that introducing a temperature t on the logits before the attention softmax has a uniform impact on perplexity regardless of the data sample and the token position over the extended context window (See Appendix A.2)
  • The reparametrization of RoPE as a set of 2D matrices has a clear benefit on the implementation of this attention scaling
  • YaRN의 경우에는 attention scaling 을 추가한 기법인듯(실제로는 softmax가 적용되기 전에 추가하는 느낌)

image

  • 왜 dynamic으로 쓰진 않았을까?, 추후에 그냥 적용이 가능한건가?
    • Appendix B.3 참고
      • we observe that Dynamic Scaling effectively extend the inference length and Dynamic-YaRN achieves better performance than Dynamic-PI. The resulting chart is in Figure 5.
      • Dynamic Scaling effectively prevents the blow-up of perplexity score beyond pretrained context window;
      • Dynamic-YaRN outperforms Dynamic-PI in terms of long-range perplexity on pretrained Llama-2 without any finetuning.
      • image

Experiments

  • YaRN successfully achieves context window extension of language models using RoPE as its position embedding. Moreover, this result is achieved with only 400 training steps, representing approximately 0.1% of the model’s original pre-training corpus, a 10x reduction from Rozière et al. [31] and 2.5x reduction in training steps from Chen et al. [9], making it highly compute-efficient for training with no additional inference costs.
    • 적은 학습으로도 잘 동작하게 만들 수 있음

4.1 Training

  • For training, we extended the Llama 2 [39] 7B and 13B parameter models
    • calculation of the embedding frequencies as described in 3.4 with s = 16 and s = 32
  • We used a learning rate of 2 × 10−5 with no weight decay and a linear warmup of 20 steps along with AdamW [24] β1 = 0.9 and β2 = 0.95. For s = 16 we fine-tuned for 400 steps with global batch size 64 using PyTorch [26] Fully Sharded Data Parallelism [42] and Flash Attention 2 [13] on the PG19 dataset [29] chunked into 64k segments bookended with the BOS and EOS token.
  • For s = 32 we followed the same procedure, but started from the finished s = 16 checkpoint and trained for an additional 200 steps.
    • s가 작은건 추가로 더 학습했네?!
    • 학습이 필요한 근본적인 이유는 attention score distribution 및 rotary 셋팅 값이 달라져서인가?

4.2 Extrapolation and Transfer Learning

  • In Code Llama [31], a dataset with 16k context was used with a scale factor set to s ≈ 88.6, which corresponds to a context size of 355k.
    • 와우 355k... 128k보다 훨씬 크네
  • They show that the network extrapolates up to 100k context without ever seeing those context sizes during training
  • YaRN also supports training with a higher scale factor s than the length of the dataset. Due to compute constraints, we test only s = 32 by further fine-tuning the s = 16 model for 200 steps using the same dataset with 64k context.
  • We show in 4.3.1 that the s = 32 model successfully extrapolates up to 128k context using only 64k context during training. Unlike previous "blind" interpolation methods, YaRN is much more efficient at transfer learning when increasing the scale s. This demonstrates successful transfer learning from s = 16 to s = 32 without the network needing to relearn the interpolated embeddings, as the s = 32 model is equivalent to the s = 16 model across the entire context size, despite only being trained on s = 32 for 200 steps.
    • extrapolation도 보니까 잘되더라

4.3 Evaluation

    1. the perplexity scores of fine-tuned models with extended context window,
    1. the passkey retrieval task on fine-tuned models,
    1. the common LLM benchmark results of fine-tuned models,

4.3.1 Long Sequence Language Modeling

  • we use the GovReport [18] and Proof-pile [4] datasets both of which contain many long sequence samples.
  • All perplexity evaluations were calculated using the sliding window method from Press et al. [27] with S = 256.
    • 슬라이딩 윈도우로 ppl 측정하도록함
  • Firstly, we evaluated how the model performed as the context window increased. We selected 10 random samples from Proof-pile with at least 128k tokens each and evaluated the perplexity of each of these samples when truncated at 2k steps from a sequence length of 2k tokens through 128k tokens.

image

image

4.3.2 Passkey Retrieval

  • The passkey retrieval task as defined in [25] measures a model’s ability to retrieve a simple passkey (i.e., a five-digit number) from amongst a large amount of otherwise meaningless text.
  • we performed 10 iterations of the passkey retrieval task with the passkey placed at a random location uniformly distributed across the evaluation context window on different context window sizes ranging from 8k to 128k. Both 7b and 13b models fine-tuned using YaRN at 128k context size passes the passkey retrieval task with very high accuracy (> 99%) within the entire context window size. We show detailed results in Appendix B.2.

4.3.3 Standardized Benchmarks

  • The Hugging Face Open LLM Leaderboard
  • Specifically, we use 25-shot ARC-Challenge [11], 10-shot HellaSwag [41], 5-shot MMLU [17], and 0-shot TruthfulQA [23].
  • llama 대비 큰 하락은 없는듯
    image
  • We observe that there is minimal performance degradation between the YaRN models and their respective Llama 2 baselines. We also observe that there was on average a 0.49% drop in scores between the YaRN s = 16 and s = 32 models. From this we conclude that the the iterative extension from 64k to 128k results in negligible performance loss.

Conclusion

  • YaRN improves upon all existing RoPE interpolation methods and can act as a drop-in replacement to PI, with no downsides and minimal implementation effort.
  • Furthermore, YaRN allows efficient extrapolation with fine- tuning on shorter datasets and can take advantage of transfer learning for faster convergence, both of which are crucial under compute-constrained scenarios
  • Finally, we have shown the effectiveness of extrapolation with YaRN where it is able to "train short, and test long".

Notation 이해를 위한 설명

image - 오일러 공식을 통해 결국 복소수를 지수형태로 표현할 수 있게됨(극좌표 기준) image image
  • 솔직히 너무 어렵다 -_ ㅠ complex vector space부터ㅠㅠ,
  • Notation에서 Re는 Hermitian inner product임
    • Hermitian inner product는 복소공간에서의 내적 방법임
      • vw의 내적이면 w의 켤레복소수를 곱해서 처리함

      • image
      • 복소공간에서의 내적이기 때문에 m-n으로 표현됨 (m, n이 각각 복소수를 나타내는 i에 있고, 복소공간에서의 내적은 켤레복소수기 때문에 i에 -가 붙게됨

      • image
image

Reference

@eagle705 eagle705 self-assigned this Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant