forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Distributed] Rudimentary distributed Python API (#64)
* Add rudimentary non-production distributed Python API * Distributed execution validation Add functionality that validates distributed StableHLO is producing the same results as non-distributed. * Add execution time measurement * Distributed Python API: add call_count to run_ranks * Add setup script for distributed Python API * Add JAX to install setup --------- Co-authored-by: Boian Petkantchin <[email protected]>
- Loading branch information
1 parent
3cf83ba
commit c680dfb
Showing
8 changed files
with
484 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,4 +66,5 @@ | |
from .io import * | ||
from .tracing import * | ||
|
||
from . import distributed | ||
from . import flags |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .distributed import prepare_shards_io_files, run_ranks | ||
|
||
__all__ = ["prepare_shards_io_files", "run_ranks"] |
86 changes: 86 additions & 0 deletions
86
runtime/bindings/python/iree/runtime/distributed/distributed.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import iree.compiler | ||
import sys | ||
import iree.runtime | ||
from iree.runtime.array_interop import DeviceArray | ||
import os | ||
from numpy.typing import ArrayLike | ||
from typing import List, Tuple | ||
import tempfile | ||
import subprocess | ||
from . import utils | ||
|
||
|
||
def prepare_shards_io_files( | ||
inputs: List[List[ArrayLike]], out_dir: str | ||
) -> Tuple[List[str], List[str]]: | ||
input_filepaths = [] | ||
output_filepaths = [] | ||
for i in range(len(inputs)): | ||
input_filepath = os.path.join(out_dir, f"shard_{i}", "input.npy") | ||
input_filepaths.append(input_filepath) | ||
os.makedirs(os.path.dirname(input_filepath)) | ||
utils.write_numpy_arrays_to_file(filepath=input_filepath, arrays=inputs[i]) | ||
output_filepath = os.path.join(out_dir, f"shard_{i}", "output.npy") | ||
output_filepaths.append(output_filepath) | ||
return input_filepaths, output_filepaths | ||
|
||
|
||
def run_ranks( | ||
num_ranks: int, | ||
module_filepath: str, | ||
function: str, | ||
inputs: List[List[ArrayLike]], | ||
driver: str, | ||
call_count: int = 1, | ||
measure_execution_time: bool = False, | ||
warmup: int = 0, | ||
) -> List[List[ArrayLike]]: | ||
""" | ||
Start all ranks with mpirun. | ||
On all ranks run the function |function| from the given module. | ||
Parameters | ||
---------- | ||
inputs : Function inputs for all ranks. | ||
Axis 0 is ranks. Axis 1 is arguments per rank. | ||
Returns | ||
------- | ||
The output of the function for all ranks. | ||
Axis 0 is ranks. Axis 1 is arguments per rank. | ||
""" | ||
with tempfile.TemporaryDirectory() as out_dir: | ||
input_filepaths, output_filepaths = prepare_shards_io_files( | ||
inputs=inputs, out_dir=out_dir | ||
) | ||
hal_driver = iree.runtime.get_driver(driver) | ||
hal_driver.query_available_devices() | ||
subprocess.check_call( | ||
[ | ||
"mpirun", | ||
"--oversubscribe", | ||
"-n", | ||
str(num_ranks), | ||
sys.executable, | ||
os.path.join(os.path.dirname(__file__), "run_rank.py"), | ||
f"--driver={driver}", | ||
f"--module_filepath={module_filepath}", | ||
f"--function={function}", | ||
f"--call_count={call_count}", | ||
] | ||
+ (["--measure_execution_time"] if measure_execution_time else []) | ||
+ [ | ||
f"--warmup={warmup}", | ||
"--inputs", | ||
] | ||
+ input_filepaths | ||
+ ["--outputs"] | ||
+ output_filepaths | ||
) | ||
return [ | ||
utils.read_numpy_arrays_from_file(out_file) for out_file in output_filepaths | ||
] |
132 changes: 132 additions & 0 deletions
132
runtime/bindings/python/iree/runtime/distributed/run_rank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import iree.compiler | ||
import argparse | ||
import iree.runtime | ||
from iree.runtime.array_interop import DeviceArray | ||
from mpi4py import MPI | ||
import utils | ||
import datetime | ||
import numpy as np | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Run 1 shard.") | ||
parser.add_argument("--driver", type=str, default="local-task", help="Device URI.") | ||
parser.add_argument( | ||
"--module_filepath", type=str, required=True, help="Path to IREE module." | ||
) | ||
parser.add_argument( | ||
"--function", type=str, required=True, help="Name of function to call." | ||
) | ||
parser.add_argument( | ||
"--call_count", | ||
type=int, | ||
default=1, | ||
help="How many times to call the function during time measurement.", | ||
) | ||
parser.add_argument( | ||
"--measure_execution_time", | ||
action="store_true", | ||
default=False, | ||
help="Measure execution time in seconds f64 and append to results.", | ||
) | ||
parser.add_argument( | ||
"--warmup", | ||
type=int, | ||
default=0, | ||
help="How many warmup calls to do before the actual call that generates the result.", | ||
) | ||
parser.add_argument( | ||
"--inputs", | ||
nargs="+", | ||
type=str, | ||
required=True, | ||
help="Path to IREE module inputs for all ranks in npy format.", | ||
) | ||
parser.add_argument( | ||
"--outputs", | ||
nargs="+", | ||
type=str, | ||
required=True, | ||
help="Path to IREE module outputs form all ranks in npy format.", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def run_module( | ||
device: iree.runtime.HalDevice, | ||
module_filepath: str, | ||
function: str, | ||
call_count: int, | ||
input_filepath: str, | ||
output_filepath: str, | ||
measure_execution_time: bool, | ||
warmup: int, | ||
): | ||
config = iree.runtime.Config(device=device) | ||
with open(module_filepath, "rb") as f: | ||
vm_flatbuffer = f.read() | ||
vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer) | ||
bound_module = iree.runtime.load_vm_module(vm_module, config) | ||
input_args = utils.read_numpy_arrays_from_file(input_filepath) | ||
input_args_on_device = [ | ||
iree.runtime.asdevicearray(device, arr) for arr in input_args | ||
] | ||
for _ in range(warmup): | ||
getattr(bound_module, function)(*input_args_on_device) | ||
if measure_execution_time: | ||
# Sync all ranks | ||
MPI.COMM_WORLD.barrier() | ||
start_time = datetime.datetime.now() | ||
assert call_count > 0 | ||
for _ in range(call_count): | ||
results = getattr(bound_module, function)(*input_args_on_device) | ||
if measure_execution_time: | ||
end_time = datetime.datetime.now() | ||
if isinstance(results, DeviceArray): | ||
results = [results] | ||
if measure_execution_time: | ||
if isinstance(results, tuple): | ||
results = list(results) | ||
results.append( | ||
np.array((end_time - start_time).total_seconds() / call_count, dtype=float) | ||
) | ||
utils.write_numpy_arrays_to_file(filepath=output_filepath, arrays=results) | ||
|
||
|
||
def run_rank( | ||
driver: str, | ||
module_filepath: str, | ||
function: str, | ||
inputs: str, | ||
outputs: str, | ||
call_count: int, | ||
measure_execution_time: bool, | ||
warmup: int, | ||
): | ||
rank = MPI.COMM_WORLD.Get_rank() | ||
hal_driver = iree.runtime.get_driver(driver) | ||
device_infos = hal_driver.query_available_devices() | ||
device = hal_driver.create_device( | ||
device_infos[rank % len(device_infos)]["device_id"] | ||
) | ||
run_module( | ||
device=device, | ||
module_filepath=module_filepath, | ||
function=function, | ||
call_count=call_count, | ||
input_filepath=inputs[rank], | ||
output_filepath=outputs[rank], | ||
measure_execution_time=measure_execution_time, | ||
warmup=warmup, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
run_rank(**vars(args)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/bin/bash | ||
|
||
set -e | ||
|
||
distribution=$(. /etc/os-release;echo $ID$VERSION_ID | sed -e 's/\.//g') | ||
wget -O /tmp/cuda-keyring_1.0-1_all.deb \ | ||
https://developer.download.nvidia.com/compute/cuda/repos/$distribution/x86_64/cuda-keyring_1.0-1_all.deb | ||
sudo dpkg -i /tmp/cuda-keyring_1.0-1_all.deb | ||
sudo apt update | ||
# For CMake to find CUDA when using LLD. | ||
sudo apt -y install lld | ||
|
||
sudo apt -y install libopenmpi-dev | ||
sudo apt -y install libnccl-dev=2.18.1-1+cuda12.1 | ||
pip install mpi4py jax[cpu] |
Oops, something went wrong.