Skip to content

Commit

Permalink
[Distributed] Rudimentary distributed Python API (#64)
Browse files Browse the repository at this point in the history
* Add rudimentary non-production distributed Python API

* Distributed execution validation

Add functionality that validates distributed StableHLO
is producing the same results as non-distributed.

* Add execution time measurement

* Distributed Python API: add call_count to run_ranks

* Add setup script for distributed Python API

* Add JAX to install setup

---------

Co-authored-by: Boian Petkantchin <[email protected]>
  • Loading branch information
2 people authored and github-actions[bot] committed Jan 25, 2024
1 parent ed8aa6d commit aa5833a
Show file tree
Hide file tree
Showing 8 changed files with 484 additions and 0 deletions.
5 changes: 5 additions & 0 deletions runtime/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ iree_py_library(
"iree/_runtime/scripts/iree_dump_parameters/__main__.py"
"iree/_runtime/scripts/iree_run_module/__main__.py"
"iree/_runtime/scripts/iree_tracy_capture/__main__.py"
"iree/runtime/distributed/__init__.py"
"iree/runtime/distributed/distributed.py"
"iree/runtime/distributed/run_rank.py"
"iree/runtime/distributed/sharding_pass_validation.py"
"iree/runtime/distributed/utils.py"
PYEXT_DEPS
iree_runtime_bindings_python_PyExtRt
)
Expand Down
1 change: 1 addition & 0 deletions runtime/bindings/python/iree/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@
from .function import *
from .io import *

from . import distributed
from . import flags
9 changes: 9 additions & 0 deletions runtime/bindings/python/iree/runtime/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .distributed import prepare_shards_io_files, run_ranks

__all__ = ["prepare_shards_io_files", "run_ranks"]
86 changes: 86 additions & 0 deletions runtime/bindings/python/iree/runtime/distributed/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import iree.compiler
import sys
import iree.runtime
from iree.runtime.array_interop import DeviceArray
import os
from numpy.typing import ArrayLike
from typing import List, Tuple
import tempfile
import subprocess
from . import utils


def prepare_shards_io_files(
inputs: List[List[ArrayLike]], out_dir: str
) -> Tuple[List[str], List[str]]:
input_filepaths = []
output_filepaths = []
for i in range(len(inputs)):
input_filepath = os.path.join(out_dir, f"shard_{i}", "input.npy")
input_filepaths.append(input_filepath)
os.makedirs(os.path.dirname(input_filepath))
utils.write_numpy_arrays_to_file(filepath=input_filepath, arrays=inputs[i])
output_filepath = os.path.join(out_dir, f"shard_{i}", "output.npy")
output_filepaths.append(output_filepath)
return input_filepaths, output_filepaths


def run_ranks(
num_ranks: int,
module_filepath: str,
function: str,
inputs: List[List[ArrayLike]],
driver: str,
call_count: int = 1,
measure_execution_time: bool = False,
warmup: int = 0,
) -> List[List[ArrayLike]]:
"""
Start all ranks with mpirun.
On all ranks run the function |function| from the given module.
Parameters
----------
inputs : Function inputs for all ranks.
Axis 0 is ranks. Axis 1 is arguments per rank.
Returns
-------
The output of the function for all ranks.
Axis 0 is ranks. Axis 1 is arguments per rank.
"""
with tempfile.TemporaryDirectory() as out_dir:
input_filepaths, output_filepaths = prepare_shards_io_files(
inputs=inputs, out_dir=out_dir
)
hal_driver = iree.runtime.get_driver(driver)
hal_driver.query_available_devices()
subprocess.check_call(
[
"mpirun",
"--oversubscribe",
"-n",
str(num_ranks),
sys.executable,
os.path.join(os.path.dirname(__file__), "run_rank.py"),
f"--driver={driver}",
f"--module_filepath={module_filepath}",
f"--function={function}",
f"--call_count={call_count}",
]
+ (["--measure_execution_time"] if measure_execution_time else [])
+ [
f"--warmup={warmup}",
"--inputs",
]
+ input_filepaths
+ ["--outputs"]
+ output_filepaths
)
return [
utils.read_numpy_arrays_from_file(out_file) for out_file in output_filepaths
]
132 changes: 132 additions & 0 deletions runtime/bindings/python/iree/runtime/distributed/run_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import iree.compiler
import argparse
import iree.runtime
from iree.runtime.array_interop import DeviceArray
from mpi4py import MPI
import utils
import datetime
import numpy as np


def parse_args():
parser = argparse.ArgumentParser(description="Run 1 shard.")
parser.add_argument("--driver", type=str, default="local-task", help="Device URI.")
parser.add_argument(
"--module_filepath", type=str, required=True, help="Path to IREE module."
)
parser.add_argument(
"--function", type=str, required=True, help="Name of function to call."
)
parser.add_argument(
"--call_count",
type=int,
default=1,
help="How many times to call the function during time measurement.",
)
parser.add_argument(
"--measure_execution_time",
action="store_true",
default=False,
help="Measure execution time in seconds f64 and append to results.",
)
parser.add_argument(
"--warmup",
type=int,
default=0,
help="How many warmup calls to do before the actual call that generates the result.",
)
parser.add_argument(
"--inputs",
nargs="+",
type=str,
required=True,
help="Path to IREE module inputs for all ranks in npy format.",
)
parser.add_argument(
"--outputs",
nargs="+",
type=str,
required=True,
help="Path to IREE module outputs form all ranks in npy format.",
)
return parser.parse_args()


def run_module(
device: iree.runtime.HalDevice,
module_filepath: str,
function: str,
call_count: int,
input_filepath: str,
output_filepath: str,
measure_execution_time: bool,
warmup: int,
):
config = iree.runtime.Config(device=device)
with open(module_filepath, "rb") as f:
vm_flatbuffer = f.read()
vm_module = iree.runtime.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer)
bound_module = iree.runtime.load_vm_module(vm_module, config)
input_args = utils.read_numpy_arrays_from_file(input_filepath)
input_args_on_device = [
iree.runtime.asdevicearray(device, arr) for arr in input_args
]
for _ in range(warmup):
getattr(bound_module, function)(*input_args_on_device)
if measure_execution_time:
# Sync all ranks
MPI.COMM_WORLD.barrier()
start_time = datetime.datetime.now()
assert call_count > 0
for _ in range(call_count):
results = getattr(bound_module, function)(*input_args_on_device)
if measure_execution_time:
end_time = datetime.datetime.now()
if isinstance(results, DeviceArray):
results = [results]
if measure_execution_time:
if isinstance(results, tuple):
results = list(results)
results.append(
np.array((end_time - start_time).total_seconds() / call_count, dtype=float)
)
utils.write_numpy_arrays_to_file(filepath=output_filepath, arrays=results)


def run_rank(
driver: str,
module_filepath: str,
function: str,
inputs: str,
outputs: str,
call_count: int,
measure_execution_time: bool,
warmup: int,
):
rank = MPI.COMM_WORLD.Get_rank()
hal_driver = iree.runtime.get_driver(driver)
device_infos = hal_driver.query_available_devices()
device = hal_driver.create_device(
device_infos[rank % len(device_infos)]["device_id"]
)
run_module(
device=device,
module_filepath=module_filepath,
function=function,
call_count=call_count,
input_filepath=inputs[rank],
output_filepath=outputs[rank],
measure_execution_time=measure_execution_time,
warmup=warmup,
)


if __name__ == "__main__":
args = parse_args()
run_rank(**vars(args))
15 changes: 15 additions & 0 deletions runtime/bindings/python/iree/runtime/distributed/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

set -e

distribution=$(. /etc/os-release;echo $ID$VERSION_ID | sed -e 's/\.//g')
wget -O /tmp/cuda-keyring_1.0-1_all.deb \
https://developer.download.nvidia.com/compute/cuda/repos/$distribution/x86_64/cuda-keyring_1.0-1_all.deb
sudo dpkg -i /tmp/cuda-keyring_1.0-1_all.deb
sudo apt update
# For CMake to find CUDA when using LLD.
sudo apt -y install lld

sudo apt -y install libopenmpi-dev
sudo apt -y install libnccl-dev=2.18.1-1+cuda12.1
pip install mpi4py jax[cpu]
Loading

0 comments on commit aa5833a

Please sign in to comment.