Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LongNet: Scaling Transformers to 1,000,000,000 Tokens #30

Open
eagle705 opened this issue Jul 10, 2023 · 0 comments
Open

LongNet: Scaling Transformers to 1,000,000,000 Tokens #30

eagle705 opened this issue Jul 10, 2023 · 0 comments
Assignees

Comments

@eagle705
Copy link
Owner

eagle705 commented Jul 10, 2023

Author

  • Jiayu Ding∗ Shuming Ma∗ Li Dong Xingxing Zhang Shaohan Huang Wenhui Wang Furu Wei†

Summary

  • dilated attention으로 계산 효율화 + 시퀀스 병렬 계산 패러다임 제시
  • ppl이 낮아지는걸 보였지만 task로 평가는 해보지 않음
  • scaling law를 따른다라고 주장했지만 scale을 크게 키우진 않음

Abstract

  • LONGNET, a Transformer variant that can scale sequence length to more than 1 billion tokens, without sacrificing the performance on shorter sequences🤔
  • propose dilated attention
      1. it has a linear computation complexity and a logarithm dependency between tokens;
      1. it can be served as a distributed trainer for extremely long sequences;
      1. its dilated attention is a drop-in replacement for standard attention, which can be seamlessly integrated with the existing Transformer-based optimization.

image

Introduction

  • 기존 연구들
    • RNN-style models are primarily implemented to increase the length. However, its sequential nature limits the parallelization during training, which is essential in long-sequence modeling.
    • More recently, state space models [GGR22, SWL23, FDS+23, PMN+23] are appealing to sequence modeling. It can operate as a CNN during training, and transform to an efficient RNN at test time. While they perform well at long-range benchmarks [TDA+21], their performance on regular lengths is not as good as Transformers, limited mainly by the model expressivity [FPB+23].
    • Another strand of scaling the sequence length is to decrease the complexity of Transformers
      • Nevertheless, this sacrifices the ability to recall the early tokens, forgetting the prompts at the very beginning of the sequence. Sparse attention reduces the computation by sparsifying the attention matrix, preserving the possibility of recalling long-distant information.
    • Transformer-based variants, including low-rank attention [WLK+20, WCL+20], kernel-based methods [KVPF20, CLD+21, QHS+22], downsampling approaches [LLK+19, JGB+21, MKW+21], recurrent models and retrieval-based methods

image

  • Our solution is LONGNET, which replaces the attention of vanilla Transformers with a novel component named dilated attention.
    • attention allocation decreases exponentially as the distance between tokens grows.
    • In the implementation, LONGNET can be transformed into a dense Transformer, which seamlessly supports the off-the-shelf optimization for Transformers (e.g., kernel fusion, quantization, and distributed training).
    • Taking advantage of the linear complexity, LONGNET can parallelize the training across nodes, breaking the constraint of both computation and memory with a distributed algorithm.

image

LongNet

preliminary figure
image image
Dislated Attention figure
image image
image image image

Complexity 계산

수식 설명
image image

LONGNET as a Distributed Trainer: Scaling up to 1B Tokens

  • Although the computation complexity of dilated attention has been greatly reduced to O(Nd), it is infeasible to scale the sequence length to the million level on a single GPU device due to the computation and memory constraints

image
image

  • 세그먼트의 길이가 device에서 지원하는 length보다 짧으면 local attention 계산
  • 세그먼트 길이가 device에서 지원하는 length보다 길면 all-gather로 가져와서 계산
  • 마지막으로 local queries과 all-gather로 가져왔던 global key-value pairs를 위해 cross-attention 계산

image

  • 단계에 따라 all-gather후 다시 split
  • backward때는 all-gather operation이 reduce-scatter이 됨

Scaling up to 1B Tokens

  • Starting from 8K, we gradually scale the sequence length until the limit of GPU memory.
  • We reduce the batch size accordingly to keep the number of tokens per batch at 1 billion.
  • Each model of different sequence lengths has up to 3 segment lengths, which are 2,048, the number of tokens per device

image

Experiments on Language Modeling

Setup

  • The backbone architecture is MAGNETO [WMH+22] with XPOS [SDP+22] relative position encoding, except that we replace the standard attention with our dilated attention
  • pre-train the model with The Stack dataset
  • The data is preprocessed with the tiktoken tokenizer with cl100k_base encoding
  • All experiments are conducted based on the torchscale [MWH+22] codebase.

What is MAGNETO?

image

image

Results

  • For LONGNET, we use segment lengths of w = {2048,4096,8192,16384,32768}, and the dilated ratios are r = {1,2,4,6,12}
  • All of our implementations of attention variants are based on FlashAttention for training efficiency

image

Scaling up Model Size

  • To verify whether LONGNET still follows the similar scaling law, we train a series of models with different model sizes, from 125 million(40B tokens) to 2.7 billion parameters(300B tokens)
  • It proves that LONGNET can still follow the power law. This implies that the dense Transformer is not a prerequisite for scaling the language models. Additionally, the scalability and the efficiency are both obtained by LONGNET.
    image

Conclusion

  • The core of LONGNET is dilated attention, which reduces the computation complexity from quadratic to linear.
  • LONGNET can be served as a distributed trainer that parallelizes the training of a sequence across multiple GPU devices
@eagle705 eagle705 self-assigned this Jul 10, 2023
@eagle705 eagle705 changed the title LONGNET: Scaling Transformers to 1,000,000,000 Tokens LongNet: Scaling Transformers to 1,000,000,000 Tokens Jul 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant