Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Research diary: reversibles #11

Open
justheuristic opened this issue Apr 8, 2022 · 1 comment
Open

Research diary: reversibles #11

justheuristic opened this issue Apr 8, 2022 · 1 comment

Comments

@justheuristic
Copy link
Member

justheuristic commented Apr 8, 2022

This issue may or may not contain my notes about implementing and training reversble models.
From correspondence with @TimDettmers, @yhn112 , @borzunov, @mryab

Why? Reversible models are one of the few ways to fit large transformers into low-end DL rigs (single gpu, 8-12gb vram, 16-32gb ram). The alternatives are nested checkpointing [slower], checkpoint quantization [may affect convergence, still more memory], or checkpoint offloading [ram is already used up for master params and optimizer]. Hence, reversibles are the most memory-efficient training strategy if they can match the quality of regular models.

As of c076ba7 , we support reversible=True which triggers reversible transformer as it was defined in Reformer. However, this is not the only possible way. Existing alternatives are:

source: running average coupling
class MeanCoupling:
    def __init__(self, layer_index: int):
        assert layer_index > 0, "layer with index zero must be applied before reversible (e.g. embeddings)"
        self.layer_index = layer_index

    def forward(self, other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
        i = self.layer_index
        return other_stream * (i / (i + 1)) + fn_out * (1 / (i + 1))

    def inverse(self, forward_output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
        i = self.layer_index
        return (forward_output - fn_out * (1 / (i + 1))) / (i / (i + 1))

Furthermore, it is unclear how best to use reversible's two inputs and two outputs with transformers. If we cannot afford to double the layer sizes, this raises the following questions:

  • which sub-branch of the reversible modules is more informative? X, Y, X+Y, X+Y-common_input, some learned gates?
  • how best to instantiate the two inputs: both equal or set one to zeros?

Finally, there are implementation hacks that can affect the training throughput and memory requirements:

  • couplings are elementwise operations: should we apply jit-script?
  • FFN layer can be computed a few vectors at a time (over batch size and sequence length), originally proposed in Reformer
  • Attention layer can be computed one head at a time - since it's additive w.r.t. heads - but that may be difficult with some sparse/lowrank projections.
@justheuristic justheuristic changed the title Reversibles Research diary: reversibles Apr 8, 2022
@justheuristic
Copy link
Member Author

justheuristic commented Apr 11, 2022

Research setting:

  • data: openwebtext, preprocessed with instructions, see data cooking log
  • model: equivalent of GPT-3 medium (tested learning curve match with transformer lm) + rotary embeddings
  • training config: based on GPT-3 paper hparams, see details at the end. All model / training hyperparams are kept equal except for reversible type.

Limitations:

  • some runs were stopped prematurely due to poor performance. There is a small but nonzero chance they could recover.
  • we are using an X+Y-Initial hack for reformer, which is standard, but has not been a part of the original paper.
  • we did not tune hyperparameters individually for each model. There is a chance that some reversibles need different hparams.
  • we used beta=0.9 for momentum-based models (default from repo and paper) and did not tune it.

Results:

tl;dr reformer is still the best
image

  • baseline = GPT3-medium + rotary
  • baseline-rev = GPT3-medium with reformer-like layer order - closest, but still marginally worse than baseline. The margin is ~1/5 that of switching to GPT3-XL (double hidden size and heads).
  • baseline-rev-aaff = GPT3-medium with reformer-like layer order, but using [Attn, Attn, FFN, FFN] layers instead of [Attn, FFN, Attn, FFN], so that each of two reversible-residual branches gets both attention and FFN outputs. Starts slightly worse, and never catches up.
  • baseline-rev0.9 = momentum-reversible, as defined in https://github.com/HomebrewNLP/revlib/blob/040f076a722d4d3c5c4877bc54658f9a3f734489/revlib/utils.py#L176 . Using the Y branch as model outputs (aka "the moving average" branch). Trains significantly slower than reformer. Killed after 5k steps.
  • baseline-rev0.9-out0 same as previous row, but using the X branch as model outputs (aka the cummulative sum of moving averages). Began training faster, but diverged just after passing 4k steps. Hypothesis: revlib default momentum may have numeric instability because of the 1 / beta ** num_layers term. Will investigate the alternative momentumnet implementations to eliminate this hypothesis.
  • baseline-rev-average - simple running average, using the MeanCoupling code from the first message. Trains significantly slower than the baseline. Killed after 3k steps.

Validation loss closely tracks the training loss.
image

Gradient norm supports the hypothesis for numerical instability of momentum-reversible models
image

Details

Model config
{
  "model_type": "lean_gpt",
  "architectures": [ "LeanGPTModel" ],
  "num_hidden_layers": 24,
  "num_hidden_groups": 24,
  "num_inner_groups": 1,
  "hidden_size": 1024,
  "embedding_size": 1024,
  "intermediate_size": 4096,
  "num_attention_heads": 16,
  "vocab_size": 50308,
  "hidden_act": "gelu_fused",
  "position_embedding_type": "rotary",
  "tie_word_embeddings": false,
  "reversible": YOUR_OPTION_HERE,
  "hidden_dropout_prob": 0,
  "attention_probs_dropout_prob": 0,
  "layer_norm_eps": 1e-12,
  "pad_token_id": 1,
  "bos_token_id": 0,
  "eos_token_id": 2
}
Training code Using a slightly modified fairseq training: https://github.com/justheuristic/junk/tree/fairseq #9f5bff306ee93780e0c9483162dd24e244403919 . The only difference = it supports training with LeanTransformer. Was validated to match the learning curve of fairseq transformer.
PYTHONPATH=`pwd`:$PYTHONPATH python fairseq_cli/train.py \
    $INPUT_PATH/data-bin/openwebtext --task language_modeling --arch lean_lm --hf-model-config $SOURCE_CODE_PATH/model_config.json \
    --max-tokens 32768 --update-freq 4 --max-update 50000 --tokens-per-sample 2048 --sample-break-mode none \
    --ddp-backend pytorch_ddp --distributed-world-size $NUM_GPUS --seed 4 \
    --amp --fp16-no-flatten-grads --min-loss-scale 1e-10 --fp16-scale-window 250 \
    --lr-scheduler cosine --lr 0.0003 --warmup-init-lr 0.0 --warmup-updates 5000 \
    --optimizer adam --weight-decay 0.1 --clip-norm 1.0 --adam-betas "(0.9, 0.95)" --adam-eps 1e-08 \
    --save-dir $SNAPSHOT_PATH --save-interval-updates 1000 --keep-best-checkpoints 1 --no-epoch-checkpoints --keep-interval-updates 2 \
    --valid-subset valid,valid_1b,valid_lambada,valid_ccnews,valid_wiki,valid_wiki2,valid_ptb --validate-interval-updates 1000 \
    --log-format simple --log-interval 50 --wandb-project $WANDB_PROJECT
Libraries & versions
#!/usr/bin/env bash
set -euxo pipefail

############################################################################
# core libraries
############################################################################

apt-get update --allow-unauthenticated --allow-insecure-repositories

apt-get install -y --no-install-recommends \
    build-essential \
    g++ gdb subversion \
    software-properties-common

apt-get install -y --no-install-recommends \
    wget curl vim nano ssh git libssl-dev

apt-get remove -y swig || true
apt-get install -y --no-install-recommends libstdc++6
apt-get install -y --no-install-recommends swig3.0
ln -s /usr/bin/swig3.0 /usr/bin/swig

############################################################################
# install anaconda (because native python stopped working
############################################################################

wget https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh
bash Anaconda3-2021.11-Linux-x86_64.sh -b -p /anaconda3
source /anaconda3/bin/activate

############################################################################
# common python libraries (project specfic libs these are installed later)
############################################################################

conda update -y conda
conda install -y python=3.8.12 --strict-channel-priority
conda install -y numpy scipy cython pandas h5py numba
pip install --upgrade setuptools

# common + devops
pip install \
    PyYAML==5.4.1 \
    Pillow==8.3.0 \
    docopt==0.6.2 \
    typer==0.3.2 \
    black==21.6b0 \
    bokeh==2.4.0dev1 \
    isort==5.9.1 \
    icecream==2.1.1 \
    flake8==3.9.2 \
    uvloop==0.15.2 \
    packaging==19.0 \
    msgpack==0.5.6 \
    sortedcontainers==2.4.0 \
    configargparse==1.2.3 \
    tqdm==4.48.2 \
    termcolor==1.0.0

# common data science libs
pip install \
    ninja==1.10.0.post1 \
    tensorboardX==2.4 \
    wandb==0.10.33 \
    matplotlib==3.4.2 \
    seaborn==0.11.1 \
    holoviews==1.14.4 \
    plotly==5.1.0 \
    jupyterlab==3.0.16

# pytorch utils
conda install -y cudatoolkit=11.3 -c pytorch
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113

pip install https://github.com/huggingface/transformers/archive/3dc82427166239e2764196c07fa4c5dcc25b1590.zip # 4.18.dev0
pip install datasets==2.0.0

pip install \
    torch_optimizer==0.1.0 \
    revlib==1.7.0 \
    bitsandbytes-cuda113==0.26.0 \
    pytorch-lightning==1.3.8 \
    triton==1.0.0 \
    einops==0.3.2 \
    libzero==0.0.5

# domain-specific ML libs
pip install \
    opencv-python==4.4.0.42 \
    albumentations==1.0.0 \
    scikit-image==0.17.2 \
    lmdb==1.2.1 \
    librosa==0.7.0 \
    sentencepiece==0.1.96 \
    nltk==3.6.2 \
    gensim==4.0.1 \
    sacrebleu==1.5.1 \
    sacremoses==0.0.45 \
    subword-nmt==0.3.7 \
    youtokentome==1.0.6

pip uninstall -y enum34

############################################################################
# Set locale
############################################################################
locale-gen ru_RU.UTF-8
update-locale

############################################################################
# Clean
############################################################################
apt-get autoremove
apt-get clean
apt-get autoclean
rm -rf /var/lib/apt/lists/*
rm -rf /tmp/*
rm -rf /.cache
rm -rf /var/cache/apt/*.bin
find /var/log -iname '*.gz' -delete
find /var/log -iname '*.1' -delete

###########################################################################
# project-specific libraries (aka YOUR CODE HERE)
###########################################################################

# hivemind dependencies
pip install \
    prefetch_generator>=1.0.1 \
    grpcio>=1.33.2 \
    grpcio-tools>=1.33.2 \
    multiaddr>=0.0.9 \
    pymultihash>=0.8.2 \
    cryptography>=3.4.6 \
    pydantic>=1.8.1 \
    whatsmyip

pip install razdel==0.5.0


# golang
wget https://golang.org/dl/go1.16.4.linux-amd64.tar.gz
rm -rf /usr/local/go && tar -C /usr/local -xzf go1.16.4.linux-amd64.tar.gz
export PATH=$PATH:/usr/local/go/bin

pip install omegaconf==2.0.5 antlr4-python3-runtime==4.8 hydra-core==1.0.7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant