diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt index 44112ce51db98..a863fd2a5fe06 100644 --- a/runtime/bindings/python/CMakeLists.txt +++ b/runtime/bindings/python/CMakeLists.txt @@ -150,6 +150,11 @@ iree_py_library( "iree/_runtime/scripts/iree_run_trace/__main__.py" "iree/_runtime/scripts/iree_run_module/__main__.py" "iree/_runtime/scripts/iree_tracy_capture/__main__.py" + "iree/runtime/distributed/__init__.py" + "iree/runtime/distributed/distributed.py" + "iree/runtime/distributed/run_rank.py" + "iree/runtime/distributed/sharding_pass_validation.py" + "iree/runtime/distributed/utils.py" PYEXT_DEPS iree_runtime_bindings_python_PyExtRt ) diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py index d594b23c88d00..a9201b863d5f6 100644 --- a/runtime/bindings/python/iree/runtime/__init__.py +++ b/runtime/bindings/python/iree/runtime/__init__.py @@ -66,4 +66,5 @@ from .io import * from .tracing import * +from . import distributed from . import flags diff --git a/runtime/bindings/python/iree/runtime/distributed/__init__.py b/runtime/bindings/python/iree/runtime/distributed/__init__.py new file mode 100644 index 0000000000000..86ee5db110ccb --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .distributed import prepare_shards_io_files, run_ranks + +__all__ = ["prepare_shards_io_files", "run_ranks"] diff --git a/runtime/bindings/python/iree/runtime/distributed/distributed.py b/runtime/bindings/python/iree/runtime/distributed/distributed.py new file mode 100644 index 0000000000000..258e517b2cf23 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/distributed.py @@ -0,0 +1,86 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.compiler +import sys +import iree.runtime +from iree.runtime.array_interop import DeviceArray +import os +from numpy.typing import ArrayLike +from typing import List, Tuple +import tempfile +import subprocess +from . import utils + + +def prepare_shards_io_files( + inputs: List[List[ArrayLike]], out_dir: str +) -> Tuple[List[str], List[str]]: + input_filepaths = [] + output_filepaths = [] + for i in range(len(inputs)): + input_filepath = os.path.join(out_dir, f"shard_{i}", "input.npy") + input_filepaths.append(input_filepath) + os.makedirs(os.path.dirname(input_filepath)) + utils.write_numpy_arrays_to_file(filepath=input_filepath, arrays=inputs[i]) + output_filepath = os.path.join(out_dir, f"shard_{i}", "output.npy") + output_filepaths.append(output_filepath) + return input_filepaths, output_filepaths + + +def run_ranks( + num_ranks: int, + module_filepath: str, + function: str, + inputs: List[List[ArrayLike]], + driver: str, + call_count: int = 1, + measure_execution_time: bool = False, + warmup: int = 0, +) -> List[List[ArrayLike]]: + """ + Start all ranks with mpirun. + On all ranks run the function |function| from the given module. + Parameters + ---------- + inputs : Function inputs for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + Returns + ------- + The output of the function for all ranks. + Axis 0 is ranks. Axis 1 is arguments per rank. + """ + with tempfile.TemporaryDirectory() as out_dir: + input_filepaths, output_filepaths = prepare_shards_io_files( + inputs=inputs, out_dir=out_dir + ) + hal_driver = iree.runtime.get_driver(driver) + hal_driver.query_available_devices() + subprocess.check_call( + [ + "mpirun", + "--oversubscribe", + "-n", + str(num_ranks), + sys.executable, + os.path.join(os.path.dirname(__file__), "run_rank.py"), + f"--driver={driver}", + f"--module_filepath={module_filepath}", + f"--function={function}", + f"--call_count={call_count}", + ] + + (["--measure_execution_time"] if measure_execution_time else []) + + [ + f"--warmup={warmup}", + "--inputs", + ] + + input_filepaths + + ["--outputs"] + + output_filepaths + ) + return [ + utils.read_numpy_arrays_from_file(out_file) for out_file in output_filepaths + ] diff --git a/runtime/bindings/python/iree/runtime/distributed/run_rank.py b/runtime/bindings/python/iree/runtime/distributed/run_rank.py new file mode 100644 index 0000000000000..7ad00f7256cca --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/run_rank.py @@ -0,0 +1,132 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.compiler +import argparse +import iree.runtime +from iree.runtime.array_interop import DeviceArray +from mpi4py import MPI +import utils +import datetime +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run 1 shard.") + parser.add_argument("--driver", type=str, default="local-task", help="Device URI.") + parser.add_argument( + "--module_filepath", type=str, required=True, help="Path to IREE module." + ) + parser.add_argument( + "--function", type=str, required=True, help="Name of function to call." + ) + parser.add_argument( + "--call_count", + type=int, + default=1, + help="How many times to call the function during time measurement.", + ) + parser.add_argument( + "--measure_execution_time", + action="store_true", + default=False, + help="Measure execution time in seconds f64 and append to results.", + ) + parser.add_argument( + "--warmup", + type=int, + default=0, + help="How many warmup calls to do before the actual call that generates the result.", + ) + parser.add_argument( + "--inputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module inputs for all ranks in npy format.", + ) + parser.add_argument( + "--outputs", + nargs="+", + type=str, + required=True, + help="Path to IREE module outputs form all ranks in npy format.", + ) + return parser.parse_args() + + +def run_module( + device: iree.runtime.HalDevice, + module_filepath: str, + function: str, + call_count: int, + input_filepath: str, + output_filepath: str, + measure_execution_time: bool, + warmup: int, +): + config = iree.runtime.Config(device=device) + with open(module_filepath, "rb") as f: + vm_flatbuffer = f.read() + vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer) + bound_module = iree.runtime.load_vm_module(vm_module, config) + input_args = utils.read_numpy_arrays_from_file(input_filepath) + input_args_on_device = [ + iree.runtime.asdevicearray(device, arr) for arr in input_args + ] + for _ in range(warmup): + getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + # Sync all ranks + MPI.COMM_WORLD.barrier() + start_time = datetime.datetime.now() + assert call_count > 0 + for _ in range(call_count): + results = getattr(bound_module, function)(*input_args_on_device) + if measure_execution_time: + end_time = datetime.datetime.now() + if isinstance(results, DeviceArray): + results = [results] + if measure_execution_time: + if isinstance(results, tuple): + results = list(results) + results.append( + np.array((end_time - start_time).total_seconds() / call_count, dtype=float) + ) + utils.write_numpy_arrays_to_file(filepath=output_filepath, arrays=results) + + +def run_rank( + driver: str, + module_filepath: str, + function: str, + inputs: str, + outputs: str, + call_count: int, + measure_execution_time: bool, + warmup: int, +): + rank = MPI.COMM_WORLD.Get_rank() + hal_driver = iree.runtime.get_driver(driver) + device_infos = hal_driver.query_available_devices() + device = hal_driver.create_device( + device_infos[rank % len(device_infos)]["device_id"] + ) + run_module( + device=device, + module_filepath=module_filepath, + function=function, + call_count=call_count, + input_filepath=inputs[rank], + output_filepath=outputs[rank], + measure_execution_time=measure_execution_time, + warmup=warmup, + ) + + +if __name__ == "__main__": + args = parse_args() + run_rank(**vars(args)) diff --git a/runtime/bindings/python/iree/runtime/distributed/setup.sh b/runtime/bindings/python/iree/runtime/distributed/setup.sh new file mode 100644 index 0000000000000..83dca488caa4f --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/setup.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +distribution=$(. /etc/os-release;echo $ID$VERSION_ID | sed -e 's/\.//g') +wget -O /tmp/cuda-keyring_1.0-1_all.deb \ + https://developer.download.nvidia.com/compute/cuda/repos/$distribution/x86_64/cuda-keyring_1.0-1_all.deb +sudo dpkg -i /tmp/cuda-keyring_1.0-1_all.deb +sudo apt update +# For CMake to find CUDA when using LLD. +sudo apt -y install lld + +sudo apt -y install libopenmpi-dev +sudo apt -y install libnccl-dev=2.18.1-1+cuda12.1 +pip install mpi4py jax[cpu] diff --git a/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py new file mode 100644 index 0000000000000..4465204516237 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/sharding_pass_validation.py @@ -0,0 +1,210 @@ +import iree.compiler +import iree.runtime +import os +from iree.runtime.distributed import run_ranks +import subprocess +from pathlib import Path +from jax._src.lib import xla_client +from jaxlib.xla_client import HloSharding +from typing import List, Tuple, Union +from numpy.typing import ArrayLike +import jax +from jax._src.sharding_impls import GSPMDSharding +import jax._src.interpreters.pxla as pxla +import numpy as np +from datetime import timedelta + +xla_extension = xla_client._xla + + +def compile_mlir(mlir_filepath: str, output_filepath: str, use_cache: bool, **kwargs): + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + iree.compiler.compile_file( + input_file=mlir_filepath, output_file=output_filepath, **kwargs + ) + + +def extract_args_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + return [ + HloSharding.from_proto(sharding) + for sharding in xla_computation.get_hlo_module().spmd_parameters_shardings + ] + + +def extract_results_sharding( + xla_computation: xla_extension.XlaComputation, +) -> List[HloSharding]: + sharding = HloSharding.from_proto( + xla_computation.get_hlo_module().spmd_output_sharding + ) + if len(sharding.tuple_elements()): + return sharding.tuple_elements() + else: + return [sharding] + + +def shard_arg(arg: ArrayLike, sharding: HloSharding) -> List[ArrayLike]: + gspmd_sharding = GSPMDSharding(devices=jax.local_devices(), op_sharding=sharding) + indices = gspmd_sharding.devices_indices_map(arg.shape).values() + sharded_array = pxla.shard_arg( + arg, devices=jax.local_devices(), arg_indices=indices, sharding=gspmd_sharding + ) + return [shard.data for shard in sharded_array.global_shards] + + +def shard_args( + args: List[ArrayLike], shardings: List[HloSharding] +) -> List[List[ArrayLike]]: + assert len(args) == len(shardings) + return [shard_arg(arg, sharding) for arg, sharding in zip(args, shardings)] + + +def assemble_shards(shards: List[ArrayLike], sharding: HloSharding) -> ArrayLike: + if sharding.is_replicated(): + return shards[0] + else: + raise NotImplementedError() + + +def propagate_shardings_and_spmd_partition( + mlir_filepath: str, + output_filepath: str, + num_devices: int, + use_cache: bool, + allow_spmd_sharding_propagation_to_output: int = 1, +): + res = subprocess.run( + [ + "stablehlo-opt", + ( + "--pass-pipeline=builtin.module(stablehlo-xla-sharding-propagation-and-spmd-partitioner{" + "is_spmd=1 " + f"allow_spmd_sharding_propagation_to_output={allow_spmd_sharding_propagation_to_output} " + "allow_spmd_sharding_propagation_to_parameters=1 " + f"num_partitions={num_devices} " + "num_replicas=1})" + ), + mlir_filepath, + ], + check=True, + stdout=subprocess.PIPE, + ) + Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) + if use_cache and os.path.exists(output_filepath): + return + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + with open(output_filepath, "wb") as f: + f.write(res.stdout) + + +def swap_shard_axis(arrays: List[ArrayLike]) -> List[List[ArrayLike]]: + """Swap axis 0 with 1.""" + if len(arrays) == 0: + return [] + expected_shards = len(arrays[0]) + res = [[] for _ in range(expected_shards)] + for arr in arrays: + assert len(arr) == expected_shards + for shard in range(expected_shards): + res[shard].append(arr[shard]) + return res + + +def execute_distributed( + num_ranks: int, + mlir_filepath: str, + iree_module_filepath: str, + function: str, + inputs: List[ArrayLike], + driver: str, + measure_execution_time: bool = False, +) -> Union[List[ArrayLike], Tuple[List[ArrayLike], timedelta]]: + with open(mlir_filepath, "r") as f: + mlir_str = f.read() + xla_computation = xla_extension.mlir.mlir_module_to_xla_computation( + mlir_module=mlir_str, use_tuple_args=False, return_tuple=False + ) + args_sharding = extract_args_sharding(xla_computation) + results_sharding = extract_results_sharding(xla_computation) + sharded_args = shard_args(args=inputs, shardings=args_sharding) + sharded_args = swap_shard_axis(sharded_args) + sharded_results = run_ranks( + num_ranks=num_ranks, + module_filepath=iree_module_filepath, + function=function, + inputs=sharded_args, + driver=driver, + ) + sharded_results = swap_shard_axis(sharded_results) + if measure_execution_time: + sharded_results, execution_times = sharded_results + res = [ + assemble_shards(shards=result_shards, sharding=sharding) + for result_shards, sharding in zip(sharded_results, results_sharding) + ] + if measure_execution_time: + res = res, timedelta(seconds=np.max(execution_times)) + return res + + +def validate_sharding_passes( + mlir_filepath: str, + mlir_with_sharding_annotations_filepath: str, + inputs: List[ArrayLike], + function: str, + num_devices: int, + use_cache: bool, + driver: str, + target_backend: str, + output_prefix_path: str, + allow_spmd_sharding_propagation_to_output: int = 1, +): + # Single instance. + iree_module_filepath = ( + f"{output_prefix_path}{os.path.basename(mlir_filepath)}.{driver}.vmfb" + ) + compile_mlir( + mlir_filepath=mlir_filepath, + output_filepath=iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + iree_module = iree.runtime.load_vm_flatbuffer_file( + path=iree_module_filepath, driver=driver + ) + results = iree_module[function](*inputs) + if isinstance(results, iree.runtime.DeviceArray): + results = [results] + + # Distributed. + spmd_mlir_filepath = f"{output_prefix_path}{os.path.basename(mlir_with_sharding_annotations_filepath)}.spmd.mlir" + propagate_shardings_and_spmd_partition( + mlir_filepath=mlir_with_sharding_annotations_filepath, + output_filepath=spmd_mlir_filepath, + num_devices=num_devices, + use_cache=use_cache, + allow_spmd_sharding_propagation_to_output=allow_spmd_sharding_propagation_to_output, + ) + spmd_iree_module_filepath = f"{output_prefix_path}{os.path.basename(spmd_mlir_filepath)}.{target_backend}.vmfb" + compile_mlir( + mlir_filepath=spmd_mlir_filepath, + output_filepath=spmd_iree_module_filepath, + use_cache=use_cache, + target_backends=[target_backend], + ) + spmd_results = execute_distributed( + num_ranks=num_devices, + mlir_filepath=spmd_mlir_filepath, + iree_module_filepath=spmd_iree_module_filepath, + function=function, + inputs=inputs, + driver=driver, + ) + + assert len(results) == len(spmd_results) + for result, spmd_result in zip(results, spmd_results): + np.testing.assert_allclose(result, spmd_result, atol=1e-7) diff --git a/runtime/bindings/python/iree/runtime/distributed/utils.py b/runtime/bindings/python/iree/runtime/distributed/utils.py new file mode 100644 index 0000000000000..3581baf354f86 --- /dev/null +++ b/runtime/bindings/python/iree/runtime/distributed/utils.py @@ -0,0 +1,26 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from numpy.typing import ArrayLike +from typing import List +import numpy as np + + +def read_numpy_arrays_from_file(filepath: str) -> List[ArrayLike]: + res = [] + with open(filepath, "rb") as f: + while True: + try: + res.append(np.load(f)) + except EOFError: + break + return res + + +def write_numpy_arrays_to_file(filepath: str, arrays: List[ArrayLike]): + with open(filepath, "wb") as f: + for arr in arrays: + np.save(f, np.asarray(arr), allow_pickle=False)