From 05308891e203329a733bcf29a3452b15b75b5eb4 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:55:40 -0700 Subject: [PATCH] [Core] Pipeline parallel with Ray ADAG (#6837) Support pipeline-parallelism with Ray accelerated DAG. Signed-off-by: Rui Qiao --- Dockerfile | 2 + MANIFEST.in | 1 + requirements-adag.txt | 3 + requirements-test.txt | 3 + tests/distributed/test_pipeline_parallel.py | 51 +++++--- tests/utils.py | 31 ++++- vllm/envs.py | 12 +- vllm/executor/ray_gpu_executor.py | 137 +++++++++++++------- vllm/executor/ray_utils.py | 30 ++++- vllm/worker/worker_base.py | 6 +- 10 files changed, 199 insertions(+), 77 deletions(-) create mode 100644 requirements-adag.txt diff --git a/Dockerfile b/Dockerfile index 7294707046abc..49aaea2949ac6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,6 +42,7 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt @@ -78,6 +79,7 @@ COPY setup.py setup.py COPY cmake cmake COPY CMakeLists.txt CMakeLists.txt COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt COPY pyproject.toml pyproject.toml COPY vllm vllm diff --git a/MANIFEST.in b/MANIFEST.in index 82be639ef4d73..5a41e5e714184 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include LICENSE +include requirements-adag.txt include requirements-common.txt include requirements-cuda.txt include requirements-rocm.txt diff --git a/requirements-adag.txt b/requirements-adag.txt new file mode 100644 index 0000000000000..e77f90fb8f85d --- /dev/null +++ b/requirements-adag.txt @@ -0,0 +1,3 @@ +# Dependencies for Ray accelerated DAG +cupy-cuda12x +ray >= 2.32 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index df247496be16c..5f3fd15c7ee56 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,6 @@ +# Needed for Ray accelerated DAG tests +-r requirements-adag.txt + # testing pytest tensorizer>=2.9.0 diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index f632caba9017e..ab325e0966929 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -15,22 +15,31 @@ @pytest.mark.parametrize( - "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND", - [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - ]) -@fork_new_process_for_each_test + ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " + "MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND): + DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") @@ -67,8 +76,18 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, if EAGER_MODE: pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + pp_env = None + if USE_RAY_ADAG: + assert DIST_BACKEND == "ray", ( + "Ray ADAG is only supported with Ray distributed backend") + pp_env = { + "VLLM_USE_RAY_COMPILED_DAG": "1", + "VLLM_USE_RAY_SPMD_WORKER": "1", + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": + str(int(USE_RAY_ADAG_NCCL)), + } - compare_two_settings(MODEL_NAME, pp_args, tp_args) + compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env) @pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ diff --git a/tests/utils.py b/tests/utils.py index f3ee801ee7742..dd8af8e3afe70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import openai import ray @@ -57,6 +57,7 @@ def __init__( model: str, cli_args: List[str], *, + env_dict: Optional[Dict[str, str]] = None, auto_port: bool = True, ) -> None: if auto_port: @@ -77,6 +78,8 @@ def __init__( # the current process might initialize cuda, # to be safe, we should use spawn method env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, env=env, stdout=sys.stdout, @@ -89,6 +92,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() + try: + self.proc.wait(3) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() def _wait_for_server(self, *, url: str, timeout: float): # run health check @@ -127,10 +135,21 @@ def get_async_client(self): ) -def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): +def compare_two_settings(model: str, + arg1: List[str], + arg2: List[str], + env1: Optional[Dict[str, str]] = None, + env2: Optional[Dict[str, str]] = None): """ - Launch API server with two different sets of arguments and compare the - results of the API calls. The arguments are after the model name. + Launch API server with two different sets of arguments/environments + and compare the results of the API calls. + + Args: + model: The model to test. + arg1: The first set of arguments to pass to the API server. + arg2: The second set of arguments to pass to the API server. + env1: The first set of environment variables to pass to the API server. + env2: The second set of environment variables to pass to the API server. """ tokenizer = AutoTokenizer.from_pretrained(model) @@ -138,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] results = [] - for args in (arg1, arg2): - with RemoteOpenAIServer(model, args) as server: + for args, env in ((arg1, env1), (arg2, env2)): + with RemoteOpenAIServer(model, args, env_dict=env) as server: client = server.get_client() # test models list diff --git a/vllm/envs.py b/vllm/envs.py index 9bcb26f8e5a64..5b8a65bd6545c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -38,6 +38,7 @@ VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -273,13 +274,20 @@ def get_default_config_root(): # execution on all workers. # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. "VLLM_USE_RAY_SPMD_WORKER": - lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), + lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. "VLLM_USE_RAY_COMPILED_DAG": - lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), + + # If the env var is set, it uses NCCL for communication in + # Ray's compiled DAG. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 14007e5518d4a..46d216910a08a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -105,12 +105,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs) + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) # Create the workers. driver_ip = get_ip() + logger.info("driver_ip: %s", driver_ip) worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): @@ -142,42 +149,49 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Else, added to the list of workers. self.workers.append(worker) + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "GPU node.") + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) - # the order in `worker_node_and_gpu_ids` does not necessarily match - # the machine boundaries. We need to make sure that workers in the - # same node are assigned consecutive ranks. - # examples: - # [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa - - # initialize worker ranks with -1 (unassigned) - worker_ranks = [-1 for x in worker_node_and_gpu_ids] - current_rank = 0 - while -1 in worker_ranks: - # whenever we find an unassigned worker, find the node - index = worker_ranks.index(-1) - current_node_id = worker_node_and_gpu_ids[index][0] - # assign ranks to all workers in the same node - for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): - if node_id == current_node_id: - worker_ranks[i] = current_rank - current_rank += 1 - # with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3] - node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids - for worker_rank, (node_id, gpu_ids) in zip(worker_ranks, - worker_node_and_gpu_ids): - node_workers[node_id].append(worker_rank) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) # `gpu_ids` can be a list of strings or integers. # convert them to integers for consistency. # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), @@ -202,16 +216,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) - if len(node_gpus) == 1: - # in single node case, we don't need to get the IP address. - # the loopback address is sufficient - # NOTE: a node may have several IP addresses, one for each - # network interface. `get_ip()` might return any of them, - # while they might not work for communication inside the node - # if the network setup is complicated. Using the loopback address - # solves this issue, as it always works for communication inside - # the node. - driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) @@ -221,8 +225,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank=node_workers[node_id].index(rank), rank=rank, distributed_init_method=distributed_init_method, - ) for rank, (node_id, - _) in zip(worker_ranks, worker_node_and_gpu_ids) + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) @@ -231,6 +234,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + # This is the list of workers that are rank 0 of each TP group EXCEPT # global rank 0. These are the workers that will broadcast to the # rest of the workers. @@ -241,9 +257,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self.non_driver_workers: List[RayWorkerWrapper] = [] # Enforce rank order for correct rank to return final output. - for rank, worker in sorted(zip(worker_ranks[1:], self.workers)): - # We need to skip the driver worker, which we - # do by skipping worker_ranks[0] which is always 0. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 if rank % self.parallel_config.tensor_parallel_size == 0: self.tp_driver_workers.append(worker) else: @@ -376,16 +392,47 @@ def _compiled_ray_dag(self, enable_asyncio: bool): raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") - from ray.dag import InputNode, MultiOutputNode assert self.parallel_config.use_ray + from ray.dag import InputNode, MultiOutputNode + from ray.experimental.channel.torch_tensor_type import TorchTensorType - # Right now, compiled DAG requires at least 1 arg. We send - # a dummy value for now. It will be fixed soon. + logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) with InputNode() as input_data: - forward_dag = MultiOutputNode([ - worker.execute_model_spmd.bind( # type: ignore[attr-defined] - input_data) for worker in self.workers - ]) + # Example DAG: PP=2, TP=4 + # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 + # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501 + # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501 + # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501 + + # All workers in the first TP group will take in the + # ExecuteModelRequest as input. + outputs = [input_data for _ in self.pp_tp_workers[0]] + for pp_rank, tp_group in enumerate(self.pp_tp_workers): + # Each PP worker takes in the output of the previous PP worker, + # and the TP group executes in SPMD fashion. + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + + last_pp_rank = len(self.pp_tp_workers) - 1 + if pp_rank < last_pp_rank: + # Specify how intermediate tensors should be passed + # between pp stages, no need to specify for the last + # pp stage. + transport = "nccl" \ + if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \ + else "auto" + outputs = [ + output.with_type_hint( + TorchTensorType(transport=transport)) + for output in outputs + ] + + forward_dag = MultiOutputNode(outputs) + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) def __del__(self): diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 58b864070f727..ac948331e81e0 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,8 +1,8 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase @@ -31,9 +31,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): - """Used only when SPMD worker and compiled DAG are both - enabled.""" + def execute_model_spmd( + self, req_or_tuple: Union[ExecuteModelRequest, + Tuple[ExecuteModelRequest, + IntermediateTensors]]): + """Execute model in SPMD fashion: used only when SPMD worker and + compiled DAG are both enabled. + + Args: + req_or_tuple: The request to execute the model, or a tuple + containing the request and intermediate tensors. + """ # TODO(swang): This is needed right now because Ray aDAG executes # on a background thread, so we need to reset torch's current # device. @@ -42,7 +50,17 @@ def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - return self.worker._execute_model_spmd(execute_model_req) + if isinstance(req_or_tuple, tuple): + execute_model_req, intermediate_tensors = req_or_tuple + else: + execute_model_req = req_or_tuple + intermediate_tensors = None + + output = self.worker._execute_model_spmd(execute_model_req, + intermediate_tensors) + if isinstance(output, IntermediateTensors): + return execute_model_req, output + return output ray_import_err = None diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8a4d1958c65a0..e56440693b895 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -285,7 +285,9 @@ def execute_model( return output def _execute_model_spmd( - self, execute_model_req: ExecuteModelRequest + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None ) -> Optional[List[SamplerOutput]]: """ Execute model in Single Program Multiple Data (SPMD) fashion. @@ -309,7 +311,7 @@ def _execute_model_spmd( return self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None) + if self.kv_cache is not None else None, intermediate_tensors) class WorkerWrapperBase: