Skip to content

Commit

Permalink
Merge branch 'main' into Obliviour-patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
Obliviour authored Jan 17, 2025
2 parents 3e7a4dd + 21706fc commit 73852aa
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 16 deletions.
1 change: 0 additions & 1 deletion MaxText/configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ Here are some of the most common XLA compiler flags used by MaxText.
| xla_gpu_enable_pipelined_reduce_scatter | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions. <br> **Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) |
| xla_gpu_enable_pipelined_all_reduce | Boolean (true/false) | Enable pipelinling of all-reduce instructions. <br> **Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) |
| xla_gpu_enable_while_loop_double_buffering | Boolean (true/false) | Enable double-buffering for while loop. <br> **Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) |
| xla_gpu_enable_triton_softmax_fusion | Boolean (true/false) | Use Triton-based Softmax fusion. <br> **Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) |
| xla_gpu_enable_all_gather_combine_by_dim | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension. <br> **Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) |
| xla_gpu_enable_reduce_scatter_combine_by_dim | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. <br> **Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) |
| xla_disable_hlo_passes | String (comma-separated list of pass names) | Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas). <br> **Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) |
Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/a3/llama_2_7b/16vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/a3/llama_2_7b/1vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/a3/llama_2_7b/2vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_all_reduce_combine_threshold_bytes=67108864 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/a3/llama_2_7b/4vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_all_reduce_combine_threshold_bytes=536870912 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/a3/llama_2_7b/8vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

Expand Down
8 changes: 4 additions & 4 deletions MaxText/configs/a3/llama_3.1_405b/128vm.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
echo "Running 128vm.sh"
# Example command to invoke this script via XPK, assume you've installed xpk
# COMMAND="bash MaxText/configs/a3/llama_3.1_405b/128vm.sh"
# COMMAND='export LD_LIBRARY_PATH=/usr/local/cuda-12.6/compat:$LD_LIBRARY_PATH;'"${COMMAND}";
#
# COMMAND='export LD_LIBRARY_PATH=/usr/local/cuda-12.6/compat:$LD_LIBRARY_PATH;'"${COMMAND}";
#
# xpk workload create --project=${PROJECT}--cluster=${CLUSTER_NAME} --zone=${ZONE} \
# --workload=${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
# --device-type=${DEVICE_TYPE} --num-nodes=2 --priority=high \
Expand All @@ -28,7 +28,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

Expand All @@ -52,5 +52,5 @@ python MaxText/$EXECUTABLE MaxText/configs/models/llama3.1_405b.yml run_name=$RU
dcn_fsdp_parallelism=128 \
ici_fsdp_parallelism=8 \
base_output_directory=$OUTPUT_PATH \
profiler=xplane
profiler=xplane

48 changes: 48 additions & 0 deletions MaxText/configs/trillium/gemma2_27b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Gemma2_27b model.
# This config will work out of the box for any number of trillium-256 slices.
#
# Command Flags:
# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
# DATASET_PATH (Required, unless dataset_path is already set in base.yml)
# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
#
# Example to invoke this script:
# bash MaxText/configs/trillium/gemma2_27b.sh RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>"
#


# Stop execution if any command exits with error
set -e

export EXECUTABLE="train.py" # or train_compile.py
export RUN_PREFLIGHT="true"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

# The setup accommodates two cases:
# 1) Passing the 'RUN_NAME' variable at runtime
# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
if [ -n "$RUN_NAME" ];
then
export M_RUN_NAME=$RUN_NAME
fi

# Set up network optimizations
if [ "$RUN_PREFLIGHT" = "true" ]; then
bash preflight.sh
fi

# Train
export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

python3 MaxText/$EXECUTABLE MaxText/configs/base.yml model_name=gemma2-27b\
steps=15 per_device_batch_size=2 enable_checkpointing=false\
remat_policy=full ici_fsdp_transpose_parallelism=256 ici_fsdp_parallelism=-1\
max_target_length=8192 base_output_directory=$OUTPUT_PATH\
reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\
attention='flash' sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048

47 changes: 47 additions & 0 deletions MaxText/configs/trillium/gemma2_9b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Gemma2-9b model.
# This config will work out of the box for any number of trillium-256 slices.
#
# Command Flags:
# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
# DATASET_PATH (Required, unless dataset_path is already set in base.yml)
# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
#
# Example to invoke this script:
# bash MaxText/configs/trillium/gemma2_9b.sh RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>"
#


# Stop execution if any command exits with error
set -e

export EXECUTABLE="train.py" # or train_compile.py
export RUN_PREFLIGHT="true"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

# The setup accommodates two cases:
# 1) Passing the 'RUN_NAME' variable at runtime
# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
if [ -n "$RUN_NAME" ];
then
export M_RUN_NAME=$RUN_NAME
fi

# Set up network optimizations
if [ "$RUN_PREFLIGHT" = "true" ]; then
bash preflight.sh
fi

# Train
export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=114688 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

python3 MaxText/$EXECUTABLE MaxText/configs/base.yml model_name=gemma2-9b\
steps=15 per_device_batch_size=3 enable_checkpointing=false\
remat_policy=full ici_fsdp_transpose_parallelism=256 ici_fsdp_parallelism=-1\
max_target_length=8192 base_output_directory=$OUTPUT_PATH\
reuse_example_batch=1 dataset_type=synthetic gcs_metrics=true\
attention='flash' sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048
7 changes: 4 additions & 3 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ def main(config, test_args):
max_logging.log(f"{golden_logits[2]=}")
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
# The ellipsis is used to currently support jax nightly versions newer than 1/9/2025 and stable tests. This can be simplified later
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}"
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}"
)

model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1)
model_probabilities = jax.nn.softmax(full_train_logits[..., 0, :token_size, :], axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1)

max_logging.log(f"{golden_probabilities[1]=}")
Expand All @@ -139,7 +140,7 @@ def main(config, test_args):
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits")
assert jax.numpy.allclose(
full_train_logits[0, :token_size, :],
full_train_logits[..., 0, :token_size, :],
golden_logits[:token_size, :],
rtol=float(test_args.rtol),
atol=float(test_args.atol),
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/xla_flags_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@
" --xla_latency_hiding_scheduler_rerun=2"
)

# Flags to optimize pipeline parallelism over DCN with large host offloading.
PIPELINING_FLAGS = (
" --xla_tpu_iova_dma_chunk_size_bytes=16777216" # breaks DMA to/from host into 16M chunks
)

# Disable bundle-aware CostModel which was causing worse perf b/357103386.
# Some fusions in the backward pass of the model were 3x slower without this.
DISABLE_BUNDLE_AWARE_COST_MODEL = (
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/gpu/a3/test_convergence_125m_params.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TRAIN_CMD="python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME h
TRAIN_CMD+=$CMD_DATA

# Train
export XLA_ARGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization"
export XLA_ARGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization"
$TRAIN_CMD

# Assert training loss is smaller than input LOSS_THRESHOLD
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/gpu/a3/test_convergence_1b_params.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ TRAIN_CMD="python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME h
TRAIN_CMD+=$CMD_DATA

# Train
export XLA_ARGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization"
export XLA_ARGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization"
$TRAIN_CMD

# Assert training loss is smaller than input LOSS_THRESHOLD
Expand Down
2 changes: 1 addition & 1 deletion end_to_end/gpu/a3/test_llama2_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export XLA_FLAGS="--xla_dump_to=$BASE_OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

Expand Down

0 comments on commit 73852aa

Please sign in to comment.