This page documents the various flags in XLA and JAX to improve performance for LLMs on GPUs. The XLA flags are defined with their default values in xla/debug_options_flags.cc
The flags can be set via the environment variable XLA_FLAGS="--xla-flag1=true --xla-flag2=false"
on command line or your script.
Please note that some of these flags are experimental. All combinations of flags have not been tested, yet. If you see any unexpected behaviors, please let us know.
-
XLA_PYTHON_CLIENT_MEM_FRACTION is a XLA environment variable that allocates a fraction of GPU memory for JAX/XLA. -- Ideally, should be 1, but in practice less because some memory is used by NVIDIA Libraries, and the JAX framework. -- We typically set it to 0.9 or 0.8. At 0.9, XLA gets 90% of GPU memory.
-
The
xla_gpu_memory_limit_slop_factor
flag controls the memory used by XLA for determining its default heuristics for scheduling, and rematerialization. Default is recommended.
The following environment variable restricts CUDA queues to 1 and is useful when a strict ordering of operations is required to achieve best performance. This is recommended to achieve good performance with latency hiding optimizations with asynchronous collectives.
- CUDA_DEVICE_MAX_CONNECTIONS=1
See NCCL Environment Variables for more details.
- NCCL_PROTO: SIMPLE,LL,LL128
The following variable accelerates all-reduce collective on NVLink4/H100. It requires additional GPU memory and may need one to reduce XLA_PYTHON_CLIENT_MEM_FRACTION
to avoid OOMs if enabled.
- NCCL_NVLS_ENABLE:1
To achieve communication computation overlap for models in JAX/XLA, we must enable Latency Hiding Scheduler and enable asynchronous communications.
To enable latency hiding optimizations with XLA, turn on the following flag:
- --xla_gpu_enable_latency_hiding_scheduler=true
To enable asynchronous communication for all collectives, the following is recommended, and is set by default in XLA :
- --xla_gpu_enable_highest_priority_async_stream=true
For more fine-grained control over which collectives should be asynchronous or not, please use:
- --xla_gpu_disable_async_collectives=allreduce,allgather,reducescatter,collectivebroadcast,alltoall,collectivepermute
With FSDP in JAX/XLA, there are additional optimizations of
-
scan loop unrolling and loop double buffering
- --xla_gpu_enable_while_loop_double_buffering=true
-
optimized pipelining of all-gather and reduce-scatter for latency hiding in FSDP
- --xla_gpu_enable_pipelined_all_gather=true
- --xla_gpu_enable_pipelined_reduce_scatter=true
- --xla_gpu_enable_pipelined_all_reduce=true
- --xla_gpu_enable_pipelined_collectives=false // if true overrides the above
-
combining tensors that are sharded along different dimensions. Within a transformer layer, tensors can be sharded row-wise or column-wise and by default XLA will generate multiple collective calls for tensors sharded along different dimensions. The following optimization flags combine all tensors shardings, and map them to a group NCCL call that has a large commulative size and achieves high communication efficiency.
- --xla_gpu_enable_all_gather_combine_by_dim=false
- --xla_gpu_enable_reduce_scatter_combine_by_dim=false
-
Combine threshold values in XLA that determine when an all-gather (AG) or reduce-scatter (RS) is triggered. We want to set these values to be at least as large as the size of weights (AG) or gradients (RS) in a single transformer layer since large communication buffers achieve higher link bandwidth utilization. For example, LLAMA2-7B with BF16 weights and gradients, we have 32 transformer layers => each layer has ~218M weights => one would want to set these thresholds to at least 436MB.
- --xla_gpu_all_gather_combine_threshold_bytes=8589934592
- --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592
-
Combine threshold values in XLA that determine when an all-reduce (AR) is triggered. Typically, used to overlap AR of gradients with back-prop of compute. We want to set this to be at least as large as possible to achieve high efficiency, but as small as possible to achieve maximum overlap. Depending on the interconnect of your system, one might want to try several threshold values in steps of 2 from say 16MB to total gradient size.
- --xla_gpu_all_reduce_combine_threshold_bytes=8589934592
The following flags enable overlap of pipeline parallel communication of send/recv with computation.
- --xla_gpu_enable_pipelined_p2p=true (false by default)
- --xla_gpu_collective_permute_decomposer_threshold=1024
- --xla_gpu_lhs_enable_gpu_async_tracker=true
The following flags enable overlap of tensor parallel communication with GEMMs/matmul by splicing GEMMs into smaller chunks and triggering each chunks' collective right after the chunk's GEMM is done. The threshold determines the size of output buffer of GEMM when this optimization becomes active (0 enables collective matmul for all GEMM-collective patterns)
- --xla_gpu_multi_streamed_windowed_einsum=true
- --xla_gpu_threshold_for_windowed_einsum_mib=0
The following flag enables use of PGLE with JAX/XLA. Please see PGLE notes for more details.
- --xla_gpu_pgle_profile_file_or_directory_path=filename
The below enables CUDA Graph suppport for JAX/XLA workloads, and is enabled by default.
- --xla_gpu_enable_command_buffer (Set to "" to disable)
The following flag removes extra copies introduced by DUS (dynamic update slice) when used in conjunction with custom NVIDIA kernels (like cuBLAS for GEMMs). This happens in particular when used with scan operations.
- --xla_gpu_enable_custom_fusions=true
- --xla_gpu_enable_address_computation_fusion=true
Enable user-buffers in NCCL for zero-copy collectives and send/recv. Needs NCCL_NVLS_ENABLE=1 for AG, AR, RS.
- --xla_gpu_enable_nccl_user_buffers=true
Flags to reduce memory consumed by NCCL.
- --xla_gpu_enable_nccl_comm_splitting=true
- --xla_gpu_enable_nccl_per_stream_comms=false openxla/xla#9845
Fine-grain control to improve performance by initializing a NCCL communicator to use only max_nchannels (SMs). Default value of 0 gets the default values from NCCL for SMs used per collective.
- --xla_gpu_nccl_collective_max_nchannels
- --xla_gpu_nccl_p2p_max_nchannels
- --xla_dump_to=some/path
- --xla_dump_latency_hiding_schedule=true
- --xla_gpu_cudnn_gemm_fusion=true (enables GEMM/bias fusion via cuDNN)
- --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX)
- --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging)
The following flags were used previously used but no longer required.
- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed
- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default
- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used