From 052ebcd0810ef879c714f5dd3abf932ff1d2e11a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2024 10:09:31 -0700 Subject: [PATCH] Introduce hermetic CUDA in Google ML projects. 1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases. [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/) [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/) [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history) 2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA. Note: use `@local_tsl` instead of `@tsl` in Tensorflow project. ``` load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) cuda_json_init_repository() load( "@cuda_redist_json//:distributions.bzl", "CUDA_REDISTRIBUTIONS", "CUDNN_REDISTRIBUTIONS", ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) cuda_redist_init_repositories( cuda_redistributions = CUDA_REDISTRIBUTIONS, ) cudnn_redist_init_repository( cudnn_redistributions = CUDNN_REDISTRIBUTIONS, ) load( "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", ) nccl_configure(name = "local_config_nccl") ``` PiperOrigin-RevId: 616865795 --- .bazelrc | 28 +- WORKSPACE | 47 ++ build_tools/build.py | 10 + build_tools/configure/BUILD | 1 + build_tools/configure/configure.py | 157 ++--- build_tools/configure/configure_test.py | 42 +- .../configure/testdata/cuda_clang.bazelrc | 7 +- .../testdata/default_cuda_clang.bazelrc | 19 + .../configure/testdata/nvcc_clang.bazelrc | 7 +- .../configure/testdata/nvcc_gcc.bazelrc | 7 +- docs/build_from_source.md | 11 +- docs/developer_guide.md | 10 + docs/hermetic_cuda.md | 544 ++++++++++++++++++ third_party/tsl/.bazelrc | 28 +- third_party/tsl/WORKSPACE | 47 ++ third_party/tsl/opensource_only.files | 25 + .../tsl/third_party/gpus/check_cuda_libs.py | 3 + .../gpus/compiler_common_tools.bzl | 174 ++++++ .../tsl/third_party/gpus/crosstool/BUILD.tpl | 11 +- .../tsl/third_party/gpus/cuda/BUILD.tpl | 41 +- .../third_party/gpus/cuda/BUILD.windows.tpl | 4 + .../tsl/third_party/gpus/cuda/hermetic/BUILD | 0 .../third_party/gpus/cuda/hermetic/BUILD.tpl | 261 +++++++++ .../gpus/cuda/hermetic/cuda_cccl.BUILD.tpl | 15 + .../gpus/cuda/hermetic/cuda_configure.bzl | 521 +++++++++++++++++ .../gpus/cuda/hermetic/cuda_cublas.BUILD.tpl | 44 ++ .../gpus/cuda/hermetic/cuda_cudart.BUILD.tpl | 126 ++++ .../gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl | 73 +++ .../gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl | 80 +++ .../gpus/cuda/hermetic/cuda_cufft.BUILD.tpl | 29 + .../gpus/cuda/hermetic/cuda_cupti.BUILD.tpl | 59 ++ .../gpus/cuda/hermetic/cuda_curand.BUILD.tpl | 26 + .../cuda/hermetic/cuda_cusolver.BUILD.tpl | 34 ++ .../cuda/hermetic/cuda_cusparse.BUILD.tpl | 27 + .../hermetic/cuda_json_init_repository.bzl | 125 ++++ .../gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl | 75 +++ .../cuda/hermetic/cuda_nvjitlink.BUILD.tpl | 17 + .../gpus/cuda/hermetic/cuda_nvml.BUILD.tpl | 10 + .../gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl | 9 + .../gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl | 20 + .../gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl | 13 + .../cuda_redist_init_repositories.bzl | 491 ++++++++++++++++ .../cuda/hermetic/cuda_redist_versions.bzl | 197 +++++++ .../tsl/third_party/gpus/cuda_configure.bzl | 172 +----- .../tsl/third_party/gpus/find_cuda_config.py | 3 + .../tsl/third_party/gpus/rocm_configure.bzl | 5 +- .../tsl/third_party/gpus/sycl_configure.bzl | 5 +- .../tsl/third_party/nccl/build_defs.bzl.tpl | 11 +- .../tsl/third_party/nccl/hermetic/BUILD | 0 .../nccl/hermetic/cuda_nccl.BUILD.tpl | 30 + .../nccl/hermetic/nccl_configure.bzl | 183 ++++++ .../hermetic/nccl_redist_init_repository.bzl | 145 +++++ .../tsl/third_party/nccl/nccl_configure.bzl | 11 +- .../toolchains/remote_config/configs.bzl | 44 +- .../toolchains/remote_config/rbe_config.bzl | 20 +- third_party/tsl/tsl/platform/default/BUILD | 5 + .../platform/default/cuda_libdevice_path.cc | 28 +- third_party/tsl/workspace2.bzl | 4 - tools/toolchains/remote_config/configs.bzl | 44 +- tools/toolchains/remote_config/rbe_config.bzl | 20 +- xla/lit.bzl | 63 +- xla/service/gpu/tests/BUILD | 1 + xla/stream_executor/cuda/BUILD | 34 +- xla/tools/BUILD | 1 + xla/tools/hlo_opt/BUILD | 1 + xla/tsl/cuda/BUILD.bazel | 84 ++- xla/tsl/tsl.bzl | 11 + 67 files changed, 3995 insertions(+), 405 deletions(-) create mode 100644 build_tools/configure/testdata/default_cuda_clang.bazelrc create mode 100644 docs/hermetic_cuda.md create mode 100644 third_party/tsl/third_party/gpus/compiler_common_tools.bzl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/BUILD create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl create mode 100644 third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl create mode 100644 third_party/tsl/third_party/nccl/hermetic/BUILD create mode 100644 third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl create mode 100644 third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl create mode 100644 third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl diff --git a/.bazelrc b/.bazelrc index 3ea641dd518adf..9e565e91a1b903 100644 --- a/.bazelrc +++ b/.bazelrc @@ -219,11 +219,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -232,22 +237,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -543,9 +548,6 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang @@ -630,7 +632,6 @@ build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/cla # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -644,7 +645,6 @@ build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute diff --git a/WORKSPACE b/WORKSPACE index 9d046e22949091..028dcdc7ef1a8e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -52,3 +52,50 @@ xla_workspace1() load(":workspace0.bzl", "xla_workspace0") xla_workspace0() + +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/build_tools/build.py b/build_tools/build.py index 770ba19a9a885c..e56849c84b828b 100755 --- a/build_tools/build.py +++ b/build_tools/build.py @@ -411,6 +411,16 @@ def main(): r"s/@sigbuild-r2\.17-clang_/@sigbuild-r2.17-clang-cudnn9_/g", "github/xla/.bazelrc", ], + check=True, + ) + sh( + [ + "sed", + "-i", + r"s/8\.9\.7\.29/9.1.1/g", + "github/xla/.bazelrc", + ], + check=True, ) sh(["nvidia-smi"]) diff --git a/build_tools/configure/BUILD b/build_tools/configure/BUILD index 6b84ba404c9043..ed518510f5eae3 100644 --- a/build_tools/configure/BUILD +++ b/build_tools/configure/BUILD @@ -33,6 +33,7 @@ py_test( data = [ "testdata/clang.bazelrc", "testdata/cuda_clang.bazelrc", + "testdata/default_cuda_clang.bazelrc", "testdata/gcc.bazelrc", "testdata/nvcc_clang.bazelrc", "testdata/nvcc_gcc.bazelrc", diff --git a/build_tools/configure/configure.py b/build_tools/configure/configure.py index f19ac30eda44c0..43e0f234d49cfd 100755 --- a/build_tools/configure/configure.py +++ b/build_tools/configure/configure.py @@ -27,11 +27,6 @@ the clang in your path. If that isn't the correct clang, you can override like `./configure.py --backend=cpu --clang_path=`. -NOTE(ddunleavy): Lots of these things should probably be outside of configure.py -but are here because of complexity in `cuda_configure.bzl` and the TF bazelrc. -Once XLA has it's own bazelrc, and cuda_configure.bzl is replaced or refactored, -we can probably make this file smaller. - TODO(ddunleavy): add more thorough validation. """ import argparse @@ -45,18 +40,9 @@ import sys from typing import Optional -_REQUIRED_CUDA_LIBRARIES = ["cublas", "cuda", "cudnn"] _DEFAULT_BUILD_AND_TEST_TAG_FILTERS = ("-no_oss",) # Assume we are being invoked from the symlink at the root of the repo _XLA_SRC_ROOT = pathlib.Path(__file__).absolute().parent -_FIND_CUDA_CONFIG = str( - _XLA_SRC_ROOT - / "third_party" - / "tsl" - / "third_party" - / "gpus" - / "find_cuda_config.py" -) _XLA_BAZELRC_NAME = "xla_configure.bazelrc" _KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} @@ -224,11 +210,12 @@ class DiscoverablePathsAndVersions: ld_library_path: Optional[str] = None # CUDA specific - cublas_version: Optional[str] = None - cuda_toolkit_path: Optional[str] = None + cuda_version: Optional[str] = None cuda_compute_capabilities: Optional[list[str]] = None cudnn_version: Optional[str] = None - nccl_version: Optional[str] = None + local_cuda_path: Optional[str] = None + local_cudnn_path: Optional[str] = None + local_nccl_path: Optional[str] = None def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): """Gets paths and versions as needed by the config. @@ -247,7 +234,7 @@ def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): ) # Notably, we don't use `_find_executable_or_die` for lld, as it changes - # which commands it accepts based on it's name! ld.lld is symlinked to a + # which commands it accepts based on its name! ld.lld is symlinked to a # different executable just called lld, which should not be invoked # directly. self.lld_path = self.lld_path or shutil.which("ld.lld") @@ -261,64 +248,6 @@ def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): if not self.cuda_compute_capabilities: self.cuda_compute_capabilities = _get_cuda_compute_capabilities_or_die() - self._get_cuda_libraries_paths_and_versions_if_needed(config) - - def _get_cuda_libraries_paths_and_versions_if_needed( - self, config: "XLAConfigOptions" - ): - """Gets cuda paths and versions if user left any unspecified. - - This uses `find_cuda_config.py` to find versions for all libraries in - `_REQUIRED_CUDA_LIBRARIES`. - - Args: - config: config that determines which libraries should be found. - """ - should_find_nccl = config.using_nccl and self.nccl_version is None - any_cuda_config_unset = any([ - self.cublas_version is None, - self.cuda_toolkit_path is None, - self.cudnn_version is None, - should_find_nccl, - ]) - - maybe_nccl = ["nccl"] if should_find_nccl else [] - - if any_cuda_config_unset: - logging.info( - "Some CUDA config versions and paths were not provided, " - "so trying to find them using find_cuda_config.py" - ) - try: - find_cuda_config_proc = subprocess.run( - [ - sys.executable, - _FIND_CUDA_CONFIG, - *_REQUIRED_CUDA_LIBRARIES, - *maybe_nccl, - ], - capture_output=True, - check=True, - text=True, - ) - except subprocess.CalledProcessError as e: - logging.info("Command %s failed. Is CUDA installed?", e.cmd) - logging.info("Dumping %s ouptut:\n %s", e.cmd, e.output) - raise e - - cuda_config = dict( - tuple(line.split(": ")) - for line in find_cuda_config_proc.stdout.strip().split("\n") - ) - - self.cublas_version = self.cublas_version or cuda_config["cublas_version"] - self.cuda_toolkit_path = ( - self.cuda_toolkit_path or cuda_config["cuda_toolkit_path"] - ) - self.cudnn_version = self.cudnn_version or cuda_config["cudnn_version"] - if should_find_nccl: - self.nccl_version = self.nccl_version or cuda_config["nccl_version"] - @dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) class XLAConfigOptions: @@ -391,18 +320,31 @@ def to_bazelrc_lines( ) # Lines needed for CUDA backend regardless of CUDA/host compiler + if dpav.cuda_version: + rc.append( + f"build:cuda --repo_env HERMETIC_CUDA_VERSION={dpav.cuda_version}" + ) rc.append( - f"build --action_env CUDA_TOOLKIT_PATH={dpav.cuda_toolkit_path}" - ) - rc.append(f"build --action_env TF_CUBLAS_VERSION={dpav.cublas_version}") - rc.append( - "build --action_env" - f" TF_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" + "build:cuda --repo_env" + f" HERMETIC_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" ) - rc.append(f"build --action_env TF_CUDNN_VERSION={dpav.cudnn_version}") - if self.using_nccl: - rc.append(f"build --action_env TF_NCCL_VERSION={dpav.nccl_version}") - else: + if dpav.cudnn_version: + rc.append( + f"build:cuda --repo_env HERMETIC_CUDNN_VERSION={dpav.cudnn_version}" + ) + if dpav.local_cuda_path: + rc.append( + f"build:cuda --repo_env LOCAL_CUDA_PATH={dpav.local_cuda_path}" + ) + if dpav.local_cudnn_path: + rc.append( + f"build:cuda --repo_env LOCAL_CUDNN_PATH={dpav.local_cudnn_path}" + ) + if dpav.local_nccl_path: + rc.append( + f"build:cuda --repo_env LOCAL_NCCL_PATH={dpav.local_nccl_path}" + ) + if not self.using_nccl: rc.append("build --config nonccl") elif self.backend == Backend.ROCM: pass @@ -489,13 +431,35 @@ def _parse_args(): parser.add_argument("--lld_path", help=path_help) # CUDA specific - find_cuda_config_help = ( - "Optional: will be found using `find_cuda_config.py` if flag is not set." + parser.add_argument( + "--cuda_version", + help="Optional: CUDA will be downloaded by Bazel if the flag is set", + ) + parser.add_argument( + "--cudnn_version", + help="Optional: CUDNN will be downloaded by Bazel if the flag is set", + ) + parser.add_argument( + "--local_cuda_path", + help=( + "Optional: Local CUDA dir will be used in dependencies if the flag" + " is set" + ), + ) + parser.add_argument( + "--local_cudnn_path", + help=( + "Optional: Local CUDNN dir will be used in dependencies if the flag" + " is set" + ), + ) + parser.add_argument( + "--local_nccl_path", + help=( + "Optional: Local NCCL dir will be used in dependencies if the flag" + " is set" + ), ) - parser.add_argument("--cublas_version", help=find_cuda_config_help) - parser.add_argument("--cuda_toolkit_path", help=find_cuda_config_help) - parser.add_argument("--cudnn_version", help=find_cuda_config_help) - parser.add_argument("--nccl_version", help=find_cuda_config_help) return parser.parse_args() @@ -523,11 +487,12 @@ def main(): gcc_path=args.gcc_path, lld_path=args.lld_path, ld_library_path=args.ld_library_path, - cublas_version=args.cublas_version, - cuda_compute_capabilities=args.cuda_compute_capabilities, - cuda_toolkit_path=args.cuda_toolkit_path, + cuda_version=args.cuda_version, cudnn_version=args.cudnn_version, - nccl_version=args.nccl_version, + cuda_compute_capabilities=args.cuda_compute_capabilities, + local_cuda_path=args.local_cuda_path, + local_cudnn_path=args.local_cudnn_path, + local_nccl_path=args.local_nccl_path, ) ) diff --git a/build_tools/configure/configure_test.py b/build_tools/configure/configure_test.py index d79a97d28d02e4..8457ff40aea3ee 100644 --- a/build_tools/configure/configure_test.py +++ b/build_tools/configure/configure_test.py @@ -34,12 +34,20 @@ # CUDA specific paths and versions _CUDA_SPECIFIC_PATHS_AND_VERSIONS = { - "cublas_version": "12.3", - "cuda_toolkit_path": "/usr/local/cuda-12.2", + "cuda_version": '"12.1.1"', "cuda_compute_capabilities": ["7.5"], - "cudnn_version": "8", + "cudnn_version": '"8.6"', + "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", +} +_CUDA_COMPUTE_CAPABILITIES_AND_LD_LIBRARY_PATH = { + "cuda_compute_capabilities": [ + "sm_50", + "sm_60", + "sm_70", + "sm_80", + "compute_90", + ], "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", - "nccl_version": "2", } @@ -66,6 +74,11 @@ def setUpClass(cls): with (testdata / "cuda_clang.bazelrc").open() as f: cls.cuda_clang_bazelrc_lines = [line.strip() for line in f.readlines()] + with (testdata / "default_cuda_clang.bazelrc").open() as f: + cls.default_cuda_clang_bazelrc_lines = [ + line.strip() for line in f.readlines() + ] + with (testdata / "nvcc_clang.bazelrc").open() as f: cls.nvcc_clang_bazelrc_lines = [line.strip() for line in f.readlines()] @@ -138,6 +151,27 @@ def test_cuda_clang_bazelrc(self): self.assertEqual(bazelrc_lines, self.cuda_clang_bazelrc_lines) + def test_default_cuda_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.CLANG, + using_nccl=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + clang_major_version=17, + **_CUDA_COMPUTE_CAPABILITIES_AND_LD_LIBRARY_PATH, + ) + ) + + self.assertEqual(bazelrc_lines, self.default_cuda_clang_bazelrc_lines) + def test_nvcc_clang_bazelrc(self): config = XLAConfigOptions( backend=Backend.CUDA, diff --git a/build_tools/configure/testdata/cuda_clang.bazelrc b/build_tools/configure/testdata/cuda_clang.bazelrc index 8f164ae15a117c..502bc8541c1285 100644 --- a/build_tools/configure/testdata/cuda_clang.bazelrc +++ b/build_tools/configure/testdata/cuda_clang.bazelrc @@ -3,10 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config cuda_clang build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/build_tools/configure/testdata/default_cuda_clang.bazelrc b/build_tools/configure/testdata/default_cuda_clang.bazelrc new file mode 100644 index 00000000000000..4623f6f52073fa --- /dev/null +++ b/build_tools/configure/testdata/default_cuda_clang.bazelrc @@ -0,0 +1,19 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-18/bin/clang +build --repo_env CC=/usr/lib/llvm-18/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang +build --config cuda_clang +build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/build_tools/configure/testdata/nvcc_clang.bazelrc b/build_tools/configure/testdata/nvcc_clang.bazelrc index 237d615bb84631..8cd19224698311 100644 --- a/build_tools/configure/testdata/nvcc_clang.bazelrc +++ b/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -3,10 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config nvcc_clang build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/build_tools/configure/testdata/nvcc_gcc.bazelrc b/build_tools/configure/testdata/nvcc_gcc.bazelrc index d96eb482396a5c..be90a87545368b 100644 --- a/build_tools/configure/testdata/nvcc_gcc.bazelrc +++ b/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -1,9 +1,8 @@ build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc build --config cuda -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/docs/build_from_source.md b/docs/build_from_source.md index c273f7f3cdf8c0..8b30f9995d08e3 100644 --- a/docs/build_from_source.md +++ b/docs/build_from_source.md @@ -65,12 +65,11 @@ docker exec xla_gpu ./configure.py --backend=CUDA docker exec xla_gpu bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` -If you want to build XLA targets with GPU support without Docker you need to -install the following additional dependencies: -[`cuda-12.3`](https://developer.nvidia.com/cuda-downloads), -[`cuDNN-8.9`](https://developer.nvidia.com/cudnn). +For more details regarding +[TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) -Then configure and build targets using the following commands: +You can build XLA targets with GPU support without Docker as well. Configure and +build targets using the following commands: ``` ./configure.py --backend=CUDA @@ -79,4 +78,4 @@ bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` For more details regarding -[TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) +[hermetic CUDA you can check out this document.](docs/hermetic_cuda.md) diff --git a/docs/developer_guide.md b/docs/developer_guide.md index 53b3efcd8cab5c..b736309b7fbc59 100644 --- a/docs/developer_guide.md +++ b/docs/developer_guide.md @@ -64,6 +64,16 @@ docker exec xla ./configure.py --backend=CUDA docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` +**NB:** please note that with hermetic CUDA rules, you don't have to build XLA +in Docker. You can build XLA for GPU on your machine without GPUs and without +NVIDIA driver installed: + +```sh +./configure.py --backend=CUDA + +bazel build --test_output=all --spawn_strategy=sandboxed //xla/... +``` + Your first build will take quite a while because it has to build the entire stack, including XLA, MLIR, and StableHLO. diff --git a/docs/hermetic_cuda.md b/docs/hermetic_cuda.md new file mode 100644 index 00000000000000..732b4a9c97e4db --- /dev/null +++ b/docs/hermetic_cuda.md @@ -0,0 +1,544 @@ +# Hermetic CUDA overview + +Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s +locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, +and then use CUDA libraries and tools as dependencies in various Bazel targets. +This enables more reproducible builds for Google ML projects and supported CUDA +versions. + +## Supported hermetic CUDA, CUDNN versions + +The supported CUDA versions are specified in `CUDA_REDIST_JSON_DICT` +dictionary, +[third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + +The supported CUDNN versions are specified in `CUDNN_REDIST_JSON_DICT` +dictionary, +[third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + +The `.bazelrc` files of individual projects have `HERMETIC_CUDA_VERSION`, +`HERMETIC_CUDNN_VERSION` environment variables set to the versions used by +default when `--config=cuda` is specified in Bazel command options. + +## Environment variables controlling the hermetic CUDA/CUDNN versions + +`HERMETIC_CUDA_VERSION` environment variable should consist of major, minor and +patch CUDA version, e.g. `12.3.2`. +`HERMETIC_CUDNN_VERSION` environment variable should consist of major, minor and +patch CUDNN version, e.g. `9.1.1`. + +Three ways to set the environment variables for Bazel commands: + +``` +# Add an entry to your `.bazelrc` file +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# OR pass it directly to your specific build command +bazel build --config=cuda \ +--repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ +--repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# OR set the environment variable globally in your shell: +export HERMETIC_CUDA_VERSION="12.3.2" +export LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +export HERMETIC_CUDNN_VERSION="9.1.1" +``` + +If `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` are not present, the +hermetic CUDA/CUDNN repository rules will look up `TF_CUDA_VERSION` and +`TF_CUDNN_VERSION` environment variables values. This is made for the backward +compatibility with non-hermetic CUDA/CUDNN repository rules. + +The mapping between CUDA version and NCCL distribution version to be downloaded +is specified in [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + +## Upgrade hermetic CUDA/CUDNN version +1. Create and submit a pull request with updated `CUDA_REDIST_JSON_DICT`, + `CUDA_REDIST_JSON_DICT` dictionaries in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + + Update `CUDA_NCCL_WHEELS` in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + if needed. + + Update `REDIST_VERSIONS_TO_BUILD_TEMPLATES` in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + if needed. + +2. For RBE executions: update `TF_CUDA_VERSION` and/or `TF_CUDNN_VERSION` in + [toolchains/remote_config/rbe_config.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl). + +3. For RBE executions: update `cuda_version`, `cudnn_version`, `TF_CUDA_VERSION` + and `TF_CUDNN_VERSION` in + [toolchains/remote_config/configs.bzl](https://github.com/openxla/xla/blob/main/tools/toolchains/remote_config/configs.bzl). + +4. For each Google ML project create a separate pull request with updated + `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` in `.bazelrc` file. + + The PR presubmit job executions will launch bazel tests and download hermetic + CUDA/CUDNN distributions. Verify that the presubmit jobs passed before + submitting the PR. + +## Pointing to CUDA/CUDNN/NCCL redistributions on local file system + +You can use the local CUDA/CUDNN/NCCL dirs as a source of redistributions. The following additional environment variables are required: + +``` +LOCAL_CUDA_PATH +LOCAL_CUDNN_PATH +LOCAL_NCCL_PATH +``` + +Example: + +``` +# Add an entry to your `.bazelrc` file +build:cuda --repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" +build:cuda --repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +build:cuda --repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + +# OR pass it directly to your specific build command +bazel build --config=cuda \ +--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ +--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ +--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + +# OR set the environment variable globally in your shell: +export LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" +export LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +export LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" +``` + +The structure of the folders inside CUDA dir should be the following (as if the archived redistributions were unpacked into one place): + +``` +/ + include/ + bin/ + lib/ + nvvm/ +``` + +The structure of the folders inside CUDNN dir should be the following: + +``` + + include/ + lib/ +``` + +The structure of the folders inside NCCL dir should be the following: + +``` + + include/ + lib/ +``` + +## Custom CUDA/CUDNN archives and NCCL wheels + +There are three options that allow usage of custom CUDA/CUDNN distributions. + +### Custom CUDA/CUDNN redistribution JSON files + +This option allows to use custom distributions for all CUDA/CUDNN dependencies +in Google ML projects. + +1. Create `cuda_redist.json` and/or `cudnn_redist.json` files. + + `cuda_redist.json` show follow the format below: + + ``` + { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + } + }, + } + ``` + + `cudnn_redist.json` show follow the format below: + + ``` + { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + } + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + } + } + } + } + ``` + + The `relative_path` field can be replaced with `full_path` for the full URLs + and absolute local paths starting with `file:///`. + +2. In the downstream project dependent on XLA, update the hermetic cuda JSON + repository call in `WORKSPACE` file. Both web links and local file paths are + allowed. Example: + + ``` + _CUDA_JSON_DICT = { + "12.4.0": [ + "file:///home/user/Downloads/redistrib_12.4.0_updated.json", + ], + } + + _CUDNN_JSON_DICT = { + "9.0.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.0.0.json", + ], + } + + cuda_json_init_repository( + cuda_json_dict = _CUDA_JSON_DICT, + cudnn_json_dict = _CUDNN_JSON_DICT, + ) + ``` + + If JSON files contain relative paths to distributions, the path prefix should + be updated in `cuda_redist_init_repositories()` and + `cudnn_redist_init_repository()` calls. Example + + ``` + cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, + cuda_redist_path_prefix = "file:///usr/Downloads/dists/", + ) + ``` + +### Custom CUDA/CUDNN distributions + +This option allows to use custom distributions for some CUDA/CUDNN dependencies +in Google ML projects. + +1. In the downstream project dependent on XLA, remove the lines below: + + ``` + <...> + "CUDA_REDIST_JSON_DICT", + <...> + "CUDNN_REDIST_JSON_DICT", + <...> + + cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT, + ) + + load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", + ) + ``` + +2. In the same `WORKSPACE` file, create dictionaries with distribution paths. + + The dictionary with CUDA distributions show follow the format below: + + ``` + _CUSTOM_CUDA_REDISTRIBUTIONS = { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + } + }, + } + ``` + + The dictionary with CUDNN distributions show follow the format below: + + ``` + _CUSTOM_CUDNN_REDISTRIBUTIONS = { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + } + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + } + } + } + } + ``` + + The `relative_path` field can be replaced with `full_path` for the full URLs + and absolute local paths starting with `file:///`. + +2. In the same `WORKSPACE` file, pass the created dictionaries to the repository + rule. If the dictionaries contain relative paths to distributions, the path + prefix should be updated in `cuda_redist_init_repositories()` and + `cudnn_redist_init_repository()` calls. + + ``` + cuda_redist_init_repositories( + cuda_redistributions = _CUSTOM_CUDA_REDISTRIBUTIONS, + cuda_redist_path_prefix = "file:///home/usr/Downloads/dists/", + ) + + cudnn_redist_init_repository( + cudnn_redistributions = _CUSTOM_CUDNN_REDISTRIBUTIONS, + cudnn_redist_path_prefix = "file:///home/usr/Downloads/dists/cudnn/" + ) + ``` +### Combination of the options above + +In the example below, `CUDA_REDIST_JSON_DICT` is merged with custom JSON data in +`_CUDA_JSON_DICT`, and `CUDNN_REDIST_JSON_DICT` is merged with +`_CUDNN_JSON_DICT`. + +The distributions data in `_CUDA_DIST_DICT` overrides the content of resulting +CUDA JSON file, and the distributions data in `_CUDNN_DIST_DICT` overrides the +content of resulting CUDNN JSON file. The NCCL wheels data is merged from +`CUDA_NCCL_WHEELS` and `_NCCL_WHEEL_DICT`. + +``` +load( + //third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDA_NCCL_WHEELS", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_PATH_PREFIX", + "CUDNN_REDIST_JSON_DICT", +) + +_CUDA_JSON_DICT = { + "12.4.0": [ + "file:///usr/Downloads/redistrib_12.4.0_updated.json", + ], +} + +_CUDNN_JSON_DICT = { + "9.0.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.0.0.json", + ], +} + +cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT | _CUDA_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT | _CUDNN_JSON_DICT, +) + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) + +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +_CUDA_DIST_DICT = { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + }, + }, + "libcusolver": { + "linux-x86_64": { + "full_path": "file:///usr/Downloads/dists/libcusolver-linux-x86_64-11.6.0.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "libcusolver-linux-sbsa-11.6.0.99-archive.tar.xz", + }, + }, +} + +_CUDNN_DIST_DICT = { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + }, + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + }, + }, + }, +} + +cudnn_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS | _CUDA_DIST_DICT, + cuda_redist_path_prefix = "file:///usr/Downloads/dists/", +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS | _CUDNN_DIST_DICT, + cudnn_redist_path_prefix = "file:///usr/Downloads/dists/cudnn/" +) + +load( + "//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +_NCCL_WHEEL_DICT = { + "12.4.0": { + "x86_64-unknown-linux-gnu": { + "url": "https://files.pythonhosted.org/packages/38/00/d0d4e48aef772ad5aebcf70b73028f88db6e5640b36c38e90445b7a57c45/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", + }, + }, +} + +nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS | _NCCL_WHEEL_DICT, +) +``` + +## DEPRECATED: Non-hermetic CUDA/CUDNN usage +Though non-hermetic CUDA/CUDNN usage is deprecated, it might be used for +some experiments currently unsupported officially (for example, building wheels +on Windows with CUDA). + +Here are the steps to use non-hermetic CUDA installed locally in Google ML +projects: + +1. Delete calls to hermetic CUDA repository rules from the `WORKSPACE` + file of the project dependent on XLA. + +2. Add the calls to non-hermetic CUDA repository rules to the bottom of the + `WORKSPACE` file. + + For XLA and JAX: + ``` + load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") + cuda_configure(name = "local_config_cuda") + load("@tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure") + nccl_configure(name = "local_config_nccl") + ``` + + For Tensorflow: + ``` + load("@local_tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") + cuda_configure(name = "local_config_cuda") + load("@local_tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure") + nccl_configure(name = "local_config_nccl") + ``` + +3. Set the following environment variables directly in your shell or in + `.bazelrc` file as shown below: + ``` + build:cuda --action_env=TF_CUDA_VERSION= + build:cuda --action_env=TF_CUDNN_VERSION= + build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES= + build:cuda --action_env=LD_LIBRARY_PATH= + build:cuda --action_env=CUDA_TOOLKIT_PATH= + build:cuda --action_env=TF_CUDA_PATHS= + build:cuda --action_env=NCCL_INSTALL_PATH= + ``` + + Note that `TF_CUDA_VERSION` and `TF_CUDNN_VERSION` should consist of major and + minor versions only (e.g. `12.3` for CUDA and `9.1` for CUDNN). + +4. Now you can run `bazel` command to use locally installed CUDA and CUDNN. + + For XLA, no changes in the command options are needed. + + For JAX, use `--override_repository=tsl=` flag in the Bazel command + options. + + For Tensorflow, use `--override_repository=local_tsl=` flag in the + Bazel command options. + +## Configure hermetic CUDA + +1. In the downstream project dependent on XLA, add the following lines to the + bottom of the `WORKSPACE` file: + + Note: use @local_tsl instead of @tsl in Tensorflow project. + + ``` + load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", + ) + + cuda_json_init_repository() + + load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", + ) + load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", + ) + + cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, + ) + + cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, + ) + + load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", + ) + + cuda_configure(name = "local_config_cuda") + + load( + "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", + ) + + nccl_redist_init_repository() + + load( + "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", + ) + + nccl_configure(name = "local_config_nccl") + ``` + +2. To select specific versions of hermetic CUDA and CUDNN, set the + `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` environment variables + respectively. Use only supported versions. You may set the environment + variables directly in your shell or in `.bazelrc` file as shown below: + ``` + build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" + build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + build:cuda --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" + ``` + +3. To enable Hermetic CUDA during test execution, or when running a binary via + bazel, make sure to add `--@local_config_cuda//cuda:include_hermetic_cuda_libs=true` + flag to your bazel command. You can provide it either directly in a shell or + in `.bazelrc`: + ``` + test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true + ``` + The flag is needed to make sure that CUDA dependencies are properly provided + to test executables. The flag is false by default to avoid unwanted coupling + of Google-released Python wheels to CUDA binaries. diff --git a/third_party/tsl/.bazelrc b/third_party/tsl/.bazelrc index 3ea641dd518adf..9e565e91a1b903 100644 --- a/third_party/tsl/.bazelrc +++ b/third_party/tsl/.bazelrc @@ -219,11 +219,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -232,22 +237,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -543,9 +548,6 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang @@ -630,7 +632,6 @@ build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/cla # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -644,7 +645,6 @@ build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute diff --git a/third_party/tsl/WORKSPACE b/third_party/tsl/WORKSPACE index 19350e3dbba762..a83a9e63f4143a 100644 --- a/third_party/tsl/WORKSPACE +++ b/third_party/tsl/WORKSPACE @@ -50,3 +50,50 @@ tsl_workspace1() load(":workspace0.bzl", "tsl_workspace0") tsl_workspace0() + +load( + "//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/third_party/tsl/opensource_only.files b/third_party/tsl/opensource_only.files index 1d52b26e043f7b..f93d02d633d3c7 100644 --- a/third_party/tsl/opensource_only.files +++ b/third_party/tsl/opensource_only.files @@ -21,6 +21,7 @@ third_party/git/BUILD.tpl: third_party/git/BUILD: third_party/git/git_configure.bzl: third_party/gpus/BUILD: +third_party/gpus/compiler_common_tools.bzl: third_party/gpus/crosstool/BUILD.rocm.tpl: third_party/gpus/crosstool/BUILD.sycl.tpl: third_party/gpus/crosstool/BUILD.tpl: @@ -38,6 +39,27 @@ third_party/gpus/cuda/LICENSE: third_party/gpus/cuda/build_defs.bzl.tpl: third_party/gpus/cuda/cuda_config.h.tpl: third_party/gpus/cuda/cuda_config.py.tpl: +third_party/gpus/cuda/hermetic/BUILD.tpl: +third_party/gpus/cuda/hermetic/BUILD: +third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_configure.bzl: +third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl: +third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl: third_party/gpus/cuda_configure.bzl: third_party/gpus/find_cuda_config:.py third_party/gpus/rocm/BUILD.tpl: @@ -67,6 +89,9 @@ third_party/nccl/archive.BUILD: third_party/nccl/archive.patch: third_party/nccl/build_defs.bzl.tpl: third_party/nccl/generated_names.bzl.tpl: +third_party/nccl/hermetic/BUILD: +third_party/nccl/hermetic/cuda_nccl.BUILD.tpl: +third_party/nccl/hermetic/nccl_configure.bzl: third_party/nccl/nccl_configure.bzl: third_party/nccl/system.BUILD.tpl: third_party/nvtx/BUILD: diff --git a/third_party/tsl/third_party/gpus/check_cuda_libs.py b/third_party/tsl/third_party/gpus/check_cuda_libs.py index afd6380b0ac203..b1a10a86b9aac6 100644 --- a/third_party/tsl/third_party/gpus/check_cuda_libs.py +++ b/third_party/tsl/third_party/gpus/check_cuda_libs.py @@ -14,6 +14,9 @@ # ============================================================================== """Verifies that a list of libraries is installed on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + Takes a list of arguments with every two subsequent arguments being a logical tuple of (path, check_soname). The path to the library and either True or False to indicate whether to check the soname field on the shared library. diff --git a/third_party/tsl/third_party/gpus/compiler_common_tools.bzl b/third_party/tsl/third_party/gpus/compiler_common_tools.bzl new file mode 100644 index 00000000000000..bd07f49ec457bb --- /dev/null +++ b/third_party/tsl/third_party/gpus/compiler_common_tools.bzl @@ -0,0 +1,174 @@ +"""Common compiler functions. """ + +load( + "//third_party/remote_config:common.bzl", + "err_out", + "raw_exec", + "realpath", +) + +def to_list_of_strings(elements): + """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. + + This is to be used to put a list of strings into the bzl file templates + so it gets interpreted as list of strings in Starlark. + + Args: + elements: list of string elements + + Returns: + single string of elements wrapped in quotes separated by a comma.""" + quoted_strings = ["\"" + element + "\"" for element in elements] + return ", ".join(quoted_strings) + +_INC_DIR_MARKER_BEGIN = "#include <...>" + +# OSX add " (framework directory)" at the end of line, strip it. +_OSX_FRAMEWORK_SUFFIX = " (framework directory)" +_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) + +# TODO(dzc): Once these functions have been factored out of Bazel's +# cc_configure.bzl, load them from @bazel_tools instead. +def _cxx_inc_convert(path): + """Convert path returned by cc -E xc++ in a complete path.""" + path = path.strip() + if path.endswith(_OSX_FRAMEWORK_SUFFIX): + path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() + return path + +def _normalize_include_path(repository_ctx, path): + """Normalizes include paths before writing them to the crosstool. + + If path points inside the 'crosstool' folder of the repository, a relative + path is returned. + If path points outside the 'crosstool' folder, an absolute path is returned. + """ + path = str(repository_ctx.path(path)) + crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) + + if path.startswith(crosstool_folder): + # We drop the path to "$REPO/crosstool" and a trailing path separator. + return path[len(crosstool_folder) + 1:] + return path + +def _is_compiler_option_supported(repository_ctx, cc, option): + """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" + result = repository_ctx.execute([ + cc, + option, + "-o", + "/dev/null", + "-c", + str(repository_ctx.path("tools/cpp/empty.cc")), + ]) + return result.stderr.find(option) == -1 + +def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sys_root): + """Compute the list of default C or C++ include directories.""" + if lang_is_cpp: + lang = "c++" + else: + lang = "c" + sysroot = [] + if tf_sys_root: + sysroot += ["--sysroot", tf_sys_root] + result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + + sysroot) + stderr = err_out(result) + index1 = stderr.find(_INC_DIR_MARKER_BEGIN) + if index1 == -1: + return [] + index1 = stderr.find("\n", index1) + if index1 == -1: + return [] + index2 = stderr.rfind("\n ") + if index2 == -1 or index2 < index1: + return [] + index2 = stderr.find("\n", index2 + 1) + if index2 == -1: + inc_dirs = stderr[index1 + 1:] + else: + inc_dirs = stderr[index1 + 1:index2].strip() + + print_resource_dir_supported = _is_compiler_option_supported( + repository_ctx, + cc, + "-print-resource-dir", + ) + + if print_resource_dir_supported: + resource_dir = repository_ctx.execute( + [cc, "-print-resource-dir"], + ).stdout.strip() + "/share" + inc_dirs += "\n" + resource_dir + + compiler_includes = [ + _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) + for p in inc_dirs.split("\n") + ] + + # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc + # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) + # but Bazel might encounter either (usually reported by the compiler) + # especially when a compiler wrapper (e.g. ccache) is used. + # So we need to also include paths where symlinks are not resolved. + + # Try to find real path to CC installation to "see through" compiler wrappers + # GCC has the path to g++ + index1 = result.stderr.find("COLLECT_GCC=") + if index1 != -1: + index1 = result.stderr.find("=", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname + else: + # Clang has the directory + index1 = result.stderr.find("InstalledDir: ") + if index1 != -1: + index1 = result.stderr.find(" ", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname + else: + # Fallback to the CC path + cc_topdir = repository_ctx.path(cc).dirname.dirname + + # We now have the compiler installation prefix, e.g. /symlink/gcc + # And the resolved installation prefix, e.g. /opt/gcc + cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() + cc_topdir = str(cc_topdir).strip() + + # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. + # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] + # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] + if cc_topdir_resolved != cc_topdir: + unresolved_compiler_includes = [ + cc_topdir + inc[len(cc_topdir_resolved):] + for inc in compiler_includes + if inc.startswith(cc_topdir_resolved) + ] + compiler_includes = compiler_includes + unresolved_compiler_includes + return compiler_includes + +def get_cxx_inc_directories(repository_ctx, cc, tf_sys_root): + """Compute the list of default C and C++ include directories.""" + + # For some reason `clang -xc` sometimes returns include paths that are + # different from the ones from `clang -xc++`. (Symlink and a dir) + # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists + includes_cpp = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + True, + tf_sys_root, + ) + includes_c = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + False, + tf_sys_root, + ) + + return includes_cpp + [ + inc + for inc in includes_c + if inc not in includes_cpp + ] diff --git a/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl b/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl index 8eda7a1cf6ac2b..b9553d9b99ecfe 100644 --- a/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl +++ b/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl @@ -2,6 +2,7 @@ # Update cuda_configure.bzl#verify_build_defines when adding new variables. load(":cc_toolchain_config.bzl", "cc_toolchain_config") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") licenses(["restricted"]) @@ -133,9 +134,17 @@ filegroup( srcs = [], ) +filegroup( + name = "cuda_nvcc_files", + srcs = %{cuda_nvcc_files}, +) + filegroup( name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + srcs = [ + ":cuda_nvcc_files", + ":clang/bin/crosstool_wrapper_driver_is_not_gcc" + ], ) filegroup( diff --git a/third_party/tsl/third_party/gpus/cuda/BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/BUILD.tpl index 44cdbe34b25f86..094431dcedfc12 100644 --- a/third_party/tsl/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/tsl/third_party/gpus/cuda/BUILD.tpl @@ -1,6 +1,10 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Please use `hermetic/cuda_configure` instead. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like @@ -144,7 +148,6 @@ cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], data = ["cuda/lib/%{cusolver_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -220,7 +223,6 @@ cc_library( name = "cusparse", srcs = ["cuda/lib/%{cusparse_lib}"], data = ["cuda/lib/%{cusparse_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -242,6 +244,41 @@ py_library( srcs = ["cuda/cuda_config.py"], ) +# Build setting that is always true (i.e. it can not be changed on the +# command line). It is used to create the config settings below that are +# always or never satisfied. +bool_setting( + name = "true_setting", + visibility = ["//visibility:private"], + build_setting_default = True, +) + +# Config settings whether TensorFlow is built with hermetic CUDA. +# These configs are never satisfied. +config_setting( + name = "hermetic_cuda_tools", + flag_values = {":true_setting": "False"}, +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":true_setting": "False"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + %{copy_rules} cc_library( diff --git a/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl index dee0e898d9ae7a..6b25c8398a7144 100644 --- a/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl @@ -1,3 +1,7 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Hermetic CUDA repository rule doesn't support Windows. +# Please use `hermetic/cuda_configure`. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD b/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl new file mode 100644 index 00000000000000..2302f4a3d8f063 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -0,0 +1,261 @@ +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cc_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + ], + deps = [":cudart_headers", + ":cublas_headers", + ":cccl_headers", + ":nvtx_headers", + ":nvcc_headers", + ":cusolver_headers", + ":cufft_headers", + ":cusparse_headers", + ":curand_headers", + ":cupti_headers", + ":nvml_headers"], +) + +cc_library( + name = "cudart_static", + srcs = ["@cuda_cudart//:static"], + linkopts = [ + "-ldl", + "-lpthread", + %{cudart_static_linkopt} + ], +) + +alias( + name = "cuda_driver", + actual = "@cuda_cudart//:cuda_driver", +) + +alias( + name = "cudart_headers", + actual = "@cuda_cudart//:headers", +) + +alias( + name = "cudart", + actual = "@cuda_cudart//:cudart", +) + +alias( + name = "nvtx_headers", + actual = "@cuda_nvtx//:headers", +) + +alias( + name = "nvml_headers", + actual = "@cuda_nvml//:headers", +) + +alias( + name = "nvcc_headers", + actual = "@cuda_nvcc//:headers", +) + +alias( + name = "cccl_headers", + actual = "@cuda_cccl//:headers", +) + +alias( + name = "cublas_headers", + actual = "@cuda_cublas//:headers", +) + +alias( + name = "cusolver_headers", + actual = "@cuda_cusolver//:headers", +) + +alias( + name = "cufft_headers", + actual = "@cuda_cufft//:headers", +) + +alias( + name = "cusparse_headers", + actual = "@cuda_cusparse//:headers", +) + +alias( + name = "curand_headers", + actual = "@cuda_curand//:headers", +) + +alias( + name = "cublas", + actual = "@cuda_cublas//:cublas", +) + +alias( + name = "cublasLt", + actual = "@cuda_cublas//:cublasLt", +) + +alias( + name = "cusolver", + actual = "@cuda_cusolver//:cusolver", +) + +alias( + name = "cudnn", + actual = "@cuda_cudnn//:cudnn", +) + +alias( + name = "cudnn_header", + actual = "@cuda_cudnn//:headers", +) + +alias( + name = "cufft", + actual = "@cuda_cufft//:cufft", +) + +alias( + name = "curand", + actual = "@cuda_curand//:curand", +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = ":cuda_headers", +) + +alias( + name = "cupti_headers", + actual = "@cuda_cupti//:headers", +) + +alias( + name = "cupti_dsos", + actual = "@cuda_cupti//:cupti", +) + +alias( + name = "cusparse", + actual = "@cuda_cusparse//:cusparse", +) + +alias( + name = "cuda-nvvm", + actual = "@cuda_nvcc//:nvvm", +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +# Config setting whether TensorFlow is built with hermetic CUDA. +alias( + name = "hermetic_cuda_tools", + actual = "@local_config_cuda//:is_cuda_enabled", +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":include_hermetic_cuda_libs": "True"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + +cc_library( + # This is not yet fully supported, but we need the rule + # to make bazel query happy. + name = "nvptxcompiler", +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl new file mode 100644 index 00000000000000..85c0cbbb196fef --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl @@ -0,0 +1,15 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + hdrs = glob([ + %{comment}"include/cub/**", + %{comment}"include/cuda/**", + %{comment}"include/nv/**", + %{comment}"include/thrust/**", + ]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl new file mode 100644 index 00000000000000..270b73c3884855 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -0,0 +1,521 @@ +"""Repository rule for hermetic CUDA autoconfiguration. + +`cuda_configure` depends on the following environment variables: + + * `TF_NEED_CUDA`: Whether to enable building with CUDA. + * `TF_NVCC_CLANG`: Whether to use clang for C++ and NVCC for Cuda compilation. + * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for + both host and device code compilation. + * `TF_SYSROOT`: The sysroot to use when compiling. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + * `HERMETIC_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default + is `3.5,5.2`. If not specified, the value will be determined by the + `TF_CUDA_COMPUTE_CAPABILITIES`. + * `PYTHON_BIN_PATH`: The python binary path +""" + +load( + "//third_party/gpus:compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", + "which", +) + +def _find_cc(repository_ctx): + """Find the C++ compiler.""" + cc_path_envvar = _CLANG_CUDA_COMPILER_PATH + cc_name = "clang" + + cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) + if cc_name_from_env: + cc_name = cc_name_from_env + if cc_name.startswith("/"): + # Return the absolute path. + return cc_name + cc = which(repository_ctx, cc_name) + if cc == None: + fail(("Cannot find {}, either correct your path or set the {}" + + " environment variable").format(cc_name, cc_path_envvar)) + return cc + +def _auto_configure_fail(msg): + """Output failure message when cuda configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _verify_build_defines(params): + """Verify all variables that crosstool/BUILD.tpl expects are substituted. + + Args: + params: dict of variables that will be passed to the BUILD.tpl template. + """ + missing = [] + for param in [ + "cxx_builtin_include_directories", + "extra_no_canonical_prefixes_flags", + "host_compiler_path", + "host_compiler_prefix", + "host_compiler_warnings", + "linker_bin_path", + "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", + "unfiltered_compile_flags", + "win_compiler_deps", + ]: + if ("%{" + param + "}") not in params: + missing.append(param) + + if missing: + _auto_configure_fail( + "BUILD.tpl template is missing these variables: " + + str(missing) + + ".\nWe only got: " + + str(params) + + ".", + ) + +def get_cuda_version(repository_ctx): + return (get_host_environ(repository_ctx, HERMETIC_CUDA_VERSION) or + get_host_environ(repository_ctx, TF_CUDA_VERSION)) + +def enable_cuda(repository_ctx): + """Returns whether to build with CUDA support.""" + return int(get_host_environ(repository_ctx, TF_NEED_CUDA, False)) + +def _flag_enabled(repository_ctx, flag_name): + return get_host_environ(repository_ctx, flag_name) == "1" + +def _use_nvcc_and_clang(repository_ctx): + # Returns the flag if we need to use clang for C++ and NVCC for Cuda. + return _flag_enabled(repository_ctx, _TF_NVCC_CLANG) + +def _tf_sysroot(repository_ctx): + return get_host_environ(repository_ctx, _TF_SYSROOT, "") + +def _py_tmpl_dict(d): + return {"%{cuda_config}": str(d)} + +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "\"\"," if cpu_value == "Darwin" else "\"-lrt\"," + +def _compute_capabilities(repository_ctx): + """Returns a list of strings representing cuda compute capabilities. + + Args: + repository_ctx: the repo rule's context. + + Returns: + list of cuda architectures to compile for. 'compute_xy' refers to + both PTX and SASS, 'sm_xy' refers to SASS only. + """ + capabilities = (get_host_environ( + repository_ctx, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + ) or + get_host_environ( + repository_ctx, + _TF_CUDA_COMPUTE_CAPABILITIES, + )) + capabilities = (capabilities or "compute_35,compute_52").split(",") + + # Map old 'x.y' capabilities to 'compute_xy'. + if len(capabilities) > 0 and all([len(x.split(".")) == 2 for x in capabilities]): + # If all capabilities are in 'x.y' format, only include PTX for the + # highest capability. + cc_list = sorted([x.replace(".", "") for x in capabilities]) + capabilities = [ + "sm_%s" % x + for x in cc_list[:-1] + ] + ["compute_%s" % cc_list[-1]] + for i, capability in enumerate(capabilities): + parts = capability.split(".") + if len(parts) != 2: + continue + capabilities[i] = "compute_%s%s" % (parts[0], parts[1]) + + # Make list unique + capabilities = dict(zip(capabilities, capabilities)).keys() + + # Validate capabilities. + for capability in capabilities: + if not capability.startswith(("compute_", "sm_")): + _auto_configure_fail("Invalid compute capability: %s" % capability) + for prefix in ["compute_", "sm_"]: + if not capability.startswith(prefix): + continue + if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(): + continue + if len(capability) == len(prefix) + 3 and capability.endswith("90a"): + continue + _auto_configure_fail("Invalid compute capability: %s" % capability) + + return capabilities + +def _compute_cuda_extra_copts(compute_capabilities): + copts = ["--no-cuda-include-ptx=all"] + for capability in compute_capabilities: + if capability.startswith("compute_"): + capability = capability.replace("compute_", "sm_") + copts.append("--cuda-include-ptx=%s" % capability) + copts.append("--cuda-gpu-arch=%s" % capability) + + return str(copts) + +def _get_cuda_config(repository_ctx): + """Detects and returns information about the CUDA installation on the system. + + Args: + repository_ctx: The repository context. + + Returns: + A struct containing the following fields: + cuda_version: The version of CUDA on the system. + cudart_version: The CUDA runtime version on the system. + cudnn_version: The version of cuDNN on the system. + compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. + """ + + return struct( + cuda_version = get_cuda_version(repository_ctx), + cupti_version = repository_ctx.read(repository_ctx.attr.cupti_version), + cudart_version = repository_ctx.read(repository_ctx.attr.cudart_version), + cublas_version = repository_ctx.read(repository_ctx.attr.cublas_version), + cusolver_version = repository_ctx.read(repository_ctx.attr.cusolver_version), + curand_version = repository_ctx.read(repository_ctx.attr.curand_version), + cufft_version = repository_ctx.read(repository_ctx.attr.cufft_version), + cusparse_version = repository_ctx.read(repository_ctx.attr.cusparse_version), + cudnn_version = repository_ctx.read(repository_ctx.attr.cudnn_version), + compute_capabilities = _compute_capabilities(repository_ctx), + cpu_value = get_cpu_value(repository_ctx), + ) + +_DUMMY_CROSSTOOL_BZL_FILE = """ +def error_gpu_disabled(): + fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + + "to build with GPU support. Please re-run ./configure and enter 'Y' " + + "at the prompt to build with GPU support.") + + native.genrule( + name = "error_gen_crosstool", + outs = ["CROSSTOOL"], + cmd = "echo 'Should not be run.' && exit 1", + ) + + native.filegroup( + name = "crosstool", + srcs = [":CROSSTOOL"], + output_licenses = ["unencumbered"], + ) +""" + +_DUMMY_CROSSTOOL_BUILD_FILE = """ +load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") + +error_gpu_disabled() +""" + +def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + + # Set up BUILD file for cuda/. + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "False", + "%{cuda_extra_copts}": "[]", + "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + }, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": "", + "%{cudart_version}": "", + "%{cupti_version}": "", + "%{cublas_version}": "", + "%{cusolver_version}": "", + "%{curand_version}": "", + "%{cufft_version}": "", + "%{cusparse_version}": "", + "%{cudnn_version}": "", + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({}), + ) + + # If cuda_configure is not configured to build with GPU support, and the user + # attempts to build with --config=cuda, add a dummy build rule to intercept + # this and fail with an actionable error message. + repository_ctx.file( + "crosstool/error_gpu_disabled.bzl", + _DUMMY_CROSSTOOL_BZL_FILE, + ) + repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) + +def _create_local_cuda_repository(repository_ctx): + """Creates the repository containing files set up to build with CUDA.""" + cuda_config = _get_cuda_config(repository_ctx) + + # Set up BUILD file for cuda/ + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + cuda_config.compute_capabilities, + ), + "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt( + cuda_config.cpu_value, + ), + }, + ) + + is_nvcc_and_clang = _use_nvcc_and_clang(repository_ctx) + tf_sysroot = _tf_sysroot(repository_ctx) + + # Set up crosstool/ + cc = _find_cc(repository_ctx) + host_compiler_includes = get_cxx_inc_directories( + repository_ctx, + cc, + tf_sysroot, + ) + + cuda_defines = {} + + # We do not support hermetic CUDA on Windows. + # This ensures the CROSSTOOL file parser is happy. + cuda_defines.update({ + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + "%{win_compiler_deps}": ":empty", + }) + + cuda_defines["%{builtin_sysroot}"] = tf_sysroot + cuda_defines["%{cuda_toolkit_path}"] = repository_ctx.attr.nvcc_binary.workspace_root + cuda_defines["%{compiler}"] = "clang" + cuda_defines["%{host_compiler_prefix}"] = "/usr/bin" + cuda_defines["%{linker_bin_path}"] = "" + cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" + cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( + host_compiler_includes, + ) + cuda_defines["%{cuda_nvcc_files}"] = "if_cuda([\"@{nvcc_archive}//:bin\", \"@{nvcc_archive}//:nvvm\"])".format( + nvcc_archive = repository_ctx.attr.nvcc_binary.repo_name, + ) + + if not is_nvcc_and_clang: + cuda_defines["%{host_compiler_path}"] = str(cc) + cuda_defines["%{host_compiler_warnings}"] = """ + # Some parts of the codebase set -Werror and hit this warning, so + # switch it off for now. + "-Wno-invalid-partial-specialization" + """ + cuda_defines["%{compiler_deps}"] = ":cuda_nvcc_files" + repository_ctx.file( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + "", + ) + else: + cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{host_compiler_warnings}"] = "" + + nvcc_relative_path = "%s/%s" % ( + repository_ctx.attr.nvcc_binary.workspace_root, + repository_ctx.attr.nvcc_binary.name, + ) + cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + + wrapper_defines = { + "%{cpu_compiler}": str(cc), + "%{cuda_version}": cuda_config.cuda_version, + "%{nvcc_path}": nvcc_relative_path, + "%{host_compiler_path}": str(cc), + "%{use_clang_compiler}": "True", + } + repository_ctx.template( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + repository_ctx.attr.crosstool_wrapper_driver_is_not_gcc_tpl, + wrapper_defines, + ) + + _verify_build_defines(cuda_defines) + + # Only expand template variables in the BUILD file + repository_ctx.template( + "crosstool/BUILD", + repository_ctx.attr.crosstool_build_tpl, + cuda_defines, + ) + + # No templating of cc_toolchain_config - use attributes and templatize the + # BUILD file. + repository_ctx.template( + "crosstool/cc_toolchain_config.bzl", + repository_ctx.attr.cc_toolchain_config_tpl, + {}, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": cuda_config.cuda_version, + "%{cudart_version}": cuda_config.cudart_version, + "%{cupti_version}": cuda_config.cupti_version, + "%{cublas_version}": cuda_config.cublas_version, + "%{cusolver_version}": cuda_config.cusolver_version, + "%{curand_version}": cuda_config.curand_version, + "%{cufft_version}": cuda_config.cufft_version, + "%{cusparse_version}": cuda_config.cusparse_version, + "%{cudnn_version}": cuda_config.cudnn_version, + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": ", ".join([ + cc.split("_")[1] + for cc in cuda_config.compute_capabilities + ]), + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({ + "cuda_version": cuda_config.cuda_version, + "cudnn_version": cuda_config.cudnn_version, + "cuda_compute_capabilities": cuda_config.compute_capabilities, + "cpu_compiler": str(cc), + }), + ) + +def _cuda_autoconf_impl(repository_ctx): + """Implementation of the cuda_autoconf repository rule.""" + build_file = repository_ctx.attr.local_config_cuda_build_file + + if not enable_cuda(repository_ctx): + _create_dummy_repository(repository_ctx) + else: + _create_local_cuda_repository(repository_ctx) + + repository_ctx.symlink(build_file, "BUILD") + +_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH" +_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" +_HERMETIC_CUDA_COMPUTE_CAPABILITIES = "HERMETIC_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +HERMETIC_CUDA_VERSION = "HERMETIC_CUDA_VERSION" +TF_CUDA_VERSION = "TF_CUDA_VERSION" +TF_NEED_CUDA = "TF_NEED_CUDA" +_TF_NVCC_CLANG = "TF_NVCC_CLANG" +_TF_SYSROOT = "TF_SYSROOT" + +_ENVIRONS = [ + _CLANG_CUDA_COMPILER_PATH, + TF_NEED_CUDA, + _TF_NVCC_CLANG, + TF_CUDA_VERSION, + HERMETIC_CUDA_VERSION, + _TF_CUDA_COMPUTE_CAPABILITIES, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + _TF_SYSROOT, + _PYTHON_BIN_PATH, + "TMP", + "TMPDIR", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", +] + +cuda_configure = repository_rule( + implementation = _cuda_autoconf_impl, + environ = _ENVIRONS, + attrs = { + "environ": attr.string_dict(), + "cublas_version": attr.label(default = Label("@cuda_cublas//:version.txt")), + "cudart_version": attr.label(default = Label("@cuda_cudart//:version.txt")), + "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.txt")), + "cufft_version": attr.label(default = Label("@cuda_cufft//:version.txt")), + "cupti_version": attr.label(default = Label("@cuda_cupti//:version.txt")), + "curand_version": attr.label(default = Label("@cuda_curand//:version.txt")), + "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.txt")), + "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.txt")), + "nvcc_binary": attr.label(default = Label("@cuda_nvcc//:bin/nvcc")), + "local_config_cuda_build_file": attr.label(default = Label("//third_party/gpus:local_config_cuda.BUILD")), + "build_defs_tpl": attr.label(default = Label("//third_party/gpus/cuda:build_defs.bzl.tpl")), + "cuda_build_tpl": attr.label(default = Label("//third_party/gpus/cuda/hermetic:BUILD.tpl")), + "cuda_config_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.h.tpl")), + "cuda_config_py_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.py.tpl")), + "crosstool_wrapper_driver_is_not_gcc_tpl": attr.label(default = Label("//third_party/gpus/crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl")), + "crosstool_build_tpl": attr.label(default = Label("//third_party/gpus/crosstool:BUILD.tpl")), + "cc_toolchain_config_tpl": attr.label(default = Label("//third_party/gpus/crosstool:cc_toolchain_config.bzl.tpl")), + }, +) +"""Detects and configures the hermetic CUDA toolchain. + +Add the following to your WORKSPACE file: + +```python +cuda_configure(name = "local_config_cuda") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl new file mode 100644 index 00000000000000..510235d801de4e --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -0,0 +1,44 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cublas_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublas.so.%{libcublas_version}", + deps = [":cublasLt"], +) + +cc_import( + name = "cublasLt_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublasLt.so.%{libcublaslt_version}", +) +%{multiline_comment} +cc_library( + name = "cublas", + visibility = ["//visibility:public"], + %{comment}deps = [":cublas_shared_library"], +) + +cc_library( + name = "cublasLt", + visibility = ["//visibility:public"], + %{comment}deps = [":cublasLt_shared_library"], +) + +cc_library( + name = "headers", + %{comment}hdrs = [ + %{comment}"include/cublas.h", + %{comment}"include/cublasLt.h", + %{comment}"include/cublas_api.h", + %{comment}"include/cublas_v2.h", + %{comment}], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl new file mode 100644 index 00000000000000..f7ba469b42b76a --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -0,0 +1,126 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) + +filegroup( + name = "static", + srcs = ["lib/libcudart_static.a"], + visibility = ["@local_config_cuda//cuda:__pkg__"], +) +%{multiline_comment} +# TODO: Replace system provided library with hermetic NVIDIA driver library. +cc_import( + name = "cuda_driver_shared_library", + interface_library = "lib/stubs/libcuda.so", + system_provided = 1, +) + +cc_import( + name = "cudart_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcudart.so.%{libcudart_version}", +) +%{multiline_comment} +cc_library( + name = "cuda_driver", + %{comment}deps = [":cuda_driver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart", + %{comment}deps = [ + %{comment}":cuda_driver", + %{comment}":cudart_shared_library", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/builtin_types.h", + %{comment}"include/channel_descriptor.h", + %{comment}"include/common_functions.h", + %{comment}"include/cooperative_groups/**", + %{comment}"include/cooperative_groups.h", + %{comment}"include/cuComplex.h", + %{comment}"include/cuda.h", + %{comment}"include/cudaEGL.h", + %{comment}"include/cudaEGLTypedefs.h", + %{comment}"include/cudaGL.h", + %{comment}"include/cudaGLTypedefs.h", + %{comment}"include/cudaProfilerTypedefs.h", + %{comment}"include/cudaTypedefs.h", + %{comment}"include/cudaVDPAU.h", + %{comment}"include/cudaVDPAUTypedefs.h", + %{comment}"include/cuda_awbarrier.h", + %{comment}"include/cuda_awbarrier_helpers.h", + %{comment}"include/cuda_awbarrier_primitives.h", + %{comment}"include/cuda_bf16.h", + %{comment}"include/cuda_bf16.hpp", + %{comment}"include/cuda_device_runtime_api.h", + %{comment}"include/cuda_egl_interop.h", + %{comment}"include/cuda_fp16.h", + %{comment}"include/cuda_fp16.hpp", + %{comment}"include/cuda_fp8.h", + %{comment}"include/cuda_fp8.hpp", + %{comment}"include/cuda_gl_interop.h", + %{comment}"include/cuda_occupancy.h", + %{comment}"include/cuda_pipeline.h", + %{comment}"include/cuda_pipeline_helpers.h", + %{comment}"include/cuda_pipeline_primitives.h", + %{comment}"include/cuda_runtime.h", + %{comment}"include/cuda_runtime_api.h", + %{comment}"include/cuda_surface_types.h", + %{comment}"include/cuda_texture_types.h", + %{comment}"include/cuda_vdpau_interop.h", + %{comment}"include/cudart_platform.h", + %{comment}"include/device_atomic_functions.h", + %{comment}"include/device_atomic_functions.hpp", + %{comment}"include/device_double_functions.h", + %{comment}"include/device_functions.h", + %{comment}"include/device_launch_parameters.h", + %{comment}"include/device_types.h", + %{comment}"include/driver_functions.h", + %{comment}"include/driver_types.h", + %{comment}"include/host_config.h", + %{comment}"include/host_defines.h", + %{comment}"include/library_types.h", + %{comment}"include/math_constants.h", + %{comment}"include/math_functions.h", + %{comment}"include/mma.h", + %{comment}"include/nvfunctional", + %{comment}"include/sm_20_atomic_functions.h", + %{comment}"include/sm_20_atomic_functions.hpp", + %{comment}"include/sm_20_intrinsics.h", + %{comment}"include/sm_20_intrinsics.hpp", + %{comment}"include/sm_30_intrinsics.h", + %{comment}"include/sm_30_intrinsics.hpp", + %{comment}"include/sm_32_atomic_functions.h", + %{comment}"include/sm_32_atomic_functions.hpp", + %{comment}"include/sm_32_intrinsics.h", + %{comment}"include/sm_32_intrinsics.hpp", + %{comment}"include/sm_35_atomic_functions.h", + %{comment}"include/sm_35_intrinsics.h", + %{comment}"include/sm_60_atomic_functions.h", + %{comment}"include/sm_60_atomic_functions.hpp", + %{comment}"include/sm_61_intrinsics.h", + %{comment}"include/sm_61_intrinsics.hpp", + %{comment}"include/surface_functions.h", + %{comment}"include/surface_indirect_functions.h", + %{comment}"include/surface_types.h", + %{comment}"include/texture_fetch_functions.h", + %{comment}"include/texture_indirect_functions.h", + %{comment}"include/texture_types.h", + %{comment}"include/vector_functions.h", + %{comment}"include/vector_functions.hpp", + %{comment}"include/vector_types.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl new file mode 100644 index 00000000000000..165c5b1579e73f --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -0,0 +1,73 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_infer.so.%{libcudnn_ops_infer_version}", +) + +cc_import( + name = "cudnn_cnn_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_infer.so.%{libcudnn_cnn_infer_version}", +) + +cc_import( + name = "cudnn_ops_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_train.so.%{libcudnn_ops_train_version}", +) + +cc_import( + name = "cudnn_cnn_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_train.so.%{libcudnn_cnn_train_version}", +) + +cc_import( + name = "cudnn_adv_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_infer.so.%{libcudnn_adv_infer_version}", +) + +cc_import( + name = "cudnn_adv_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_train.so.%{libcudnn_adv_train_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_ops_infer", + %{comment}":cudnn_ops_train", + %{comment}":cudnn_cnn_infer", + %{comment}":cudnn_cnn_train", + %{comment}":cudnn_adv_infer", + %{comment}":cudnn_adv_train", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl new file mode 100644 index 00000000000000..7f36054a51bb5b --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -0,0 +1,80 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops.so.%{libcudnn_ops_version}", +) + +cc_import( + name = "cudnn_cnn", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn.so.%{libcudnn_cnn_version}", +) + +cc_import( + name = "cudnn_adv", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv.so.%{libcudnn_adv_version}", +) + +cc_import( + name = "cudnn_graph", + hdrs = [":headers"], + shared_library = "lib/libcudnn_graph.so.%{libcudnn_graph_version}", +) + +cc_import( + name = "cudnn_engines_precompiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_precompiled.so.%{libcudnn_engines_precompiled_version}", +) + +cc_import( + name = "cudnn_engines_runtime_compiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_runtime_compiled.so.%{libcudnn_engines_runtime_compiled_version}", +) + +cc_import( + name = "cudnn_heuristic", + hdrs = [":headers"], + shared_library = "lib/libcudnn_heuristic.so.%{libcudnn_heuristic_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_engines_precompiled", + %{comment}":cudnn_ops", + %{comment}":cudnn_graph", + %{comment}":cudnn_cnn", + %{comment}":cudnn_adv", + %{comment}":cudnn_engines_runtime_compiled", + %{comment}":cudnn_heuristic", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl new file mode 100644 index 00000000000000..48ccb0ea3cd197 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -0,0 +1,29 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cufft_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcufft.so.%{libcufft_version}", +) +%{multiline_comment} +cc_library( + name = "cufft", + %{comment}deps = [":cufft_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudalibxt.h", + %{comment}"include/cufft*.h" + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl new file mode 100644 index 00000000000000..3efe76f470953f --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -0,0 +1,59 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cupti_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcupti.so.%{libcupti_version}", +) +%{multiline_comment} +cc_library( + name = "cupti", + %{comment}deps = [":cupti_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/Openacc/**", + %{comment}"include/Openmp/**", + %{comment}"include/cuda_stdint.h", + %{comment}"include/cupti.h", + %{comment}"include/cupti_activity.h", + %{comment}"include/cupti_activity_deprecated.h", + %{comment}"include/cupti_callbacks.h", + %{comment}"include/cupti_checkpoint.h", + %{comment}"include/cupti_driver_cbid.h", + %{comment}"include/cupti_events.h", + %{comment}"include/cupti_metrics.h", + %{comment}"include/cupti_nvtx_cbid.h", + %{comment}"include/cupti_pcsampling.h", + %{comment}"include/cupti_pcsampling_util.h", + %{comment}"include/cupti_profiler_target.h", + %{comment}"include/cupti_result.h", + %{comment}"include/cupti_runtime_cbid.h", + %{comment}"include/cupti_sass_metrics.h", + %{comment}"include/cupti_target.h", + %{comment}"include/cupti_version.h", + %{comment}"include/generated_cudaGL_meta.h", + %{comment}"include/generated_cudaVDPAU_meta.h", + %{comment}"include/generated_cuda_gl_interop_meta.h", + %{comment}"include/generated_cuda_meta.h", + %{comment}"include/generated_cuda_runtime_api_meta.h", + %{comment}"include/generated_cuda_vdpau_interop_meta.h", + %{comment}"include/generated_cudart_removed_meta.h", + %{comment}"include/generated_nvtx_meta.h", + %{comment}"include/nvperf_common.h", + %{comment}"include/nvperf_cuda_host.h", + %{comment}"include/nvperf_host.h", + %{comment}"include/nvperf_target.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/extras/CUPTI/include", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl new file mode 100644 index 00000000000000..50e5a8f18a96fd --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -0,0 +1,26 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "curand_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcurand.so.%{libcurand_version}", +) +%{multiline_comment} +cc_library( + name = "curand", + %{comment}deps = [":curand_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob(["include/curand*.h"]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl new file mode 100644 index 00000000000000..943a08ebeb96e1 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -0,0 +1,34 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusolver_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusolver.so.%{libcusolver_version}", + deps = [ + "@cuda_nvjitlink//:nvjitlink", + "@cuda_cusparse//:cusparse", + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + ], +) +%{multiline_comment} +cc_library( + name = "cusolver", + %{comment}deps = [":cusolver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cusolver*.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl new file mode 100644 index 00000000000000..46b24366ce1c04 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -0,0 +1,27 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusparse_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusparse.so.%{libcusparse_version}", + deps = ["@cuda_nvjitlink//:nvjitlink"], +) +%{multiline_comment} +cc_library( + name = "cusparse", + %{comment}deps = [":cusparse_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = ["include/cusparse.h"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl new file mode 100644 index 00000000000000..fdda3aaf92cea5 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl @@ -0,0 +1,125 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistributions JSON repository initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_JSON_DICT", +) + +def _get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_json_file_content(repository_ctx, url_to_sha256, json_file_name): + if len(url_to_sha256) > 1: + (url, sha256) = url_to_sha256 + else: + url = url_to_sha256[0] + sha256 = "" + repository_ctx.download( + url = tf_mirror_urls(url), + sha256 = sha256, + output = json_file_name, + ) + return repository_ctx.read(repository_ctx.path(json_file_name)) + +def _cuda_redist_json_impl(repository_ctx): + cuda_version = (_get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + _get_env_var(repository_ctx, "TF_CUDA_VERSION")) + local_cuda_path = _get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + cudnn_version = (_get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + _get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + local_cudnn_path = _get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + supported_cuda_versions = repository_ctx.attr.cuda_json_dict.keys() + if (cuda_version and not local_cuda_path and + (cuda_version not in supported_cuda_versions)): + fail( + ("The supported CUDA versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add JSON URL for" + + " CUDA version={version}.") + .format( + supported_versions = supported_cuda_versions, + version = cuda_version, + ), + ) + supported_cudnn_versions = repository_ctx.attr.cudnn_json_dict.keys() + if cudnn_version and not local_cudnn_path and (cudnn_version not in supported_cudnn_versions): + fail( + ("The supported CUDNN versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDNN_VERSION" + + " environment variable or add JSON URL for" + + " CUDNN version={version}.") + .format( + supported_versions = supported_cudnn_versions, + version = cudnn_version, + ), + ) + cuda_redistributions = "{}" + cudnn_redistributions = "{}" + if cuda_version and not local_cuda_path: + cuda_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cuda_json_dict[cuda_version], + "redistrib_cuda_%s.json" % cuda_version, + ) + if cudnn_version and not local_cudnn_path: + cudnn_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cudnn_json_dict[cudnn_version], + "redistrib_cudnn_%s.json" % cudnn_version, + ) + + repository_ctx.file( + "distributions.bzl", + """CUDA_REDISTRIBUTIONS = {cuda_redistributions} + +CUDNN_REDISTRIBUTIONS = {cudnn_redistributions} +""".format( + cuda_redistributions = cuda_redistributions, + cudnn_redistributions = cudnn_redistributions, + ), + ) + repository_ctx.file( + "BUILD", + "", + ) + +cuda_redist_json = repository_rule( + implementation = _cuda_redist_json_impl, + attrs = { + "cuda_json_dict": attr.string_list_dict(mandatory = True), + "cudnn_json_dict": attr.string_list_dict(mandatory = True), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "HERMETIC_CUDNN_VERSION", + "TF_CUDA_VERSION", + "TF_CUDNN_VERSION", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", + ], +) + +def cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT): + cuda_redist_json( + name = "cuda_redist_json", + cuda_json_dict = cuda_json_dict, + cudnn_json_dict = cudnn_json_dict, + ) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl new file mode 100644 index 00000000000000..7757a92a90b795 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl @@ -0,0 +1,75 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "bin/nvcc", +]) + +filegroup( + name = "nvvm", + srcs = [ + "nvvm/libdevice/libdevice.10.bc", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "nvlink", + srcs = [ + "bin/nvlink", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "fatbinary", + srcs = [ + "bin/fatbinary", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin2c", + srcs = [ + "bin/bin2c", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "ptxas", + srcs = [ + "bin/ptxas", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin", + srcs = glob([ + "bin/**", + "nvvm/bin/**", + ]), + visibility = ["//visibility:public"], +) + +filegroup( + name = "link_stub", + srcs = [ + "bin/crt/link.stub", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/crt/**", + %{comment}"include/fatbinary_section.h", + %{comment}"include/nvPTXCompiler.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl new file mode 100644 index 00000000000000..9784a84471f1a7 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -0,0 +1,17 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nvjitlink_shared_library", + shared_library = "lib/libnvJitLink.so.%{libnvjitlink_version}", +) +%{multiline_comment} +cc_library( + name = "nvjitlink", + %{comment}deps = [":nvjitlink_shared_library"], + visibility = ["//visibility:public"], +) + diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl new file mode 100644 index 00000000000000..23ee30f09f8ff3 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl @@ -0,0 +1,10 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = ["include/nvml.h"], + include_prefix = "third_party/gpus/cuda/nvml/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl new file mode 100644 index 00000000000000..986ef0c8f76166 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl @@ -0,0 +1,9 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +filegroup( + name = "nvprune", + srcs = [ + "bin/nvprune", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl new file mode 100644 index 00000000000000..de18489b455b79 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl @@ -0,0 +1,20 @@ +licenses(["restricted"]) # NVIDIA proprietary license +%{multiline_comment} +cc_import( + name = "nvrtc_main", + shared_library = "lib/libnvrtc.so.%{libnvrtc_version}", +) + +cc_import( + name = "nvrtc_builtins", + shared_library = "lib/libnvrtc-builtins.so.%{libnvrtc-builtins_version}", +) +%{multiline_comment} +cc_library( + name = "nvrtc", + %{comment}deps = [ + %{comment}":nvrtc_main", + %{comment}":nvrtc_builtins", + %{comment}], + visibility = ["//visibility:public"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl new file mode 100644 index 00000000000000..3457f41a502dee --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl @@ -0,0 +1,13 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nvToolsExt*.h", + %{comment}"include/nvtx3/**", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl new file mode 100644 index 00000000000000..d2015e737540c3 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -0,0 +1,491 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDNN_REDIST_PATH_PREFIX", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +OS_ARCH_DICT = { + "amd64": "x86_64-unknown-linux-gnu", + "aarch64": "aarch64-unknown-linux-gnu", +} +_REDIST_ARCH_DICT = { + "linux-x86_64": "x86_64-unknown-linux-gnu", + "linux-sbsa": "aarch64-unknown-linux-gnu", +} + +SUPPORTED_ARCHIVE_EXTENSIONS = [ + ".zip", + ".jar", + ".war", + ".aar", + ".tar", + ".tar.gz", + ".tgz", + ".tar.xz", + ".txz", + ".tar.zst", + ".tzst", + ".tar.bz2", + ".tbz", + ".ar", + ".deb", + ".whl", +] + +def get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def get_archive_name(url): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the archive name without extension.""" + filename = _get_file_name(url) + for extension in SUPPORTED_ARCHIVE_EXTENSIONS: + if filename.endswith(extension): + return filename[:-len(extension)] + return filename + +LIB_EXTENSION = ".so." + +def _get_lib_name_and_version(path): + extension_index = path.rfind(LIB_EXTENSION) + last_slash_index = path.rfind("/") + lib_name = path[last_slash_index + 1:extension_index] + lib_version = path[extension_index + len(LIB_EXTENSION):] + return (lib_name, lib_version) + +def _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_dir_path = repository_ctx.path("lib") + if not lib_dir_path.exists: + return [] + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]).lower() + lib_dir_content = lib_dir_path.readdir() + return [ + str(f) + for f in lib_dir_content + if (LIB_EXTENSION in str(f) and + main_lib_name in str(f).lower()) + ] + +def get_lib_name_to_version_dict(repository_ctx): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns a dict of library names and major versions.""" + lib_name_to_version_dict = {} + for path in _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_name, lib_version = _get_lib_name_and_version(path) + key = "%%{%s_version}" % lib_name.lower() + + # We need to find either major or major.minor version if there is no + # file with major version. E.g. if we have the following files: + # libcudart.so + # libcudart.so.12 + # libcudart.so.12.3.2, + # we will save save {"%{libcudart_version}": "12"}. + if len(lib_version.split(".")) == 1: + lib_name_to_version_dict[key] = lib_version + if (len(lib_version.split(".")) == 2 and + key not in lib_name_to_version_dict): + lib_name_to_version_dict[key] = lib_version + return lib_name_to_version_dict + +def create_dummy_build_file(repository_ctx, use_comment_symbols = True): + repository_ctx.template( + "BUILD", + repository_ctx.attr.build_templates[0], + { + "%{multiline_comment}": "'''" if use_comment_symbols else "", + "%{comment}": "#" if use_comment_symbols else "", + }, + ) + +def _get_build_template(repository_ctx, major_lib_version): + template = None + for i in range(0, len(repository_ctx.attr.versions)): + for dist_version in repository_ctx.attr.versions[i].split(","): + if dist_version == major_lib_version: + template = repository_ctx.attr.build_templates[i] + break + if not template: + fail("No build template found for {} version {}".format( + repository_ctx.name, + major_lib_version, + )) + return template + +def get_major_library_version(repository_ctx, lib_name_to_version_dict): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the major library version provided the versions dict.""" + major_version = "" + if len(lib_name_to_version_dict) == 0: + return major_version + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]) + key = "%%{%s_version}" % main_lib_name + major_version = lib_name_to_version_dict[key] + return major_version + +def create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_lib_version): + # buildifier: disable=function-docstring-args + """Creates a BUILD file for the repository.""" + if len(major_lib_version) == 0: + build_template_content = repository_ctx.read( + repository_ctx.attr.build_templates[0], + ) + if "_version}" not in build_template_content: + create_dummy_build_file(repository_ctx, use_comment_symbols = False) + else: + create_dummy_build_file(repository_ctx) + return + build_template = _get_build_template( + repository_ctx, + major_lib_version.split(".")[0], + ) + repository_ctx.template( + "BUILD", + build_template, + lib_name_to_version_dict | { + "%{multiline_comment}": "", + "%{comment}": "", + }, + ) + +def _create_symlinks(repository_ctx, local_path, dirs): + for dir in dirs: + repository_ctx.symlink( + "{path}/{dir}".format( + path = local_path, + dir = dir, + ), + dir, + ) + +def use_local_path(repository_ctx, local_path, dirs): + # buildifier: disable=function-docstring-args + """Creates repository using local redistribution paths.""" + _create_symlinks( + repository_ctx, + local_path, + dirs, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _use_local_cuda_path(repository_ctx, local_cuda_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDA repository.""" + use_local_path( + repository_ctx, + local_cuda_path, + ["include", "lib", "bin", "nvvm"], + ) + +def _use_local_cudnn_path(repository_ctx, local_cudnn_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDNN repository.""" + use_local_path(repository_ctx, local_cudnn_path, ["include", "lib"]) + +def _download_redistribution(repository_ctx, arch_key, path_prefix): + (url, sha256) = repository_ctx.attr.url_dict[arch_key] + + # If url is not relative, then appending prefix is not needed. + if not (url.startswith("http") or url.startswith("file:///")): + url = path_prefix + url + archive_name = get_archive_name(url) + file_name = _get_file_name(url) + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + if repository_ctx.attr.override_strip_prefix: + strip_prefix = repository_ctx.attr.override_strip_prefix + else: + strip_prefix = archive_name + repository_ctx.extract( + archive = file_name, + stripPrefix = strip_prefix, + ) + repository_ctx.delete(file_name) + +def _use_downloaded_cuda_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDA redistribution and initializes hermetic CUDA repository.""" + major_version = "" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cuda_version: + # If no CUDA version is found, comment out all cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cuda_redist_path_prefix, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version(repository_ctx, lib_name_to_version_dict) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _cuda_repo_impl(repository_ctx): + local_cuda_path = get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + if local_cuda_path: + _use_local_cuda_path(repository_ctx, local_cuda_path) + else: + _use_downloaded_cuda_redistribution(repository_ctx) + +cuda_repo = repository_rule( + implementation = _cuda_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cuda_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDA_PATH", + ], +) + +def _use_downloaded_cudnn_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDNN redistribution and initializes hermetic CUDNN repository.""" + cudnn_version = None + major_version = "" + cudnn_version = (get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cudnn_version: + # If no CUDNN version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + arch_key = "cuda{version}_{arch}".format( + version = cuda_version.split(".")[0], + arch = arch_key, + ) + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cudnn_redist_path_prefix, + ) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _cudnn_repo_impl(repository_ctx): + local_cudnn_path = get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + if local_cudnn_path: + _use_local_cudnn_path(repository_ctx, local_cudnn_path) + else: + _use_downloaded_cudnn_redistribution(repository_ctx) + +cudnn_repo = repository_rule( + implementation = _cudnn_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cudnn_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDNN_VERSION", + "TF_CUDNN_VERSION", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDNN_PATH", + ], +) + +def _get_redistribution_urls(dist_info): + url_dict = {} + for arch in _REDIST_ARCH_DICT.keys(): + if "relative_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["relative_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + if "full_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["full_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + for cuda_version, data in dist_info[arch].items(): + # CUDNN JSON might contain paths for each CUDA version. + path_key = "relative_path" + if path_key not in data.keys(): + path_key = "full_path" + url_dict["{cuda_version}_{arch}".format( + cuda_version = cuda_version, + arch = _REDIST_ARCH_DICT[arch], + )] = [data[path_key], data.get("sha256", "")] + return url_dict + +def get_version_and_template_lists(version_to_template): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns lists of versions and templates provided in the dict.""" + template_to_version_map = {} + for version, template in version_to_template.items(): + if template not in template_to_version_map.keys(): + template_to_version_map[template] = [version] + else: + template_to_version_map[template].append(version) + version_list = [] + template_list = [] + for template, versions in template_to_version_map.items(): + version_list.append(",".join(versions)) + template_list.append(Label(template)) + return (version_list, template_list) + +def cudnn_redist_init_repository( + cudnn_redistributions, + cudnn_redist_path_prefix = CUDNN_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDNN repository.""" + if "cudnn" in cudnn_redistributions.keys(): + url_dict = _get_redistribution_urls(cudnn_redistributions["cudnn"]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates["cudnn"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cudnn_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cudnn_redist_path_prefix = cudnn_redist_path_prefix, + ) + +def cuda_redist_init_repositories( + cuda_redistributions, + cuda_redist_path_prefix = CUDA_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDA repositories.""" + for redist_name, _ in redist_versions_to_build_templates.items(): + if redist_name in ["cudnn", "cuda_nccl"]: + continue + if redist_name in cuda_redistributions.keys(): + url_dict = _get_redistribution_urls(cuda_redistributions[redist_name]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates[redist_name] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cuda_redist_path_prefix = cuda_redist_path_prefix, + ) diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl new file mode 100644 index 00000000000000..7f94a983661df7 --- /dev/null +++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -0,0 +1,197 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistribution versions.""" + +CUDA_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/" +CUDNN_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/" + +CUDA_REDIST_JSON_DICT = { + "11.8": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_11.8.0.json", + "941a950a4ab3b95311c50df7b3c8bca973e0cdda76fc2f4b456d2d5e4dac0281", + ], + "12.1.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.1.1.json", + "bafea3cb83a4cf5c764eeedcaac0040d0d3c5db3f9a74550da0e7b6ac24d378c", + ], + "12.3.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.1.json", + "b3cc4181d711cf9b6e3718f323b23813c24f9478119911d7b4bceec9b437dbc3", + ], + "12.3.2": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.2.json", + "1b6eacf335dd49803633fed53ef261d62c193e5a56eee5019e7d2f634e39e7ef", + ], +} + +CUDNN_REDIST_JSON_DICT = { + "8.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.6.0.json", + "7f6f50bed4fd8216dc10d6ef505771dc0ecc99cce813993ab405cb507a21d51d", + ], + "8.9.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.6.json", + "6069ef92a2b9bb18cebfbc944964bd2b024b76f2c2c35a43812982e0bc45cf0c", + ], + "8.9.7.29": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.7.29.json", + "a0734f26f068522464fa09b2f2c186dfbe6ad7407a88ea0c50dd331f0c3389ec", + ], + "9.1.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.1.1.json", + "d22d569405e5683ff8e563d00d6e8c27e5e6a902c564c23d752b22a8b8b3fe20", + ], +} + +# The versions are different for x86 and aarch64 architectures because only +# NCCL release versions 2.20.3 and 2.20.5 have the wheels for aarch64. +CUDA_12_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", + }, + "aarch64-unknown-linux-gnu": { + "version": "2.20.5", + "url": "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", + "sha256": "1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", + }, +} + +CUDA_11_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/ac/9a/8b6a28b3b87d5fddab0e92cd835339eb8fbddaa71ae67518c8c1b3d05bae/nvidia_nccl_cu11-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "49d8350629c7888701d1fd200934942671cb5c728f49acc5a0b3a768820bed29", + }, +} + +CUDA_NCCL_WHEELS = { + "11.8": CUDA_11_NCCL_WHEEL_DICT, + "12.1.1": CUDA_12_NCCL_WHEEL_DICT, + "12.3.1": CUDA_12_NCCL_WHEEL_DICT, + "12.3.2": CUDA_12_NCCL_WHEEL_DICT, +} + +REDIST_VERSIONS_TO_BUILD_TEMPLATES = { + "cuda_nccl": { + "repo_name": "cuda_nccl", + "version_to_template": { + "2": "//third_party/nccl/hermetic:cuda_nccl.BUILD.tpl", + }, + }, + "cudnn": { + "repo_name": "cuda_cudnn", + "version_to_template": { + "9": "//third_party/gpus/cuda/hermetic:cuda_cudnn9.BUILD.tpl", + "8": "//third_party/gpus/cuda/hermetic:cuda_cudnn.BUILD.tpl", + }, + }, + "libcublas": { + "repo_name": "cuda_cublas", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + }, + }, + "cuda_cudart": { + "repo_name": "cuda_cudart", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + }, + }, + "libcufft": { + "repo_name": "cuda_cufft", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + "10": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + }, + }, + "cuda_cupti": { + "repo_name": "cuda_cupti", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + }, + }, + "libcurand": { + "repo_name": "cuda_curand", + "version_to_template": { + "10": "//third_party/gpus/cuda/hermetic:cuda_curand.BUILD.tpl", + }, + }, + "libcusolver": { + "repo_name": "cuda_cusolver", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cusolver.BUILD.tpl", + }, + }, + "libcusparse": { + "repo_name": "cuda_cusparse", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + }, + }, + "libnvjitlink": { + "repo_name": "cuda_nvjitlink", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", + }, + }, + "cuda_nvrtc": { + "repo_name": "cuda_nvrtc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + }, + }, + "cuda_cccl": { + "repo_name": "cuda_cccl", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + }, + }, + "cuda_nvcc": { + "repo_name": "cuda_nvcc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + }, + }, + "cuda_nvml_dev": { + "repo_name": "cuda_nvml", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + }, + }, + "cuda_nvprune": { + "repo_name": "cuda_nvprune", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + }, + }, + "cuda_nvtx": { + "repo_name": "cuda_nvtx", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + }, + }, +} diff --git a/third_party/tsl/third_party/gpus/cuda_configure.bzl b/third_party/tsl/third_party/gpus/cuda_configure.bzl index f4ed97ac4eb07a..a25b60c5f4f6fe 100644 --- a/third_party/tsl/third_party/gpus/cuda_configure.bzl +++ b/third_party/tsl/third_party/gpus/cuda_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for CUDA autoconfiguration. +NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + `cuda_configure` depends on the following environment variables: * `TF_NEED_CUDA`: Whether to enable building with CUDA. @@ -53,6 +55,11 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" @@ -67,20 +74,6 @@ _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -def to_list_of_strings(elements): - """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. - - This is to be used to put a list of strings into the bzl file templates - so it gets interpreted as list of strings in Starlark. - - Args: - elements: list of string elements - - Returns: - single string of elements wrapped in quotes separated by a comma.""" - quoted_strings = ["\"" + element + "\"" for element in elements] - return ", ".join(quoted_strings) - def verify_build_defines(params): """Verify all variables that crosstool/BUILD.tpl expects are substituted. @@ -238,156 +231,6 @@ def find_cc(repository_ctx, use_cuda_clang): " environment variable").format(target_cc_name, cc_path_envvar)) return cc -_INC_DIR_MARKER_BEGIN = "#include <...>" - -# OSX add " (framework directory)" at the end of line, strip it. -_OSX_FRAMEWORK_SUFFIX = " (framework directory)" -_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) - -def _cxx_inc_convert(path): - """Convert path returned by cc -E xc++ in a complete path.""" - path = path.strip() - if path.endswith(_OSX_FRAMEWORK_SUFFIX): - path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() - return path - -def _normalize_include_path(repository_ctx, path): - """Normalizes include paths before writing them to the crosstool. - - If path points inside the 'crosstool' folder of the repository, a relative - path is returned. - If path points outside the 'crosstool' folder, an absolute path is returned. - """ - path = str(repository_ctx.path(path)) - crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) - - if path.startswith(crosstool_folder): - # We drop the path to "$REPO/crosstool" and a trailing path separator. - return path[len(crosstool_folder) + 1:] - return path - -def _is_compiler_option_supported(repository_ctx, cc, option): - """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" - result = repository_ctx.execute([ - cc, - option, - "-o", - "/dev/null", - "-c", - str(repository_ctx.path("tools/cpp/empty.cc")), - ]) - return result.stderr.find(option) == -1 - -def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): - """Compute the list of default C or C++ include directories.""" - if lang_is_cpp: - lang = "c++" - else: - lang = "c" - sysroot = [] - if tf_sysroot: - sysroot += ["--sysroot", tf_sysroot] - result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + - sysroot) - stderr = err_out(result) - index1 = stderr.find(_INC_DIR_MARKER_BEGIN) - if index1 == -1: - return [] - index1 = stderr.find("\n", index1) - if index1 == -1: - return [] - index2 = stderr.rfind("\n ") - if index2 == -1 or index2 < index1: - return [] - index2 = stderr.find("\n", index2 + 1) - if index2 == -1: - inc_dirs = stderr[index1 + 1:] - else: - inc_dirs = stderr[index1 + 1:index2].strip() - - print_resource_dir_supported = _is_compiler_option_supported( - repository_ctx, - cc, - "-print-resource-dir", - ) - - if print_resource_dir_supported: - resource_dir = repository_ctx.execute( - [cc, "-print-resource-dir"], - ).stdout.strip() + "/share" - inc_dirs += "\n" + resource_dir - - compiler_includes = [ - _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) - for p in inc_dirs.split("\n") - ] - - # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc - # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) - # but Bazel might encounter either (usually reported by the compiler) - # especially when a compiler wrapper (e.g. ccache) is used. - # So we need to also include paths where symlinks are not resolved. - - # Try to find real path to CC installation to "see through" compiler wrappers - # GCC has the path to g++ - index1 = result.stderr.find("COLLECT_GCC=") - if index1 != -1: - index1 = result.stderr.find("=", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname - else: - # Clang has the directory - index1 = result.stderr.find("InstalledDir: ") - if index1 != -1: - index1 = result.stderr.find(" ", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname - else: - # Fallback to the CC path - cc_topdir = repository_ctx.path(cc).dirname.dirname - - # We now have the compiler installation prefix, e.g. /symlink/gcc - # And the resolved installation prefix, e.g. /opt/gcc - cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() - cc_topdir = str(cc_topdir).strip() - - # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. - # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] - # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] - if cc_topdir_resolved != cc_topdir: - unresolved_compiler_includes = [ - cc_topdir + inc[len(cc_topdir_resolved):] - for inc in compiler_includes - if inc.startswith(cc_topdir_resolved) - ] - compiler_includes = compiler_includes + unresolved_compiler_includes - return compiler_includes - -def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): - """Compute the list of default C and C++ include directories.""" - - # For some reason `clang -xc` sometimes returns include paths that are - # different from the ones from `clang -xc++`. (Symlink and a dir) - # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - True, - tf_sysroot, - ) - includes_c = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - False, - tf_sysroot, - ) - - return includes_cpp + [ - inc - for inc in includes_c - if inc not in includes_cpp - ] - def auto_configure_fail(msg): """Output failure message when cuda configuration fails.""" red = "\033[0;31m" @@ -1293,6 +1136,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cuda_nvcc_files}"] = "[]" if is_cuda_clang and not is_nvcc_and_clang: cuda_defines["%{host_compiler_path}"] = str(cc) cuda_defines["%{host_compiler_warnings}"] = """ diff --git a/third_party/tsl/third_party/gpus/find_cuda_config.py b/third_party/tsl/third_party/gpus/find_cuda_config.py index b88694af5c014d..68623bf671da71 100644 --- a/third_party/tsl/third_party/gpus/find_cuda_config.py +++ b/third_party/tsl/third_party/gpus/find_cuda_config.py @@ -14,6 +14,9 @@ # ============================================================================== """Prints CUDA library and header directories and versions found on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + The script searches for CUDA library and header files on the system, inspects them to determine their version and prints the configuration to stdout. The paths to inspect and the required versions are specified through environment diff --git a/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/tsl/third_party/gpus/rocm_configure.bzl index da4b2b976ffde0..482b3eca709b7f 100644 --- a/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -22,12 +22,15 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "enable_cuda", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) load( ":sycl_configure.bzl", diff --git a/third_party/tsl/third_party/gpus/sycl_configure.bzl b/third_party/tsl/third_party/gpus/sycl_configure.bzl index 05330b2fe53195..dd80694e7274f5 100644 --- a/third_party/tsl/third_party/gpus/sycl_configure.bzl +++ b/third_party/tsl/third_party/gpus/sycl_configure.bzl @@ -16,11 +16,14 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" diff --git a/third_party/tsl/third_party/nccl/build_defs.bzl.tpl b/third_party/tsl/third_party/nccl/build_defs.bzl.tpl index ffc3062b39533d..c5be6e5b63f22a 100644 --- a/third_party/tsl/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/tsl/third_party/nccl/build_defs.bzl.tpl @@ -5,7 +5,6 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") # CUDA toolkit version as tuple (e.g. '(11, 1)'). _cuda_version = %{cuda_version} -_cuda_clang = %{cuda_clang} def _rdc_copts(): """Returns copts for compiling relocatable device code.""" @@ -121,25 +120,25 @@ _device_link = rule( "gpu_archs": attr.string_list(), "nvlink_args": attr.string_list(), "_nvlink": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvlink"), + default = Label("%{nvlink_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_fatbinary": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/fatbinary"), + default = Label("%{fatbinary_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_bin2c": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/bin2c"), + default = Label("%{bin2c_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_link_stub": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/crt/link.stub"), + default = Label("%{link_stub_label}"), allow_single_file = True, ), }, @@ -189,7 +188,7 @@ _prune_relocatable_code = rule( "input": attr.label(mandatory = True, allow_files = True), "gpu_archs": attr.string_list(), "_nvprune": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvprune"), + default = Label("%{nvprune_label}"), allow_single_file = True, executable = True, cfg = "host", diff --git a/third_party/tsl/third_party/nccl/hermetic/BUILD b/third_party/tsl/third_party/nccl/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl b/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl new file mode 100644 index 00000000000000..61d7809bcdaad1 --- /dev/null +++ b/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl @@ -0,0 +1,30 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nccl_shared_library", + shared_library = "lib/libnccl.so.%{libnccl_version}", + hdrs = [":headers"], + deps = ["@local_config_cuda//cuda:cuda_headers", ":headers"], +) +%{multiline_comment} +cc_library( + name = "nccl", + %{comment}deps = [":nccl_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nccl*.h", + %{comment}]), + include_prefix = "third_party/nccl", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl new file mode 100644 index 00000000000000..75f5a10b6fe24e --- /dev/null +++ b/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl @@ -0,0 +1,183 @@ +"""Repository rule for hermetic NCCL configuration. + +`nccl_configure` depends on the following environment variables: + + * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should + be used, "0" if NCCL should be linked in statically. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + +""" + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "TF_NEED_CUDA", + "enable_cuda", + "get_cuda_version", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", +) + +_TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" + +_NCCL_DUMMY_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_STUB_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl_via_stub", + }), + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +def _create_local_nccl_repository(repository_ctx): + cuda_version = get_cuda_version(repository_ctx).split(".")[:2] + nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + + if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0": + repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) + else: + repository_ctx.file("BUILD", _NCCL_ARCHIVE_STUB_BUILD_CONTENT) + + repository_ctx.template("generated_names.bzl", repository_ctx.attr.generated_names_tpl, {}) + repository_ctx.template( + "build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), + "%{nvlink_label}": "@cuda_nvcc//:nvlink", + "%{fatbinary_label}": "@cuda_nvcc//:fatbinary", + "%{bin2c_label}": "@cuda_nvcc//:bin2c", + "%{link_stub_label}": "@cuda_nvcc//:link_stub", + "%{nvprune_label}": "@cuda_nvprune//:nvprune", + }, + ) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"%s\"" % nccl_version) + +def _nccl_autoconf_impl(repository_ctx): + if (not enable_cuda(repository_ctx) or + get_cpu_value(repository_ctx) != "Linux"): + # Add a dummy build file to make bazel query happy. + repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") + else: + _create_local_nccl_repository(repository_ctx) + +_ENVIRONS = [ + TF_NEED_CUDA, + TF_CUDA_VERSION, + _TF_NCCL_USE_STUB, + HERMETIC_CUDA_VERSION, + "LOCAL_NCCL_PATH", +] + +nccl_configure = repository_rule( + environ = _ENVIRONS, + implementation = _nccl_autoconf_impl, + attrs = { + "environ": attr.string_dict(), + "nccl_version": attr.label(default = Label("@cuda_nccl//:version.txt")), + "generated_names_tpl": attr.label(default = Label("//third_party/nccl:generated_names.bzl.tpl")), + "build_defs_tpl": attr.label(default = Label("//third_party/nccl:build_defs.bzl.tpl")), + }, +) +"""Downloads and configures the hermetic NCCL configuration. + +Add the following to your WORKSPACE file: + +```python +nccl_configure(name = "local_config_nccl") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl b/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl new file mode 100644 index 00000000000000..244cb851ddf591 --- /dev/null +++ b/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl @@ -0,0 +1,145 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic NCCL repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "OS_ARCH_DICT", + "create_build_file", + "create_dummy_build_file", + "get_archive_name", + "get_env_var", + "get_lib_name_to_version_dict", + "get_major_library_version", + "get_version_and_template_lists", + "use_local_path", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_NCCL_WHEELS", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +def _use_downloaded_nccl_wheel(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads NCCL wheel and inits hermetic NCCL repository.""" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + major_version = "" + if not cuda_version: + # If no CUDA version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch = OS_ARCH_DICT[repository_ctx.os.arch] + dict_key = "{cuda_version}-{arch}".format( + cuda_version = cuda_version, + arch = arch, + ) + supported_versions = repository_ctx.attr.url_dict.keys() + if dict_key not in supported_versions: + fail( + ("The supported NCCL versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add NCCL distribution for" + + " CUDA version={version}, OS={arch}.") + .format( + supported_versions = supported_versions, + version = cuda_version, + arch = arch, + ), + ) + sha256 = repository_ctx.attr.sha256_dict[dict_key] + url = repository_ctx.attr.url_dict[dict_key] + + archive_name = get_archive_name(url) + file_name = archive_name + ".zip" + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + repository_ctx.extract( + archive = file_name, + stripPrefix = repository_ctx.attr.strip_prefix, + ) + repository_ctx.delete(file_name) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _use_local_nccl_path(repository_ctx, local_nccl_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic NCCL repository.""" + use_local_path(repository_ctx, local_nccl_path, ["include", "lib"]) + +def _cuda_nccl_repo_impl(repository_ctx): + local_nccl_path = get_env_var(repository_ctx, "LOCAL_NCCL_PATH") + if local_nccl_path: + _use_local_nccl_path(repository_ctx, local_nccl_path) + else: + _use_downloaded_nccl_wheel(repository_ctx) + +cuda_nccl_repo = repository_rule( + implementation = _cuda_nccl_repo_impl, + attrs = { + "sha256_dict": attr.string_dict(mandatory = True), + "url_dict": attr.string_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "strip_prefix": attr.string(), + }, + environ = ["HERMETIC_CUDA_VERSION", "TF_CUDA_VERSION", "LOCAL_NCCL_PATH"], +) + +def nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes NCCL repository.""" + nccl_artifacts_dict = {"sha256_dict": {}, "url_dict": {}} + for cuda_version, nccl_wheel_info in cuda_nccl_wheels.items(): + for arch in OS_ARCH_DICT.values(): + if arch in nccl_wheel_info.keys(): + cuda_version_to_arch_key = "%s-%s" % (cuda_version, arch) + nccl_artifacts_dict["sha256_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch].get("sha256", "") + nccl_artifacts_dict["url_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch]["url"] + repo_data = redist_versions_to_build_templates["cuda_nccl"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_nccl_repo( + name = repo_data["repo_name"], + sha256_dict = nccl_artifacts_dict["sha256_dict"], + url_dict = nccl_artifacts_dict["url_dict"], + versions = versions, + build_templates = templates, + strip_prefix = "nvidia/nccl", + ) diff --git a/third_party/tsl/third_party/nccl/nccl_configure.bzl b/third_party/tsl/third_party/nccl/nccl_configure.bzl index a62c29caf27a40..724e2bcfe62fb4 100644 --- a/third_party/tsl/third_party/nccl/nccl_configure.bzl +++ b/third_party/tsl/third_party/nccl/nccl_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for NCCL configuration. +NB: DEPRECATED! Use `hermetic/nccl_configure` rule instead. + `nccl_configure` depends on the following environment variables: * `TF_NCCL_VERSION`: Installed NCCL version or empty to build from source. @@ -8,7 +10,6 @@ files. * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is `/usr/local/cuda,usr/`. - * `TF_CUDA_CLANG`: "1" if using Clang, "0" if using NVCC. * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should be used, "0" if NCCL should be linked in statically. @@ -33,7 +34,6 @@ _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" _TF_NCCL_VERSION = "TF_NCCL_VERSION" _TF_NEED_CUDA = "TF_NEED_CUDA" _TF_CUDA_PATHS = "TF_CUDA_PATHS" -_TF_CUDA_CLANG = "TF_CUDA_CLANG" _TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" _DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR" @@ -129,7 +129,11 @@ def _create_local_nccl_repository(repository_ctx): _label("build_defs.bzl.tpl"), { "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), - "%{cuda_clang}": repr(get_host_environ(repository_ctx, _TF_CUDA_CLANG)), + "%{nvlink_label}": "@local_config_cuda//cuda:cuda/bin/nvlink", + "%{fatbinary_label}": "@local_config_cuda//cuda:cuda/bin/fatbinary", + "%{bin2c_label}": "@local_config_cuda//cuda:cuda/bin/bin2c", + "%{link_stub_label}": "@local_config_cuda//cuda:cuda/bin/crt/link.stub", + "%{nvprune_label}": "@local_config_cuda//cuda:cuda/bin/nvprune", }, ) else: @@ -181,7 +185,6 @@ _ENVIRONS = [ _TF_CUDA_COMPUTE_CAPABILITIES, _TF_NEED_CUDA, _TF_CUDA_PATHS, - _TF_CUDA_CLANG, ] remote_nccl_configure = repository_rule( diff --git a/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 86f22652fdc88e..9a4dfa2aafdc51 100644 --- a/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,8 +710,8 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_NEED_TENSORRT": "0", @@ -749,8 +749,8 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_NEED_TENSORRT": "0", @@ -788,8 +788,8 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", @@ -826,8 +826,8 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", @@ -864,8 +864,8 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", diff --git a/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl b/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl index 51b9ea2f960071..dbd7bad8d855c6 100644 --- a/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl @@ -1,8 +1,8 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/third_party/tsl/tsl/platform/default/BUILD b/third_party/tsl/tsl/platform/default/BUILD index 01cf593888c077..04e749d757c49b 100644 --- a/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/tsl/tsl/platform/default/BUILD @@ -2,6 +2,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load( "@xla//xla/tsl:tsl.bzl", + "if_hermetic_cuda_tools", "if_not_fuchsia", "if_not_windows", "if_oss", @@ -59,6 +60,9 @@ cc_library( srcs = ["cuda_libdevice_path.cc"], hdrs = ["//tsl/platform:cuda_libdevice_path.h"], compatible_with = [], + data = if_hermetic_cuda_tools([ + "@cuda_nvcc//:nvvm", + ]), tags = [ "manual", "no_oss", @@ -66,6 +70,7 @@ cc_library( ], deps = [ "//tsl/platform", + "//tsl/platform:env", "//tsl/platform:logging", "//tsl/platform:path", "//tsl/platform:types", diff --git a/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc b/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc index 46321e74b5dc38..ac0a804b4dfd42 100644 --- a/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc +++ b/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc @@ -31,6 +31,7 @@ limitations under the License. #if !defined(PLATFORM_GOOGLE) #include "third_party/gpus/cuda/cuda_config.h" +#include "tsl/platform/env.h" #endif #include "tsl/platform/logging.h" @@ -38,8 +39,25 @@ namespace tsl { std::vector CandidateCudaRoots() { #if !defined(PLATFORM_GOOGLE) - auto roots = std::vector{TF_CUDA_TOOLKIT_PATH, - std::string("/usr/local/cuda")}; + auto roots = std::vector{}; + std::string runfiles_suffix = "runfiles"; + + // The CUDA candidate root for c++ targets. + std::string executable_path = tsl::Env::Default()->GetExecutablePath(); + std::string cuda_nvcc_dir = + io::JoinPath(executable_path + "." + runfiles_suffix, "cuda_nvcc"); + roots.emplace_back(cuda_nvcc_dir); + + // The CUDA candidate root for python targets. + std::string runfiles_dir = tsl::Env::Default()->GetRunfilesDir(); + std::size_t runfiles_ind = runfiles_dir.rfind(runfiles_suffix); + cuda_nvcc_dir = io::JoinPath( + runfiles_dir.substr(0, runfiles_ind + runfiles_suffix.length()), + "cuda_nvcc"); + roots.emplace_back(cuda_nvcc_dir); + + roots.emplace_back(TF_CUDA_TOOLKIT_PATH); + roots.emplace_back(std::string("/usr/local/cuda")); #if defined(PLATFORM_POSIX) && !defined(__APPLE__) Dl_info info; @@ -53,13 +71,17 @@ std::vector CandidateCudaRoots() { // relative to the current binary for the wheel-based nvcc package. for (auto path : {"../nvidia/cuda_nvcc", "../../nvidia/cuda_nvcc"}) roots.emplace_back(io::JoinPath(dir, path)); + + // Also add the path to the copy of libdevice.10.bc that we include within + // the Python wheel. + roots.emplace_back(io::JoinPath(dir, "cuda")); } #endif // defined(PLATFORM_POSIX) && !defined(__APPLE__) for (auto root : roots) VLOG(3) << "CUDA root = " << root; return roots; #else // !defined(PLATFORM_GOOGLE) - return {std::string("/usr/local/cuda")}; + return {}; #endif //! defined(PLATFORM_GOOGLE) } diff --git a/third_party/tsl/workspace2.bzl b/third_party/tsl/workspace2.bzl index 8b8f1de82400a3..01d6000db3950d 100644 --- a/third_party/tsl/workspace2.bzl +++ b/third_party/tsl/workspace2.bzl @@ -17,14 +17,12 @@ load("//third_party/eigen3:workspace.bzl", eigen3 = "repo") load("//third_party/farmhash:workspace.bzl", farmhash = "repo") load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo") load("//third_party/git:git_configure.bzl", "git_configure") -load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/gpus:sycl_configure.bzl", "sycl_configure") load("//third_party/hwloc:workspace.bzl", hwloc = "repo") load("//third_party/implib_so:workspace.bzl", implib_so = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") load("//third_party/nasm:workspace.bzl", nasm = "repo") -load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo") load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo") @@ -69,9 +67,7 @@ def _tf_toolchains(): # Note that we check the minimum bazel version in WORKSPACE. clang6_configure(name = "local_config_clang6") cc_download_clang_toolchain(name = "local_config_download_clang") - cuda_configure(name = "local_config_cuda") tensorrt_configure(name = "local_config_tensorrt") - nccl_configure(name = "local_config_nccl") git_configure(name = "local_config_git") syslibs_configure(name = "local_config_syslibs") python_configure(name = "local_config_python") diff --git a/tools/toolchains/remote_config/configs.bzl b/tools/toolchains/remote_config/configs.bzl index 86f22652fdc88e..9a4dfa2aafdc51 100644 --- a/tools/toolchains/remote_config/configs.bzl +++ b/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,8 +710,8 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_NEED_TENSORRT": "0", @@ -749,8 +749,8 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_NEED_TENSORRT": "0", @@ -788,8 +788,8 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", @@ -826,8 +826,8 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", @@ -864,8 +864,8 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", diff --git a/tools/toolchains/remote_config/rbe_config.bzl b/tools/toolchains/remote_config/rbe_config.bzl index 51b9ea2f960071..ec2ac4cc8ea430 100644 --- a/tools/toolchains/remote_config/rbe_config.bzl +++ b/tools/toolchains/remote_config/rbe_config.bzl @@ -1,8 +1,8 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") +load("@local_config_cuda//cuda/hermetic:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/xla/lit.bzl b/xla/lit.bzl index d6ec58096671f3..bbee57e4246e46 100644 --- a/xla/lit.bzl +++ b/xla/lit.bzl @@ -1,7 +1,7 @@ """Helper rules for writing LIT tests.""" load("@bazel_skylib//lib:paths.bzl", "paths") -load("//xla/tsl:tsl.bzl", "if_oss") +load("//xla/tsl:tsl.bzl", "if_hermetic_cuda_tools", "if_oss") def enforce_glob(files, **kwargs): """A utility to enforce that a list matches a glob expression. @@ -50,6 +50,7 @@ def lit_test_suite( timeout = None, default_tags = None, tags_override = None, + hermetic_cuda_data_dir = None, **kwargs): """Creates one lit test per source file and a test suite that bundles them. @@ -74,6 +75,8 @@ def lit_test_suite( timeout: timeout argument passed to the individual tests. default_tags: string list. Tags applied to all tests. tags_override: string_dict. Tags applied in addition to only select tests. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. **kwargs: additional keyword arguments to pass to all generated rules. See https://llvm.org/docs/CommandGuide/lit.html for details on lit @@ -105,6 +108,7 @@ def lit_test_suite( env = env, timeout = timeout, tags = default_tags + tags_override.get(test_file, []), + hermetic_cuda_data_dir = hermetic_cuda_data_dir, **kwargs ) @@ -114,6 +118,23 @@ def lit_test_suite( **kwargs ) +def lit_script_with_xla_gpu_cuda_data_dir( + name, + input_file, + output_file, + xla_gpu_cuda_data_dir): + """Adds a line to the LIT script to set the XLA_FLAGS environment variable.""" + return native.genrule( + name = name, + srcs = [input_file], + outs = [output_file], + cmd = if_hermetic_cuda_tools( + """echo -e '// RUN: export XLA_FLAGS=\"--xla_gpu_cuda_data_dir={}\"' > $@; +cat $< >> $@;""".format(xla_gpu_cuda_data_dir), + "cat $< >> $@;", + ), + ) + def lit_test( name, test_file, @@ -124,6 +145,7 @@ def lit_test( visibility = None, env = None, timeout = None, + hermetic_cuda_data_dir = None, **kwargs): """Runs a single test file with LLVM's lit tool. @@ -146,6 +168,8 @@ def lit_test( env: string_dict. Environment variables available during test execution. See the common Bazel test attribute. timeout: bazel test timeout string, as per common bazel definitions. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. **kwargs: additional keyword arguments to pass to all generated rules. See https://llvm.org/docs/CommandGuide/lit.html for details on lit @@ -170,12 +194,19 @@ def lit_test( tools_on_path_target_name, "lit_bin", ) + lib_dir = paths.join( + native.package_name(), + tools_on_path_target_name, + "lit_lib", + ) _tools_on_path( name = tools_on_path_target_name, testonly = True, srcs = tools, bin_dir = bin_dir, + lib_dir = lib_dir, + deps = ["//xla/stream_executor/cuda:all_runtime"], visibility = ["//visibility:private"], **kwargs ) @@ -195,6 +226,18 @@ def lit_test( ) # copybara:comment_end + + if hermetic_cuda_data_dir: + output_file = "with_xla_gpu_cuda_data_dir_{}".format(test_file) + rule_name = "script_{}".format(output_file) + lit_script_with_xla_gpu_cuda_data_dir( + rule_name, + test_file, + output_file, + hermetic_cuda_data_dir, + ) + test_file = output_file + native_test( name = name, src = lit_name, @@ -275,6 +318,22 @@ def _tools_on_path_impl(ctx): " {} and {} conflict".format(runfiles_symlinks[bin_path], exe)) runfiles_symlinks[bin_path] = exe + # The loop below symlinks the libraries that are used by the tools. + for dep in ctx.attr.deps: + linker_inputs = dep[CcInfo].linking_context.linker_inputs.to_list() + for linker_input in linker_inputs: + if len(linker_input.libraries) == 0: + continue + lib = linker_input.libraries[0].dynamic_library + if not lib: + continue + lib_path = paths.join(ctx.attr.lib_dir, lib.basename) + if lib_path in runfiles_symlinks: + fail("All libs used by lit tests must have unique basenames, as" + + " they are added to the path." + + " {} and {} conflict".format(runfiles_symlinks[lib_path], lib)) + runfiles_symlinks[lib_path] = lib + return [ DefaultInfo(runfiles = ctx.runfiles( symlinks = runfiles_symlinks, @@ -286,6 +345,8 @@ _tools_on_path = rule( attrs = { "srcs": attr.label_list(allow_files = True, mandatory = True), "bin_dir": attr.string(mandatory = True), + "lib_dir": attr.string(mandatory = True), + "deps": attr.label_list(), }, doc = "Symlinks srcs into a single lit_bin directory. All basenames must be unique.", ) diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index be6ff484a1b806..0ab927e906951d 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -672,6 +672,7 @@ lit_test_suite( "//xla/tools/hlo_opt:gpu_specs/v100.txtpb", ], default_tags = tf_cuda_tests_tags(), + hermetic_cuda_data_dir = "%S/../../../../../cuda_nvcc", tags_override = { "element_wise_row_vectorization.hlo": ["no_rocm"], "scatter_bf16.hlo": ["no_rocm"], diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 0ce0da73ac6be9..6466cb0b5f5068 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -29,7 +29,14 @@ load( "tf_additional_gpu_compilation_copts", ) load("//xla/tests:build_defs.bzl", "xla_test") -load("//xla/tsl:tsl.bzl", "if_google", "if_nccl", "internal_visibility", "tsl_copts") +load( + "//xla/tsl:tsl.bzl", + "if_google", + "if_hermetic_cuda_tools", + "if_nccl", + "internal_visibility", + "tsl_copts", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -134,9 +141,21 @@ cuda_only_cc_library( # Buildozer can not remove dependencies inside select guards, so we have to use # an intermediate target. -cc_library(name = "ptxas_wrapper") +cc_library( + name = "ptxas_wrapper", + data = if_hermetic_cuda_tools( + ["@cuda_nvcc//:ptxas"], + [], + ), +) -cc_library(name = "nvlink_wrapper") +cc_library( + name = "nvlink_wrapper", + data = if_hermetic_cuda_tools( + ["@cuda_nvcc//:nvlink"], + [], + ), +) cuda_only_cc_library( name = "cuda_driver", @@ -723,7 +742,7 @@ xla_cc_test( name = "nvjitlink_test", srcs = ["nvjitlink_test.cc"], args = if_google([ - # nvjitlink allocates memory and only keeps a pointer past the usual offest of 1024 bytes; + # nvjitlink allocates memory and only keeps a pointer past the usual offset of 1024 bytes; # so we need to increase the max pointer offset. -1 means no limit. # This is only relevant for Google's HeapLeakChecker. The newer Leak sanitizer doesn't # have this issue. @@ -761,6 +780,13 @@ cuda_only_cc_library( # "@local_config_cuda//cuda:runtime_ptxas", # ], # copybara:uncomment_end + # copybara:comment_begin + data = if_hermetic_cuda_tools([ + "@cuda_nvcc//:fatbinary", + "@cuda_nvcc//:nvlink", + "@cuda_nvcc//:ptxas", + ]), + # copybara:comment_end visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", diff --git a/xla/tools/BUILD b/xla/tools/BUILD index 4d2c5617438e25..3dda241bece01c 100644 --- a/xla/tools/BUILD +++ b/xla/tools/BUILD @@ -314,6 +314,7 @@ xla_cc_binary( xla_cc_binary( name = "hlo-opt", testonly = True, + linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], deps = [ "//xla/tools/hlo_opt:opt_main", ], diff --git a/xla/tools/hlo_opt/BUILD b/xla/tools/hlo_opt/BUILD index 58802ed8809bab..88271fe3b0b4ee 100644 --- a/xla/tools/hlo_opt/BUILD +++ b/xla/tools/hlo_opt/BUILD @@ -175,6 +175,7 @@ lit_test_suite( cfg = "//xla:lit.cfg.py", data = [":test_utilities"], default_tags = tf_cuda_tests_tags(), + hermetic_cuda_data_dir = "%S/../../../../cuda_nvcc", tags_override = { "gpu_hlo_ptx.hlo": ["no_rocm"], }, diff --git a/xla/tsl/cuda/BUILD.bazel b/xla/tsl/cuda/BUILD.bazel index 5c375cace438a8..c221903990c030 100644 --- a/xla/tsl/cuda/BUILD.bazel +++ b/xla/tsl/cuda/BUILD.bazel @@ -10,6 +10,10 @@ load( "cuda_rpath_flags", "if_cuda_is_configured", ) +load( + "//xla/tsl:tsl.bzl", + "if_hermetic_cuda_libs", +) load("//xla/tsl/cuda:stub.bzl", "cuda_stub") package( @@ -22,7 +26,7 @@ cuda_stub( ) cc_library( - name = "cublas", # buildifier: disable=duplicated-name + name = "cublas_stub", srcs = if_cuda_is_configured([ "cublas_stub.cc", "cublas.tramp.S", @@ -44,13 +48,19 @@ cc_library( ]), ) +alias( + name = "cublas", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cublas//:cublas", ":cublas_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cublasLt", srcs = ["cublasLt.symbols"], ) cc_library( - name = "cublas_lt", + name = "cublas_lt_stub", srcs = if_cuda_is_configured([ "cublasLt_stub.cc", "cublasLt.tramp.S", @@ -68,6 +78,12 @@ cc_library( ]), ) +alias( + name = "cublas_lt", + actual = if_hermetic_cuda_libs("@cuda_cublas//:cublasLt", ":cublas_lt_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cuda", srcs = ["cuda.symbols"], @@ -98,7 +114,7 @@ cuda_stub( ) cc_library( - name = "cudart", # buildifier: disable=duplicated-name + name = "cudart_stub", srcs = select({ # include dynamic loading implementation only when if_cuda_is_configured and build dynamically "@xla//xla/tsl:is_cuda_enabled_and_oss": [ @@ -129,13 +145,19 @@ cc_library( }), ) +alias( + name = "cudart", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cudart//:cudart", ":cudart_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cudnn", srcs = ["cudnn.symbols"], ) cc_library( - name = "cudnn", # buildifier: disable=duplicated-name + name = "cudnn_stub", srcs = if_cuda_is_configured([ "cudnn_stub.cc", "cudnn.tramp.S", @@ -155,12 +177,24 @@ cc_library( ]), ) +alias( + name = "cudnn", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cudnn//:cudnn", ":cudnn_stub"), + visibility = ["//visibility:public"], +) + cc_library( - name = "nccl_rpath", + name = "nccl_rpath_flags", linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), visibility = ["//visibility:public"], ) +alias( + name = "nccl_rpath", + actual = if_hermetic_cuda_libs("@cuda_nccl//:nccl", ":nccl_rpath_flags"), + visibility = ["//visibility:public"], +) + cc_library( name = "tensorrt_rpath", linkopts = if_cuda_is_configured(cuda_rpath_flags("tensorrt")), @@ -173,7 +207,7 @@ cuda_stub( ) cc_library( - name = "cufft", # buildifier: disable=duplicated-name + name = "cufft_stub", srcs = if_cuda_is_configured([ "cufft_stub.cc", "cufft.tramp.S", @@ -192,13 +226,19 @@ cc_library( ]), ) +alias( + name = "cufft", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cufft//:cufft", ":cufft_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cupti", srcs = ["cupti.symbols"], ) cc_library( - name = "cupti", # buildifier: disable=duplicated-name + name = "cupti_stub", srcs = if_cuda_is_configured([ "cupti_stub.cc", "cupti.tramp.S", @@ -219,13 +259,19 @@ cc_library( ]), ) +alias( + name = "cupti", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cupti//:cupti", ":cupti_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cusolver", srcs = ["cusolver.symbols"], ) cc_library( - name = "cusolver", # buildifier: disable=duplicated-name + name = "cusolver_stub", srcs = if_cuda_is_configured([ "cusolver_stub.cc", "cusolver.tramp.S", @@ -244,13 +290,19 @@ cc_library( ]), ) +alias( + name = "cusolver", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cusolver//:cusolver", ":cusolver_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cusparse", srcs = ["cusparse.symbols"], ) cc_library( - name = "cusparse", # buildifier: disable=duplicated-name + name = "cusparse_stub", srcs = if_cuda_is_configured([ "cusparse_stub.cc", "cusparse.tramp.S", @@ -270,13 +322,19 @@ cc_library( ]), ) +alias( + name = "cusparse", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cusparse//:cusparse", ":cusparse_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "nccl", srcs = ["nccl.symbols"], ) cc_library( - name = "nccl_stub", + name = "nccl", # buildifier: disable=duplicated-name srcs = if_cuda_is_configured([ "nccl_stub.cc", "nccl.tramp.S", @@ -296,3 +354,9 @@ cc_library( "@tsl//tsl/platform:load_library", ]), ) + +alias( + name = "nccl_stub", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_nccl//:nccl", ":nccl"), + visibility = ["//visibility:public"], +) diff --git a/xla/tsl/tsl.bzl b/xla/tsl/tsl.bzl index f81dd65aa065a5..8cb4dcdc76f561 100644 --- a/xla/tsl/tsl.bzl +++ b/xla/tsl/tsl.bzl @@ -221,6 +221,17 @@ def if_with_tpu_support(if_true, if_false = []): "//conditions:default": if_false, }) +# These configs are used to determine whether we should use the hermetic CUDA +# tools in cc_libraries. +# They are intended for the OSS builds only. +def if_hermetic_cuda_tools(if_true, if_false = []): # buildifier: disable=unused-variable + """Shorthand for select()'ing on whether we're building with hermetic CUDA tools.""" + return select({"@local_config_cuda//cuda:hermetic_cuda_tools": if_true, "//conditions:default": if_false}) # copybara:comment_replace return if_false + +def if_hermetic_cuda_libs(if_true, if_false = []): # buildifier: disable=unused-variable + """Shorthand for select()'ing on whether we need to include hermetic CUDA libraries.""" + return select({"@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": if_true, "//conditions:default": if_false}) # copybara:comment_replace return if_false + def get_win_copts(is_external = False): WINDOWS_COPTS = [ # copybara:uncomment_begin(no MSVC flags in google)