Skip to content

Commit

Permalink
Remove autotune sharing.
Browse files Browse the repository at this point in the history
xla_gpu_shard_autotuning can be used now instead and it is enabled by default.

PiperOrigin-RevId: 705792463
  • Loading branch information
Google-ML-Automation committed Dec 13, 2024
1 parent d0f63da commit a123d4e
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 139 deletions.
125 changes: 0 additions & 125 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from collections.abc import Sequence
import logging
import os
import tempfile
import time
from typing import Any, Callable
import warnings
Expand Down Expand Up @@ -449,22 +447,6 @@ def compile_or_get_cached(
cache_key,
min_device_process_id
)
elif (
config.share_autotune_config_between_hosts.value
and is_multi_process
and distributed.global_state.client is not None
):
log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_autotune_config(
backend,
computation,
compile_options,
host_callbacks,
distributed.global_state.client,
module_name,
cache_key,
min_device_process_id
)
else:
log_persistent_cache_miss(module_name, cache_key)
return _compile_and_write_cache(
Expand Down Expand Up @@ -608,113 +590,6 @@ def _share_fdo_profiles(

_share_fdo_profiles.modules_profiles = {}


# The process with the first_process_id should compile the module and write an
# autotune config to the K-V storage.
def _compile_and_write_autotune_config(
backend: xc.Client,
computation: ir.Module,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
global_client: lib.xla_extension.DistributedRuntimeClient,
module_name: str,
cache_key: str,
first_process_id: int
) -> xc.LoadedExecutable:
share_timeout = config.share_binary_between_hosts_timeout_ms.value
debug_options = compile_options.executable_build_options.debug_options

if _compile_and_write_autotune_config.autotune_configs_dir is None:
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()

autotune_tmp_file = os.path.join(
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
)

if os.path.exists(autotune_tmp_file):
logger.debug(
"Compiling module: %s. Use existing autotune config file: %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
return _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)

if distributed.global_state.process_id == first_process_id:
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
logger.debug("Process %d compiling and dumping autotune for module: %s",
first_process_id, module_name)
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)

logger.debug(
"Writing autotune config for module %s to %s",
module_name,
autotune_tmp_file,
)
with open(autotune_tmp_file, "rb") as f:
autotune_config = f.read()

autotune_config = compilation_cache.compress_executable(autotune_config)
global_client.key_value_set_bytes(cache_key, autotune_config)
logger.debug(
"Autotune config for module %s with size %d shared by cache_key %s",
module_name,
len(autotune_config),
cache_key,
)
else:
logger.debug(
"Compiling module %s, waiting for config to be shared by cache_key %s"
"from process %d",
module_name,
cache_key,
first_process_id
)
autotune_config = global_client.blocking_key_value_get_bytes(
cache_key, share_timeout
)

logger.debug(
"Received autotune config for module %s of size %d",
module_name,
len(autotune_config),
)
autotune_config = compilation_cache.decompress_executable(autotune_config)
with open(autotune_tmp_file, "wb") as f:
f.write(autotune_config)

logger.debug(
"Compiling module %s, using autotune config from %s",
module_name,
autotune_tmp_file,
)
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
executable = _compile_and_write_cache(
backend,
computation,
compile_options,
host_callbacks,
module_name,
cache_key,
)
return executable

_compile_and_write_autotune_config.autotune_configs_dir = None

# The process with the first_process_id should compile the module and write it
# to the K-V storage.
def _compile_and_share_module(
Expand Down
14 changes: 0 additions & 14 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,20 +1169,6 @@ def _update_jax_memories_thread_local(val):
),
)

share_autotune_config_between_hosts = bool_state(
name='jax_share_autotune_config_between_hosts',
default=False,
help=(
'If set to True, the coordinator process will share autotune configs '
'other participants. This will increase overall compilation time, but '
'will lead to equal compiled modules in each process. '
'If both jax_share_binary_between_hosts and '
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
"shared when it's possible and autotune config sharing will be used "
'as a fallback.'
),
)

share_binary_between_hosts = bool_state(
name='jax_share_binary_between_hosts',
default=False,
Expand Down

0 comments on commit a123d4e

Please sign in to comment.