diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 281f936fd772b..71602cb4e8d6b 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_executable_terminate_timeout_seconds(30); opts.set_xla_gpu_experimental_disable_binary_libraries(false); opts.set_xla_experimental_ignore_channel_id(false); + opts.set_xla_gpu_dot_merger_threshold_mb(32); return opts; } @@ -1951,6 +1952,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_experimental_ignore_channel_id), debug_options->xla_experimental_ignore_channel_id(), "Experimental: ignore channel ids for collective operations.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_dot_merger_threshold_mb", + int32_setter_for(&DebugOptions::set_xla_gpu_dot_merger_threshold_mb), + debug_options->xla_gpu_dot_merger_threshold_mb(), + "Dot merger pass threshold to be set in MB.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 3bc29969e3997..68a26a18bb5f6 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -790,9 +790,12 @@ absl::Status RunOptimizationPasses( // AlgebraicSimplifier may add contracting dimensions to a dot. pipeline.AddPass(); pipeline.AddPass(); - // Only merge "smallish" dots. This threshold was not set carefully, but - // so far we know that 1mb is too small. - pipeline.AddPass(/*max_size_to_merge=*/int64_t{32} << 20); + // Only merge "smallish" dots. This threshold defaults to 32MB today, with + // a flag to override. + pipeline.AddPass( + /*max_size_to_merge=*/int64_t{ + debug_options.xla_gpu_dot_merger_threshold_mb()} + << 20); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/xla/xla.proto b/xla/xla.proto index 74eaf1166e459..018320d9c8df6 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -978,7 +978,10 @@ message DebugOptions { // for collectives in the given HLO. bool xla_experimental_ignore_channel_id = 330; - // Next id: 331 + // DotMerger pass threshold size to be used in MB. + int32 xla_gpu_dot_merger_threshold_mb = 331; + + // Next id: 332 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.