diff --git a/Dockerfile_rocm.ci b/Dockerfile_rocm.ci new file mode 100644 index 0000000000..253b446a29 --- /dev/null +++ b/Dockerfile_rocm.ci @@ -0,0 +1,70 @@ +ARG BASE_DOCKER=rocm/pytorch:latest +FROM $BASE_DOCKER + +ARG PYTORCH_ROCM_ARCH_OVERRIDE="gfx942" +ENV WORKSPACE_DIR=/workspace +ENV STAGE_DIR=/workspace/installs +RUN mkdir -p $WORKSPACE_DIR +RUN mkdir -p ${STAGE_DIR} +WORKDIR $WORKSPACE_DIR + +RUN pip3 install \ +scipy \ +einops \ +flask-restful \ +nltk \ +pytest \ +pytest-cov \ +pytest_mock \ +pytest-csv \ +pytest-random-order \ +sentencepiece \ +wrapt \ +zarr \ +wandb \ +tensorstore==0.1.45 \ +pytest_mock \ +pybind11 \ +setuptools==69.5.1 \ +datasets \ +tiktoken \ +pynvml + +RUN pip3 install "huggingface_hub[cli]" +RUN python3 -m nltk.downloader punkt_tab + + +# Install Causal-Conv1d and its dependencies +WORKDIR ${STAGE_DIR} +ENV CAUSAL_CONV1D_FORCE_BUILD=TRUE +ENV MAMBA_FORCE_BUILD=TRUE +ENV HIP_ARCHITECTURES=${PYTORCH_ROCM_ARCH_OVERRIDE} +RUN git clone https://github.com/Dao-AILab/causal-conv1d causal-conv1d &&\ + cd causal-conv1d &&\ + git show --oneline -s &&\ + pip install . + +# Install mamba +WORKDIR ${STAGE_DIR} +RUN git clone https://github.com/state-spaces/mamba mamba &&\ + cd mamba &&\ + git show --oneline -s &&\ + pip install --no-build-isolation . + +# Clone TE repo and submodules +WORKDIR ${STAGE_DIR} +ENV NVTE_FRAMEWORK=pytorch +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH_OVERRIDE} +ENV NVTE_USE_HIPBLASLT=1 +RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git &&\ + cd TransformerEngine &&\ + pip install . + +WORKDIR $WORKSPACE_DIR +COPY . Megatron-LM +WORKDIR $WORKSPACE_DIR/Megatron-LM +RUN pip install -e . + +# record configuration for posterity +RUN pip list + diff --git a/Dockerfile_rocm.dev b/Dockerfile_rocm.dev new file mode 100644 index 0000000000..d253193b67 --- /dev/null +++ b/Dockerfile_rocm.dev @@ -0,0 +1,72 @@ +ARG BASE_DOCKER=rocm/pytorch:latest +FROM $BASE_DOCKER +ARG PYTORCH_ROCM_ARCH_OVERRIDE="gfx942" +ENV WORKSPACE_DIR=/workspace +ENV STAGE_DIR=/workspace/installs +RUN mkdir -p $WORKSPACE_DIR +RUN mkdir -p ${STAGE_DIR} +WORKDIR $WORKSPACE_DIR + +RUN pip3 install \ +scipy \ +einops \ +flask-restful \ +nltk \ +pytest \ +pytest-cov \ +pytest_mock \ +pytest-csv \ +pytest-random-order \ +sentencepiece \ +wrapt \ +zarr \ +wandb \ +tensorstore==0.1.45 \ +pytest_mock \ +pybind11 \ +setuptools==69.5.1 \ +datasets \ +tiktoken \ +pynvml + +RUN pip3 install "huggingface_hub[cli]" +RUN python3 -m nltk.downloader punkt_tab + + +# Install Causal-Conv1d and its dependencies +WORKDIR ${STAGE_DIR} +ENV CAUSAL_CONV1D_FORCE_BUILD=TRUE +ENV MAMBA_FORCE_BUILD=TRUE +ENV HIP_ARCHITECTURES=${PYTORCH_ROCM_ARCH_OVERRIDE} +RUN git clone https://github.com/Dao-AILab/causal-conv1d causal-conv1d &&\ + cd causal-conv1d &&\ + git show --oneline -s &&\ + pip install . + +# Install mamba +WORKDIR ${STAGE_DIR} +RUN git clone https://github.com/state-spaces/mamba mamba &&\ + cd mamba &&\ + git show --oneline -s &&\ + pip install --no-build-isolation . + +# Clone TE repo and submodules +WORKDIR ${STAGE_DIR} +ENV NVTE_FRAMEWORK=pytorch +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH_OVERRIDE} +ENV NVTE_USE_HIPBLASLT=1 +RUN git clone --recursive https://github.com/ROCm/TransformerEngine &&\ + cd TransformerEngine &&\ + pip install . + +WORKDIR $WORKSPACE_DIR +RUN git clone https://github.com/ROCm/Megatron-LM.git Megatron-LM &&\ + cd Megatron-LM &&\ + git checkout rocm_dev &&\ + pip install -e . + +WORKDIR $WORKSPACE_DIR/Megatron-LM + +# record configuration for posterity +RUN pip list + diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100755 index 0000000000..49c2417d61 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,92 @@ +import org.apache.commons.io.FilenameUtils +import groovy.json.JsonOutput + + +def clean_up_docker_images() { + // Check if the images exist before attempting to remove them + def imageExists = sh(script: "docker images -q ${env.DOCKER_IMAGE}", returnStdout: true).trim() + if (imageExists) { + sh "docker rmi ${env.DOCKER_IMAGE}" + } +} + +def clean_docker_build_cache() { + sh 'docker system prune -f --volumes || true' +} + +pipeline { + agent { + label 'build-only' + } + + parameters { + string(name: 'TEST_NODE_LABEL', defaultValue: 'MI300X_BANFF', description: 'Node or Label to launch Jenkins Job') + string(name: 'GPU_ARCH', defaultValue: 'gfx942', description: 'GPU Architecture') + } + + environment { + REPO_NAME = 'rocm/megatron-lm-private' + CONTAINER_NAME = "megatron-lm-container" + DOCKER_RUN_ARGS = "-v \$(pwd):/workspace/Megatron-LM/output --workdir /workspace/Megatron-LM \ + --entrypoint /workspace/Megatron-LM/run_unit_tests.sh" + DOCKER_RUN_CMD = "docker run --rm -t --network host -u root --group-add video --cap-add=SYS_PTRACE \ + --cap-add SYS_ADMIN --device /dev/fuse --security-opt seccomp=unconfined --security-opt apparmor=unconfined \ + --ipc=host --device=/dev/kfd --device=/dev/dri" + } + + stages { + stage('Build Docker Image') { + steps { + clean_docker_build_cache() + script { + + // Generate a unique UUID for the Docker image name + def uuid = sh(script: 'uuidgen', returnStdout: true).trim() + env.DOCKER_IMAGE = "${REPO_NAME}:${uuid}" + + // Build Docker image + sh "docker build --no-cache -f Dockerfile_rocm.ci --build-arg PYTORCH_ROCM_ARCH_OVERRIDE=${params.GPU_ARCH} -t ${env.DOCKER_IMAGE} ." + + withCredentials([usernamePassword(credentialsId: 'docker-hub-credentials', usernameVariable: 'DOCKER_USERNAME', passwordVariable: 'DOCKER_PASSWORD')]) { + sh "docker push ${env.DOCKER_IMAGE}" + } + } + } + post { + always { + clean_up_docker_images() + } + } + } + + stage('Run Unit Tests') { + agent { + node { + label "${params.TEST_NODE_LABEL}" + } + } + + steps { + script { + // Pull the Docker image from the repository on the test node + withCredentials([usernamePassword(credentialsId: 'docker-hub-credentials', usernameVariable: 'DOCKER_USERNAME', passwordVariable: 'DOCKER_PASSWORD')]) { + sh "docker pull ${env.DOCKER_IMAGE}" + } + + wrap([$class: 'AnsiColorBuildWrapper', 'colorMapName': 'xterm']) { + sh "${DOCKER_RUN_CMD} ${DOCKER_RUN_ARGS} --name ${env.CONTAINER_NAME} ${env.DOCKER_IMAGE}" + } + } + } + post { + always { + // Archive test results + script { + archiveArtifacts artifacts: 'test_report.csv', allowEmptyArchive: true + clean_up_docker_images() + } + } + } + } + } +} diff --git a/examples/llama/prepare_bookcorpus_megatron_dataset.py b/examples/llama/prepare_bookcorpus_megatron_dataset.py new file mode 100755 index 0000000000..449d41dfa7 --- /dev/null +++ b/examples/llama/prepare_bookcorpus_megatron_dataset.py @@ -0,0 +1,14 @@ +import argparse +from pathlib import Path +from datasets import load_dataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=str, required=False, default="tmp/data", + help="Path to output JSON") + args = parser.parse_args() + out_dir = Path(args.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + dataset = load_dataset("bookcorpus", split="train", trust_remote_code=True) + dataset.to_json(out_dir / "bookcorpus_megatron.json") \ No newline at end of file diff --git a/examples/llama/prepare_dataset.sh b/examples/llama/prepare_dataset.sh new file mode 100755 index 0000000000..8f16ea0fda --- /dev/null +++ b/examples/llama/prepare_dataset.sh @@ -0,0 +1,18 @@ +TMP_DIR="tmp" +mkdir -p $TMP_DIR +mkdir -p ${TMP_DIR}/data + +DATA_PATH="${TMP_DIR}/data" +TOKENIZER_MODEL=${TMP_DIR}/tokenizer.model + +# Download the tokenizer model +if ! [ -f "$TOKENIZER_MODEL" ]; then +wget -O $TOKENIZER_MODEL https://huggingface.co/NousResearch/Llama-2-7b-chat-hf/resolve/main/tokenizer.model +fi + +python3 prepare_bookcorpus_megatron_dataset.py --out-dir ${DATA_PATH} +python3 tools/preprocess_data.py --input ${DATA_PATH}/bookcorpus_megatron.json --tokenizer-type GPTSentencePieceTokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} --output-prefix ${DATA_PATH}/bookcorpus --workers `nproc` --split-sentences + +python3 tools/preprocess_data.py --input ${DATA_PATH}/bookcorpus_megatron.json --tokenizer-type GPTSentencePieceTokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} --output-prefix ${DATA_PATH}/bookcorpus --workers `nproc` --split-sentences diff --git a/examples/llama/readme.md b/examples/llama/readme.md new file mode 100644 index 0000000000..8ceb357d09 --- /dev/null +++ b/examples/llama/readme.md @@ -0,0 +1,131 @@ +# Llama2/Llama3 Model Pretraining Instructions + +This guide provides the steps for setting up the environment and configuring the script to train Llama2 or Llama3 models. + +--- + +## 1. Environment Setup + +1. **Download Docker Image** + Download the Docker image required for training: + `docker pull ` + +2. **Launch Docker Container** + Start the Docker container: + `docker run -it ` + +--- + +## 2. How to Run + +### 2.1 Single Node Training +To run the training on a single node, go to Megatron-LM folder, use the following command: +```bash +TEE_OUTPUT=1 MBS=2 BS=64 TP=8 TE_FP8=0 SEQ_LENGTH=4096 bash examples/llama/train_llama2.sh +``` + + +### 2.2 Multi-node Training +To run training on multiple nodes, launch the Docker container on each node. Follow these steps: + +- **On the Master Node:** + ```bash + TEE_OUTPUT=1 MBS=2 BS=64 TP=8 TE_FP8=0 SEQ_LENGTH=4096 bash examples/llama/train_llama2.sh + ``` + +- **On the Slave Node(s):** + ```bash + TEE_OUTPUT=1 MBS=2 BS=64 TP=8 TE_FP8=0 SEQ_LENGTH=4096 bash examples/llama/train_llama2.sh + ``` + +## 3. Configurations in Script (`Megatron/examples/llama`) + +### 3.1 Network Interface +Update the network interface in the script to match your system’s network interface. +To find your network interface, run (out of container): +```bash +ip a +``` +Then, update the following variables in the script: +```bash +export NCCL_SOCKET_IFNAME=ens50f0np0 +export GLOO_SOCKET_IFNAME=ens50f0np0 +``` + +### 3.2 Dataset +You can use either mock data or real data for training. + +- **Mock Data:** + Replace the data path: + ```bash + --data-path $DATA_PATH \ with + --mock-data + ``` + +- **Real Data:** + Update the `DATA_PATH` to the location where your dataset is stored: + ```bash + DATA_DIR="/root/.cache/data" # Change to where your dataset is stored + DATA_PATH=${DATA_DIR}/bookcorpus_text_sentence + ``` + +### 3.3 Tokenizer + +- **For Llama2 Training:** + Use the `Llama2Tokenizer`. + +- **For Llama3 Training:** + Use the `HuggingFaceTokenizer`. Set the HuggingFace model link in the `TOKENIZER_MODEL` variable: + ```bash + TOKENIZER_MODEL=meta-llama/Llama-3.1-8B # For Llama3 + ``` + +### 3.4 Multi-node Training +If you're running multi-node training, update the following environment variables: + +- **Master Address:** + Change `localhost` to the master node's hostname: + ```bash + MASTER_ADDR="${MASTER_ADDR:-localhost}" + ``` + +- **Number of Nodes:** + Set the number of nodes you want to train on (e.g., 2, 4, 8): + ```bash + NNODES="${NNODES:-1}" + ``` + +- **Node Rank:** + Set the rank of each node (0 for master, 1 for the first slave node, etc.): + ```bash + NODE_RANK="${NODE_RANK:-0}" + ``` + +--- + +## 4. Key Variables to Pay Attention To + +- **TE_FP8:** + `0` for BP16 (default), `1` for FP8. + +- **GEMM_TUNING:** + `1` to enable GEMM tuning, which boosts performance by using the best GEMM kernels. + +- **USE_FLASH_ATTN:** + `1` to enable Flash Attention. + +- **ENABLE_PROFILING:** + `1` to enable PyTorch profiling for performance analysis. + +- **transformer-impl:** + `transformer_engine` to use the Transformer Engine (TE). Set to `local` if you want to disable TE. + +- **MODEL_SIZE:** + Set to `7B` or `70B` for Llama2, or `8B` or `70B` for Llama3/3.1. + +- **TOTAL_ITERS:** + Set the total number of iterations (default: 10). + +--- + +That's it! You've now set up the environment and configured the necessary settings for training Llama2 or Llama3 models. diff --git a/examples/llama/train_llama2.sh b/examples/llama/train_llama2.sh new file mode 100644 index 0000000000..6a8d6b71d5 --- /dev/null +++ b/examples/llama/train_llama2.sh @@ -0,0 +1,312 @@ +#!/bin/bash +############################################################################### +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +################################################################################# +# set -x + +# set envs +export GPU_MAX_HW_QUEUES=2 +export TORCH_NCCL_HIGH_PRIORITY=1 +export NCCL_CHECKS_DISABLE=1 +export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 +export NCCL_IB_GID_INDEX=3 +export NCCL_CROSS_NIC=0 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_PROTO=Simple +export RCCL_MSCCL_ENABLE=0 +export TOKENIZERS_PARALLELISM=false +export HSA_NO_SCRATCH_RECLAIM=1 + + +# parsing input arguments +for ARGUMENT in "$@" +do + KEY=$(echo $ARGUMENT | cut -f1 -d=) + + KEY_LENGTH=${#KEY} + VALUE="${ARGUMENT:$KEY_LENGTH+1}" + + export "$KEY"="$VALUE" +done + + +TIME_STAMP=$(date +"%Y-%m-%d_%H-%M-%S") +EXP_NAME="${EXP_NAME:-perf}" + +TEE_OUTPUT="${TEE_OUTPUT:-1}" +USE_FLASH_ATTN="${USE_FLASH_ATTN:-1}" +NO_TRAINING="${NO_TRAINING:-0}" # NO_TRAINING=1: for computing metrics only +ENABLE_PROFILING="${ENABLE_PROFILING:-0}" #enable pytorch profiling +ENABLE_ROPE="${ENABLE_ROPE:-1}" +DISABLE_ROPE_TE="${DISABLE_ROPE_TE:-0}" +echo "NO_TRAINING=$NO_TRAINING" + +CWD=`pwd` +GPUS_PER_NODE=`python3 -c "import torch; print(torch.cuda.device_count())"` + +# single node config, Change for multinode config +MASTER_ADDR="${MASTER_ADDR:-localhost}" +#MASTER_ADDR="${MASTER_ADDR:-tw015}" +MASTER_PORT="${MASTER_PORT:-6020}" +NNODES="${NNODES:-1}" +#NNODES="${NNODES:-2}" +NODE_RANK="${NODE_RANK:-0}" +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +if [ "${NNODES:-1}" -gt 1 ]; then + export NCCL_SOCKET_IFNAME="${NCCL_SOCKET_IFNAME:-ens5}" + export GLOO_SOCKET_IFNAME="${GLOO_SOCKET_IFNAME:-ens50f0}" + echo "NCCL and GLOO socket interfaces set." +else + echo "Single node setup, skipping NCCL and GLOO socket interface settings." +fi + +MODEL_SIZE="${MODEL_SIZE:-70}" +TP="${TP:-8}" +PP="${PP:-1}" +CP="${CP:-1}" +MBS="${MBS:-2}" +BS="${BS:-8}" +SEQ_LENGTH="${SEQ_LENGTH:-4096}" +TOTAL_ITERS="${TOTAL_ITERS:-5}" +SEQ_PARALLEL="${SEQ_PARALLEL:-1}" +CONTI_PARAMS="${CONTI_PARAMS:-0}" +TE_FP8="${TE_FP8:-0}" # 0: disable FP8, 1: enable FP8 +GEMM_TUNING="${GEMM_TUNING:-1}" +MCORE="${MCORE:-1}" + +EXPERIMENT_DIR="experiment" +mkdir -p $EXPERIMENT_DIR +CHECKPOINT_PATH=${CHECKPOINT_PATH:-"$EXPERIMENT_DIR/ckpts"} + + +DATA_DIR="${DATA_DIR:-/root/.cache/data}" +DATA_PATH=${DATA_PATH:-"$DATA_DIR/bookcorpus_text_sentence"} + +TOKENIZER_MODEL=$EXPERIMENT_DIR/tokenizer.model +# Download the tokenizer model +if ! [ -f "$TOKENIZER_MODEL" ]; then +wget -O $TOKENIZER_MODEL https://huggingface.co/NousResearch/Llama-2-7b-chat-hf/resolve/main/tokenizer.model +fi + +MAX_POSITION_EMBEDDINGS=128000 + +DEFAULT_LOG_DIR="${EXPERIMENT_DIR}/${NNODES}nodes_rank${NODE_RANK}_train_${MODEL_SIZE}B_mbs${MBS}_bs${BS}_tp${TP}_pp${PP}_cp${CP}_iter${TOTAL_ITERS}/TE_FP8_${TE_FP8}/${TIME_STAMP}" +LOG_DIR="${LOG_DIR:-${DEFAULT_LOG_DIR}}" +TRAIN_LOG="${LOG_DIR}/output_${EXP_NAME}.log" +mkdir -p $LOG_DIR +echo $TRAIN_LOG + +# gemm tuning +if [ "$GEMM_TUNING" -eq 1 ]; then + export TE_HIPBLASLT_TUNING_RUN_COUNT=10 + export TE_HIPBLASLT_TUNING_ALGO_COUNT=50 +fi + +if [ "$SEQ_LENGTH" -le 8192 ]; then + ds_works=8 +else + ds_works=24 +fi + +if [[ $MODEL_SIZE -eq 7 ]]; then #llama2-7B + HIDDEN_SIZE=4096 # e.g. llama-13b: 5120 + FFN_HIDDEN_SIZE=14336 # e.g. llama-13b: 13824 + NUM_LAYERS=32 # e.g. llama-13b: 40 + NUM_HEADS=32 # e.g. llama-13b: 40 + SEQ_LENGTH=$SEQ_LENGTH + NUM_KV_HEADS=8 # llama2 70B uses GQA +elif [[ $MODEL_SIZE -eq 70 ]]; then + HIDDEN_SIZE=8192 # e.g. llama-13b: 5120 + FFN_HIDDEN_SIZE=28672 # e.g. llama-13b: 13824 + NUM_LAYERS=80 # e.g. llama-13b: 40 + NUM_HEADS=64 # e.g. llama-13b: 40 + NUM_KV_HEADS=8 # llama3 70B uses GQA + SEQ_LENGTH=$SEQ_LENGTH + MAX_POSITION_EMBEDDINGS=$MAX_POSITION_EMBEDDINGS +else + echo "Model size not supported." + exit 1 +fi + +GROUP_SIZE=$(( ${NUM_HEADS} / ${NUM_KV_HEADS} )) +NUM_GROUPS=$(( ${NUM_HEADS} / ${GROUP_SIZE} )) + +PROFILING_DIR="${LOG_DIR}/trace_${EXP_NAME}" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --context-parallel-size ${CP} \ + --num-layers $NUM_LAYERS \ + --hidden-size $HIDDEN_SIZE \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --num-attention-heads $NUM_HEADS \ + --seq-length $SEQ_LENGTH \ + --max-position-embeddings $MAX_POSITION_EMBEDDINGS \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --no-position-embedding \ + --disable-bias-linear \ + --swiglu \ + --init-method-std 0.02 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --normalization RMSNorm \ + --micro-batch-size $MBS \ + --global-batch-size $BS \ + --train-iters $TOTAL_ITERS \ + --no-async-tensor-model-parallel-allreduce \ + --bf16 \ + --no-masked-softmax-fusion \ + --disable-bias-linear \ +" + +TRAIN_ARGS="--lr 1e-4 \ + --min-lr 1e-5 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --weight-decay 1.0e-1 \ + --clip-grad 1.0 \ + --optimizer adam \ +" +DATA_ARGS=" + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --dataloader-type cyclic \ + --save-interval 200000 \ + --tensorboard-dir $LOG_DIR \ + --log-interval 1 \ + --eval-interval 320000 \ + --eval-iters 10 \ + --num-workers $ds_works \ + --mock-data +" +# --data-path $DATA_PATH \ +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 5000 \ + --log-throughput \ + --no-save-optim \ + --eval-iters -1 +" +# --save $CHECKPOINT_PATH \ + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ +" + +CKPT_LOAD_ARGS="--exit-on-missing-checkpoint \ + --no-load-optim \ + --use-checkpoint-args \ + --no-load-rng" + + +EXTRA_ARGS=" + --group-query-attention \ + --num-query-groups $NUM_GROUPS \ + --no-gradient-accumulation-fusion \ + --distributed-backend nccl \ + --distributed-timeout-minutes 120 \ + --use-distributed-optimizer \ + --overlap-param-gather \ + --overlap-grad-reduce \ +" + +if [ "$ENABLE_PROFILING" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --profile --use-pytorch-profiler --tensorboard-dir $LOG_DIR" +fi + +if [ "$USE_FLASH_ATTN" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --use-flash-attn" +fi + +if [ "$SEQ_PARALLEL" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --sequence-parallel" +fi + +if [ "$CONTI_PARAMS" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --use-contiguous-parameters-in-local-ddp" +fi + +if [ "$MCORE" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --use-mcore-models" +fi + +if [ "$ENABLE_ROPE" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --position-embedding-type rope" +fi + +if [ "$DISABLE_ROPE_TE" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --disable-te-fused-rope" +fi + +if [ "$TE_FP8" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --transformer-impl=transformer_engine \ + --fp8-margin=0 \ + --fp8-format=hybrid \ + --fp8-interval=1 \ + --fp8-amax-history-len=1024 \ + --fp8-amax-compute-algo=max \ + --attention-softmax-in-fp32 \ +" +fi + +run_cmd=" + torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $EXTRA_ARGS \ + $TRAIN_ARGS \ +" + +if [ "$TEE_OUTPUT" -eq 0 ]; then + run_cmd="$run_cmd >& $TRAIN_LOG" +else + run_cmd="$run_cmd |& tee $TRAIN_LOG" +fi + +if [ "$NO_TRAINING" -eq 0 ]; then + eval $run_cmd +fi + + +echo 'import argparse +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="Process Log") + parser.add_argument("filename") + args = parser.parse_args() + + with open(args.filename) as f: + lines = f.readlines() + lines = lines[2:-1] + lines = [float(a) for a in lines] + mean = np.mean(np.array(lines)) + print(mean)' > mean_log_value.py + + +# echo '============================================================================================================' +grep -Eo 'throughput per GPU [^|]*' $TRAIN_LOG | sed -E 's/.*throughput per GPU \(TFLOP\/s\/GPU\): ([0-9\.]+).*/\1/' > tmp.txt +PERFORMANCE=$(python3 mean_log_value.py tmp.txt) +echo "throughput per GPU: $PERFORMANCE" |& tee -a $TRAIN_LOG +rm tmp.txt + +# echo '============================================================================================================' +grep -Eo 'elapsed time per iteration [^|]*' $TRAIN_LOG | sed -E 's/.*elapsed time per iteration \(ms\): ([0-9\.]+).*/\1/' > tmp.txt +ETPI=$(python3 mean_log_value.py tmp.txt) +echo "elapsed time per iteration: $ETPI" |& tee -a $TRAIN_LOG + +TIME_PER_ITER=$(python3 mean_log_value.py tmp.txt 2>/dev/null | awk '{printf "%.6f", $0}') +TGS=$(awk -v bs="$BS" -v sl="$SEQ_LENGTH" -v tpi="$TIME_PER_ITER" -v ws="$WORLD_SIZE" 'BEGIN {printf "%.6f", bs * sl * 1000/ (tpi * ws)}') +echo "tokens/GPU/s: $TGS" |& tee -a $TRAIN_LOG +rm tmp.txt diff --git a/examples/llama/train_llama3.sh b/examples/llama/train_llama3.sh new file mode 100644 index 0000000000..4392268545 --- /dev/null +++ b/examples/llama/train_llama3.sh @@ -0,0 +1,312 @@ +#!/bin/bash +############################################################################### +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +################################################################################# +#set -x + +# set envs +export GPU_MAX_HW_QUEUES=2 +export TORCH_NCCL_HIGH_PRIORITY=1 +export NCCL_CHECKS_DISABLE=1 +export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 +export NCCL_IB_GID_INDEX=3 +export NCCL_CROSS_NIC=0 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_PROTO=Simple +export RCCL_MSCCL_ENABLE=0 +export TOKENIZERS_PARALLELISM=false +export HSA_NO_SCRATCH_RECLAIM=1 + + +# parsing input arguments +for ARGUMENT in "$@" +do + KEY=$(echo $ARGUMENT | cut -f1 -d=) + + KEY_LENGTH=${#KEY} + VALUE="${ARGUMENT:$KEY_LENGTH+1}" + + export "$KEY"="$VALUE" +done + + +TIME_STAMP=$(date +"%Y-%m-%d_%H-%M-%S") +EXP_NAME="${EXP_NAME:-perf}" + +TEE_OUTPUT="${TEE_OUTPUT:-1}" +USE_FLASH_ATTN="${USE_FLASH_ATTN:-1}" +NO_TRAINING="${NO_TRAINING:-0}" # NO_TRAINING=1: for computing metrics only +ENABLE_PROFILING="${ENABLE_PROFILING:-0}" #enable pytorch profiling +ENABLE_ROPE="${ENABLE_ROPE:-1}" +DISABLE_ROPE_TE="${DISABLE_ROPE_TE:-0}" +echo "NO_TRAINING=$NO_TRAINING" + +CWD=`pwd` +GPUS_PER_NODE=`python3 -c "import torch; print(torch.cuda.device_count())"` + +# single node config, Change for multinode config +MASTER_ADDR="${MASTER_ADDR:-localhost}" +MASTER_PORT="${MASTER_PORT:-6000}" +NNODES="${NNODES:-1}" +NODE_RANK="${NODE_RANK:-0}" +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +if [ "${NNODES:-1}" -gt 1 ]; then + export NCCL_SOCKET_IFNAME="${NCCL_SOCKET_IFNAME:-ens5}" + export GLOO_SOCKET_IFNAME="${GLOO_SOCKET_IFNAME:-ens50f0}" + echo "NCCL and GLOO socket interfaces set." +else + echo "Single node setup, skipping NCCL and GLOO socket interface settings." +fi + +MODEL_SIZE="${MODEL_SIZE:-70}" +TP="${TP:-8}" +PP="${PP:-1}" +CP="${CP:-1}" +MBS="${MBS:-2}" +BS="${BS:-8}" +SEQ_LENGTH="${SEQ_LENGTH:-2048}" +TOTAL_ITERS="${TOTAL_ITERS:-10}" +SEQ_PARALLEL="${SEQ_PARALLEL:-1}" +CONTI_PARAMS="${CONTI_PARAMS:-0}" +TE_FP8="${TE_FP8:-0}" # 0: disable FP8, 1: enable FP8 +GEMM_TUNING="${GEMM_TUNING:-1}" +MCORE="${MCORE:-1}" + +EXPERIMENT_DIR="experiment" +mkdir -p $EXPERIMENT_DIR +CHECKPOINT_PATH=${CHECKPOINT_PATH:-"$EXPERIMENT_DIR/ckpts"} + +DATA_DIR="${DATA_DIR:-/root/.cache/data}" +TOKENIZER_MODEL=meta-llama/Llama-3.1-8B +# Download the tokenizer model +# if ! [ -f "$TOKENIZER_MODEL" ]; then +# wget -O $TOKENIZER_MODEL https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/original/tokenizer.model +# fi + +DATA_PATH=${DATA_PATH:-"$DATA_DIR/bookcorpus_text_sentence"} + +MAX_POSITION_EMBEDDINGS=128000 + +DEFAULT_LOG_DIR="${EXPERIMENT_DIR}/${NNODES}nodes_rank${NODE_RANK}_train_${MODEL_SIZE}B_mbs${MBS}_bs${BS}_tp${TP}_pp${PP}_cp${CP}_iter${TOTAL_ITERS}/TE_FP8_${TE_FP8}/${TIME_STAMP}" +LOG_DIR="${LOG_DIR:-${DEFAULT_LOG_DIR}}" +TRAIN_LOG="${LOG_DIR}/output_${EXP_NAME}.log" +mkdir -p $LOG_DIR +echo $TRAIN_LOG + +# gemm tuning +if [ "$GEMM_TUNING" -eq 1 ]; then + export TE_HIPBLASLT_TUNING_RUN_COUNT=10 + export TE_HIPBLASLT_TUNING_ALGO_COUNT=50 +fi + +if [ "$SEQ_LENGTH" -le 8192 ]; then + ds_works=8 +else + ds_works=24 +fi + +if [[ $MODEL_SIZE -eq 8 ]]; then #llama2-7B + HIDDEN_SIZE=4096 # e.g. llama-13b: 5120 + FFN_HIDDEN_SIZE=14336 # e.g. llama-13b: 13824 + NUM_LAYERS=32 # e.g. llama-13b: 40 + NUM_HEADS=32 # e.g. llama-13b: 40 + SEQ_LENGTH=$SEQ_LENGTH + NUM_KV_HEADS=8 # llama2 70B uses GQA +elif [[ $MODEL_SIZE -eq 70 ]]; then + HIDDEN_SIZE=8192 # e.g. llama-13b: 5120 + FFN_HIDDEN_SIZE=28672 # e.g. llama-13b: 13824 + NUM_LAYERS=80 # e.g. llama-13b: 40 + NUM_HEADS=64 # e.g. llama-13b: 40 + NUM_KV_HEADS=8 # llama3 70B uses GQA + SEQ_LENGTH=$SEQ_LENGTH + MAX_POSITION_EMBEDDINGS=$MAX_POSITION_EMBEDDINGS +else + echo "Model size not supported." + exit 1 +fi + +GROUP_SIZE=$(( ${NUM_HEADS} / ${NUM_KV_HEADS} )) +NUM_GROUPS=$(( ${NUM_HEADS} / ${GROUP_SIZE} )) + +PROFILING_DIR="${LOG_DIR}/trace_${EXP_NAME}" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --context-parallel-size ${CP} \ + --num-layers $NUM_LAYERS \ + --hidden-size $HIDDEN_SIZE \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --num-attention-heads $NUM_HEADS \ + --seq-length $SEQ_LENGTH \ + --max-position-embeddings $MAX_POSITION_EMBEDDINGS \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --no-position-embedding \ + --disable-bias-linear \ + --swiglu \ + --init-method-std 0.02 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --normalization RMSNorm \ + --micro-batch-size $MBS \ + --global-batch-size $BS \ + --train-iters $TOTAL_ITERS \ + --no-async-tensor-model-parallel-allreduce \ + --bf16 \ + --no-masked-softmax-fusion \ + --disable-bias-linear \ +" + +TRAIN_ARGS="--lr 1e-4 \ + --min-lr 1e-5 \ + --lr-decay-iters 320000 \ + --lr-decay-style cosine \ + --weight-decay 1.0e-1 \ + --clip-grad 1.0 \ + --optimizer adam \ +" + +DATA_ARGS=" + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --dataloader-type cyclic \ + --save-interval 200000 \ + --tensorboard-dir $LOG_DIR \ + --log-interval 1 \ + --eval-interval 320000 \ + --eval-iters 10 \ + --num-workers $ds_works \ + --mock-data +" +#--data-path $DATA_PATH \ +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 5000 \ + --log-throughput \ + --no-save-optim \ + --eval-iters -1 +" +# --save $CHECKPOINT_PATH \ + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ +" + +CKPT_LOAD_ARGS="--exit-on-missing-checkpoint \ + --no-load-optim \ + --use-checkpoint-args \ + --no-load-rng" + + +EXTRA_ARGS=" + --group-query-attention \ + --num-query-groups $NUM_GROUPS \ + --no-gradient-accumulation-fusion \ + --distributed-backend nccl \ + --distributed-timeout-minutes 120 \ + --use-distributed-optimizer \ + --overlap-param-gather \ + --overlap-grad-reduce \ +" + +if [ "$ENABLE_PROFILING" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --profile --use-pytorch-profiler --tensorboard-dir $LOG_DIR" +fi + +if [ "$USE_FLASH_ATTN" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --use-flash-attn" +fi + +if [ "$SEQ_PARALLEL" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --sequence-parallel" +fi + +if [ "$CONTI_PARAMS" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --use-contiguous-parameters-in-local-ddp" +fi + +if [ "$MCORE" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --use-mcore-models" +fi + +if [ "$ENABLE_ROPE" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --position-embedding-type rope" +fi + +if [ "$DISABLE_ROPE_TE" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --disable-te-fused-rope" +fi + +if [ "$TE_FP8" -eq 1 ]; then +EXTRA_ARGS="$EXTRA_ARGS --transformer-impl=transformer_engine \ + --fp8-margin=0 \ + --fp8-format=hybrid \ + --fp8-interval=1 \ + --fp8-amax-history-len=1024 \ + --fp8-amax-compute-algo=max \ + --attention-softmax-in-fp32 \ +" +fi + +run_cmd=" + torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $EXTRA_ARGS \ + $TRAIN_ARGS \ +" + +if [ "$TEE_OUTPUT" -eq 0 ]; then + run_cmd="$run_cmd >& $TRAIN_LOG" +else + run_cmd="$run_cmd |& tee $TRAIN_LOG" +fi + +if [ "$NO_TRAINING" -eq 0 ]; then + eval $run_cmd +fi + + +echo 'import argparse +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="Process Log") + parser.add_argument("filename") + args = parser.parse_args() + + with open(args.filename) as f: + lines = f.readlines() + lines = lines[2:-1] + lines = [float(a) for a in lines] + mean = np.mean(np.array(lines)) + print(mean)' > mean_log_value.py + + +# echo '============================================================================================================' +grep -Eo 'throughput per GPU [^|]*' $TRAIN_LOG | sed -E 's/.*throughput per GPU \(TFLOP\/s\/GPU\): ([0-9\.]+).*/\1/' > tmp.txt +PERFORMANCE=$(python3 mean_log_value.py tmp.txt) +echo "throughput per GPU: $PERFORMANCE" |& tee -a $TRAIN_LOG +rm tmp.txt + +# echo '============================================================================================================' +grep -Eo 'elapsed time per iteration [^|]*' $TRAIN_LOG | sed -E 's/.*elapsed time per iteration \(ms\): ([0-9\.]+).*/\1/' > tmp.txt +ETPI=$(python3 mean_log_value.py tmp.txt) +echo "elapsed time per iteration: $ETPI" |& tee -a $TRAIN_LOG + +TIME_PER_ITER=$(python3 mean_log_value.py tmp.txt 2>/dev/null | awk '{printf "%.6f", $0}') +TGS=$(awk -v bs="$BS" -v sl="$SEQ_LENGTH" -v tpi="$TIME_PER_ITER" -v ws="$WORLD_SIZE" 'BEGIN {printf "%.6f", bs * sl * 1000/ (tpi * ws)}') +echo "tokens/GPU/s: $TGS" |& tee -a $TRAIN_LOG +rm tmp.txt + + diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index bf5159c759..8957c1fd30 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -706,6 +706,15 @@ def forward( else: return core_attn_out + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + state_dict = self.state_dict(prefix='', keep_vars=True) + # TE with version>=1.9 introduces an extra state in DotProductAttention Module + if is_te_min_version("1.9.0.dev0") and ('_extra_state' in state_dict): + state_dict.pop('_extra_state') + return make_sharded_tensors_for_checkpoint( + state_dict, prefix, {}, sharded_offsets + ) + if is_te_min_version("1.9.0.dev0"): diff --git a/megatron/core/models/common/embeddings/rope_utils.py b/megatron/core/models/common/embeddings/rope_utils.py old mode 100644 new mode 100755 index accb251961..ece04492af --- a/megatron/core/models/common/embeddings/rope_utils.py +++ b/megatron/core/models/common/embeddings/rope_utils.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple, Union if TYPE_CHECKING: from megatron.core.transformer.transformer_config import TransformerConfig @@ -26,6 +26,11 @@ except ImportError: HAVE_APPLY_ROPE_FUSION = False +try: + import transformer_engine.pytorch.cpp_extensions as tex + HAVE_TE = True +except ImportError: + HAVE_TE = False def get_pos_emb_on_this_cp_rank(pos_emb: Tensor, seq_dim: int) -> Tensor: """Get the position embedding on the current context parallel rank. @@ -149,43 +154,162 @@ def apply_rotary_pos_emb( Reroute to the appropriate apply_rotary_pos_emb function depending on fused/unfused kernels, or bshd (conventional) / thd (packed seq) format """ - if config.apply_rope_fusion and not HAVE_APPLY_ROPE_FUSION: - # setting apply_rope_fusion in config to False - # so that subsequent queries to this config also return False - config.apply_rope_fusion = False - if not getattr(apply_rotary_pos_emb, "printed_fused_warning", False): + if not config.disable_te_fused_rope and HAVE_TE and torch.cuda.is_available() and torch.version.hip: + return apply_rotary_pos_emb_fused_te(t = t, freqs = freqs, config = config, cu_seqlens = cu_seqlens) + else: + if config.apply_rope_fusion and not HAVE_APPLY_ROPE_FUSION: + # setting apply_rope_fusion in config to False + # so that subsequent queries to this config also return False + config.apply_rope_fusion = False + if not getattr(apply_rotary_pos_emb, "printed_fused_warning", False): + logger.warning( + "Setting apply_rope_fusion to false because its implementation" + " is not included in Apex. Try upgrading to the latest version" + ) + apply_rotary_pos_emb.printed_fused_warning = True + + if getattr(config, "multi_latent_attention", False) and config.rotary_interleaved: logger.warning( - "Setting apply_rope_fusion to false because its implementation" - " is not included in Apex. Try upgrading to the latest version" + "rotary_interleaved is not supported with multi_latent_attention, setting it to False" ) - apply_rotary_pos_emb.printed_fused_warning = True + config.rotary_interleaved = False + + if config.apply_rope_fusion: + if cu_seqlens is None: + return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True) + else: + return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) + else: + if cu_seqlens is None: + return _apply_rotary_pos_emb_bshd( + t, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + else: + return _apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + ) + +class FusedRoPEFunc(torch.autograd.Function): + """ + Function for FusedRoPE - if getattr(config, "multi_latent_attention", False) and config.rotary_interleaved: - logger.warning( - "rotary_interleaved is not supported with multi_latent_attention, setting it to False" - ) - config.rotary_interleaved = False + This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and + the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid + the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. + """ - if config.apply_rope_fusion: - if cu_seqlens is None: - return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True) + @staticmethod + def forward( + ctx, + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + cu_seqlens: Union[torch.Tensor, None] = None, + ) -> torch.Tensor: + if freqs.dtype != torch.float32: + freqs = freqs.float() + if tensor_format == "sbhd": + output = tex.fused_rope_forward(t, freqs, False) + elif tensor_format == "bshd": + output = tex.fused_rope_forward( + t.transpose(0, 1), freqs, True + ).transpose(0, 1) + elif tensor_format == "thd": + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs) else: - return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) - else: - if cu_seqlens is None: - return _apply_rotary_pos_emb_bshd( - t, - freqs, - rotary_interleaved=config.rotary_interleaved, - multi_latent_attention=config.multi_latent_attention, - mscale=mscale, - ) + raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + ctx.save_for_backward(freqs, cu_seqlens) + ctx.tensor_format = tensor_format + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + freqs, cu_seqlens = ctx.saved_tensors + if ctx.tensor_format == "sbhd": + grad_input = tex.fused_rope_backward(grad_output, freqs, False) + elif ctx.tensor_format == "bshd": + grad_input = tex.fused_rope_backward( + grad_output.transpose(0, 1), freqs, True + ).transpose(0, 1) + elif ctx.tensor_format == "thd": + grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs) else: - return _apply_rotary_pos_emb_thd( - t, - cu_seqlens, - freqs, - rotary_interleaved=config.rotary_interleaved, - multi_latent_attention=config.multi_latent_attention, - mscale=mscale, - ) + raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") + + return grad_input, None, None, None, None + + +def apply_rotary_pos_emb_fused_te( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + config: TransformerConfig = None, + fused: bool = True, + cu_seqlens: Union[torch.Tensor, None] = None, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which + rotary positional embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + fused: bool, default = False + Whether to use a fused applying RoPE implementation. + tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' + is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is + of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. + cu_seqlens: torch.Tensor, default = None. + Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and + dtype torch.int32. Only valid when `tensor_format` is 'thd'. + """ + + if fused: + assert ( + tensor_format != "thd" or cu_seqlens is not None + ), "cu_seqlens must not be None when tensor_format is 'thd'." + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens) + + assert tensor_format in ("sbhd", "bshd"), ( + "Only formats `sbhd` or `bshd` are supported for input tensor `t` " + f"when fused is False, got {tensor_format}." + ) + + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert cur_seq_len <= max_seq_len, ( + f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + ) + freqs = freqs[:cur_seq_len] + if tensor_format == "bshd": + freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_) + (_rotate_half(t) * sin_) + return torch.cat((t, t_pass), dim=-1) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py old mode 100644 new mode 100755 index 5232faec60..d16ae79cdb --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -183,4 +183,4 @@ def get_rotary_seq_len( rotary_seq_len *= transformer_config.context_parallel_size - return rotary_seq_len + return rotary_seq_len \ No newline at end of file diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py old mode 100644 new mode 100755 index a63171686a..b8968d6cf5 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -165,6 +165,9 @@ class TransformerConfig(ModelParallelConfig): apply_rope_fusion: bool = False """If True, use fused RoPE kernel.""" + disable_te_fused_rope: bool = False + """If True, disable fused RoPE kernel from transformer engine""" + #################### # activation recomputation #################### diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py old mode 100644 new mode 100755 index 87cceac3e3..f01088fd5a --- a/megatron/legacy/fused_kernels/__init__.py +++ b/megatron/legacy/fused_kernels/__init__.py @@ -3,9 +3,10 @@ import os import pathlib import subprocess - +import torch from torch.utils import cpp_extension + # Setting this param to a list has a problem of generating different # compilation commands (with diferent order of architectures) and # leading to recompilation of fused kernels. Set it to empty string @@ -16,22 +17,23 @@ def load(args): - # Check if cuda 11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( - cpp_extension.CUDA_HOME - ) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - if int(bare_metal_minor) >= 8: + if torch.cuda.is_available() and torch.version.cuda: + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME + ) + if int(bare_metal_major) >= 11: cc_flag.append('-gencode') - cc_flag.append('arch=compute_90,code=sm_90') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 8: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') - # Build path - srcpath = pathlib.Path(__file__).parent.absolute() - buildpath = srcpath / "build" - _create_build_dir(buildpath) + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / "build" + _create_build_dir(buildpath) # Helper function to build the kernels. def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py old mode 100644 new mode 100755 index e3d876a5f2..9411223126 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -695,6 +695,8 @@ def core_transformer_config_from_args(args, config_class=None): else: kw_args['num_query_groups'] = None kw_args['config_logger_dir'] = args.config_logger_dir + if args.disable_te_fused_rope: + kw_args['disable_te_fused_rope'] = args.disable_te_fused_rope # Return config. return config_class(**kw_args) @@ -853,6 +855,8 @@ def _add_network_size_args(parser): action='store_false', help='Disable position embedding. Deprecated: use --position-embedding-type', dest='add_position_embedding') + group.add_argument('--disable-te-fused-rope', action='store_true', default = False, + help='Disable fused rope from transformer-engine: use --disable_te_fused_rope') group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') diff --git a/pytest.ini b/pytest.ini index c75f3b9fa4..cb6bfac7d4 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,10 @@ # content of pytest.ini [pytest] markers = - internal: mark a test as a test to private/internal functions. \ No newline at end of file + internal: Mark a test as a test to private/internal functions. + failing_on_rocm: Currently Failing Tests on ROCm. + failing_on_rocm_mi250: Tests failing on MI250. + test_on_rocm: Mark a test that we run on ROCm specifically. + +addopts = + --ignore tests/unit_tests/test_utilities.py diff --git a/run_unit_tests.sh b/run_unit_tests.sh new file mode 100755 index 0000000000..cfa8e0ad3b --- /dev/null +++ b/run_unit_tests.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +set -x +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +PYTEST_MARKERS="(not flaky and not internal and not failing_on_rocm or test_on_rocm)" + +if [[ "$HIP_ARCHITECTURES" == "gfx90a" ]]; then + PYTEST_MARKERS="$PYTEST_MARKERS and not failing_on_rocm_mi250" +fi + +torchrun --nproc_per_node=8 -m pytest --color=yes -m "$PYTEST_MARKERS" --csv output/test_report.csv tests/unit_tests/ diff --git a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py index 4a8f153ed4..695a257c0f 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py +++ b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py @@ -102,6 +102,7 @@ def teardown_method(self, method): (False, (1, 1, 4), (8, 1, 1), True), ], ) + @pytest.mark.failing_on_rocm @pytest.mark.parametrize("expert_type", expert_type) def test_parallel_reconfiguration_e2e( self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, use_fpsl, expert_type @@ -182,6 +183,7 @@ def test_parallel_reconfiguration_e2e( ], ) @pytest.mark.parametrize("src_module,dest_module", src_dest_expert_type) + @pytest.mark.failing_on_rocm def test_sequential_grouped_mlp_interchangeable( self, tmp_path_dist_ckpt, src_tp_pp_exp, dest_tp_pp_exp, use_glu, src_module, dest_module ): diff --git a/tests/unit_tests/dist_checkpointing/test_fp8.py b/tests/unit_tests/dist_checkpointing/test_fp8.py old mode 100644 new mode 100755 index d2dcb367c7..9f29626acb --- a/tests/unit_tests/dist_checkpointing/test_fp8.py +++ b/tests/unit_tests/dist_checkpointing/test_fp8.py @@ -20,6 +20,7 @@ class TestFP8: @pytest.mark.parametrize('dtype', ['bf16', 'fp16', 'fp8']) @pytest.mark.parametrize('src_rank', [0, 6]) + @pytest.mark.failing_on_rocm def test_simple_broadcast(self, dtype, src_rank): Utils.initialize_model_parallel() @@ -52,6 +53,7 @@ def get_ten(dtype: str = 'fp8'): ], ) @pytest.mark.flaky + @pytest.mark.failing_on_rocm def test_fp8_save_load( self, tmp_path_dist_ckpt, use_fpsl, src_tp_pp, dest_tp_pp, load_exchange_algo ): diff --git a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py b/tests/unit_tests/dist_checkpointing/test_nonpersistent.py index 346751e264..263168a63c 100644 --- a/tests/unit_tests/dist_checkpointing/test_nonpersistent.py +++ b/tests/unit_tests/dist_checkpointing/test_nonpersistent.py @@ -30,6 +30,7 @@ def teardown_method(self, method): @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) @pytest.mark.flaky + @pytest.mark.failing_on_rocm def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp): Utils.initialize_model_parallel(tp, pp) num_floating_point_operations_so_far = 0 @@ -119,6 +120,7 @@ def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, tp, pp): class TestLegacySaveAndLoad: @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) @pytest.mark.flaky + @pytest.mark.failing_on_rocm def test_basic_save_load_scenario(self, tmp_path_dist_ckpt, tp, pp): Utils.initialize_model_parallel(tp, pp) num_floating_point_operations_so_far = 0 diff --git a/tests/unit_tests/dist_checkpointing/test_optimizer.py b/tests/unit_tests/dist_checkpointing/test_optimizer.py index 19d1ee9e85..2b47138d46 100644 --- a/tests/unit_tests/dist_checkpointing/test_optimizer.py +++ b/tests/unit_tests/dist_checkpointing/test_optimizer.py @@ -517,6 +517,7 @@ def test_optimizer_resharding( ((2, 1, 2), (1, 1, 8)), ], ) + @pytest.mark.failing_on_rocm def test_chained_optimizer_resharding( self, tmp_path_dist_ckpt, diff --git a/tests/unit_tests/inference/test_modelopt_gpt_model.py b/tests/unit_tests/inference/test_modelopt_gpt_model.py index 380ac7fa16..2cb86e546e 100644 --- a/tests/unit_tests/inference/test_modelopt_gpt_model.py +++ b/tests/unit_tests/inference/test_modelopt_gpt_model.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import pytest from megatron.core.inference.modelopt_support.gpt.model_specs import get_gpt_layer_modelopt_spec from megatron.core.inference.modelopt_support.gpt.state_dict_hooks import ( mcore_gpt_load_te_state_dict_pre_hook, @@ -32,6 +33,7 @@ def setup_method(self, method): max_sequence_length=4, ) + @pytest.mark.failing_on_rocm_mi250 def test_load_te_state_dict_pre_hook(self): handle = self.modelopt_gpt_model._register_load_state_dict_pre_hook( mcore_gpt_load_te_state_dict_pre_hook diff --git a/tests/unit_tests/models/test_mamba_model.py b/tests/unit_tests/models/test_mamba_model.py index 913adb538c..cad2b0367c 100644 --- a/tests/unit_tests/models/test_mamba_model.py +++ b/tests/unit_tests/models/test_mamba_model.py @@ -56,6 +56,7 @@ def test_set_input_tensor(self): assert self.model.decoder.input_tensor.shape[1] == micro_batch_size assert self.model.decoder.input_tensor.shape[2] == config.hidden_size + @pytest.mark.failing_on_rocm def test_forward(self): config: TransformerConfig = self.model.config sequence_length = self.model.max_sequence_length @@ -78,6 +79,7 @@ def test_forward(self): assert logits.shape[1] == sequence_length assert logits.shape[2] == self.model.vocab_size + @pytest.mark.failing_on_rocm def test_inference(self): config: TransformerConfig = self.model.config micro_batch_size = 2 diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index 229cead1c3..c8b0139a59 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -110,7 +110,7 @@ def test_cross_check_param_hashes_across_dp_replicas(): # Teardown. _deinit_distributed() - +@pytest.mark.failing_on_rocm def test_straggler_detector(): world = int(os.getenv('WORLD_SIZE', '1')) rank = int(os.getenv('RANK', '0')) diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index 043bdc8c58..8638cf364b 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -278,6 +278,7 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.test_on_rocm @pytest.mark.internal def test_constructor(self): assert isinstance(self.sequential_mlp, MoELayer) @@ -313,6 +314,7 @@ def test_constructor(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.test_on_rocm @pytest.mark.internal def test_gpu_forward_backward(self): self.sequential_mlp.cuda() @@ -356,6 +358,7 @@ def test_gpu_forward_backward(self): torch.testing.assert_close(smm_result, gmm_result) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.test_on_rocm @pytest.mark.internal def test_gpu_forward_backward_with_no_tokens_allocated(self): """Test the case when no token is allocated for groupedGEMM kernels.""" diff --git a/tests/unit_tests/transformer/test_attention.py b/tests/unit_tests/transformer/test_attention.py index 8c13ff3f8c..ef1226d0b0 100644 --- a/tests/unit_tests/transformer/test_attention.py +++ b/tests/unit_tests/transformer/test_attention.py @@ -37,7 +37,7 @@ def test_constructor(self): def test_cpu_forward(self): # we can't currently do this because the global memory buffer is on GPU pass - + @pytest.mark.failing_on_rocm def test_gpu_forward(self): config = self.parallel_attention.config @@ -61,7 +61,7 @@ def test_gpu_forward(self): assert output.shape[1] == micro_batch_size assert output.shape[2] == config.hidden_size assert bias.shape[0] == config.hidden_size - + @pytest.mark.failing_on_rocm def test_fused_rope_gpu_forward(self): self.parallel_attention.config.apply_rope_fusion = True config = self.parallel_attention.config @@ -90,7 +90,7 @@ def test_fused_rope_gpu_forward(self): assert output.shape[2] == config.hidden_size assert bias.shape[0] == config.hidden_size self.parallel_attention.config.apply_rope_fusion = False - + @pytest.mark.failing_on_rocm def test_checkpointed_gpu_forward(self): transformer_config = self.transformer_config transformer_config.recompute_granularity = 'selective' diff --git a/tests/unit_tests/transformer/test_retro_attention.py b/tests/unit_tests/transformer/test_retro_attention.py index d7c5a5f155..751e1c74e3 100644 --- a/tests/unit_tests/transformer/test_retro_attention.py +++ b/tests/unit_tests/transformer/test_retro_attention.py @@ -3,6 +3,7 @@ import types import torch +import pytest from megatron.core.models.retro import RetroConfig, get_retro_decoder_block_spec from megatron.core.models.retro.decoder_attention import ( @@ -80,6 +81,7 @@ def setup_method(self, method): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.failing_on_rocm def test_constructor(self): config = self.get_config() @@ -101,6 +103,7 @@ def test_constructor(self): assert get_nparams(modules.encoder_bda) == 0 assert get_nparams(modules.encoder_norm) == 32 + @pytest.mark.failing_on_rocm def test_cpu_forward(self): # we can't currently do this because the global memory buffer is on GPU pass @@ -190,7 +193,7 @@ def run_gpu_forward(self, recompute_granularity, use_transformer_engine): config.retro_num_neighbors * micro_batch_size * n_chunks_per_sample, config.hidden_size, ) - + @pytest.mark.failing_on_rocm def test_gpu_forward(self): for recompute_granularity in (None, 'selective'): for use_transformer_engine in (True, False): diff --git a/tests/unit_tests/transformer/test_spec_customization.py b/tests/unit_tests/transformer/test_spec_customization.py index a9a245b861..1d458fe88f 100755 --- a/tests/unit_tests/transformer/test_spec_customization.py +++ b/tests/unit_tests/transformer/test_spec_customization.py @@ -132,6 +132,7 @@ def test_build_module(self): bda_op = build_module(self.bda_spec) assert id(bda_op) == id(get_bias_dropout_add) + @pytest.mark.failing_on_rocm def test_sliding_window_attention(self): if not is_te_min_version("1.2.0"): print("SWA not tested because TE version is not >= 1.2.0", file=sys.stderr) diff --git a/tests/unit_tests/transformer/test_transformer_block.py b/tests/unit_tests/transformer/test_transformer_block.py index 02702a9ff7..210a4bc37c 100644 --- a/tests/unit_tests/transformer/test_transformer_block.py +++ b/tests/unit_tests/transformer/test_transformer_block.py @@ -66,12 +66,14 @@ def test_gpu_forward(self): def test_gpu_forward_full_checkpoint(self): self._run_full_checkpoint_test(fp8=None) + @pytest.mark.failing_on_rocm_mi250 def test_gpu_forward_full_checkpoint_fp8(self): self._run_full_checkpoint_test(fp8="e4m3") def test_gpu_forward_selective_checkpoint(self): self._run_selective_checkpoint_test(fp8=None) + @pytest.mark.failing_on_rocm_mi250 def test_gpu_forward_selective_checkpoint_fp8(self): self._run_selective_checkpoint_test(fp8="e4m3") diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index a81fe8ca7e..a9575707b9 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -203,7 +203,7 @@ def get_args(): choices=['BertWordPieceLowerCase','BertWordPieceCase', 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer', 'Llama2Tokenizer', - 'Llama3Tokenizer', 'MistralTokenizer', 'NullTokenizer'], + 'Llama3Tokenizer', 'MistralTokenizer', 'HuggingFaceTokenizer', 'NullTokenizer'], help='What type of tokenizer to use.') group.add_argument('--tokenizer-model', type=str, default=None, help='YTTM tokenizer model.')