From 7c2c0d9337f3b353576bccc30f61c16abcc633a7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 30 May 2024 06:28:50 -0500 Subject: [PATCH] [Disco][QoL] Implement broadcast/scatter methods for Session (#17035) * [Disco][QoL] Implement broadcast/scatter methods for Session Prior to this commit, use of the `disco.Session` API to broadcast or scatter an array required several steps from the caller. 1. Allocate memory on worker0 2. Transfer data from the controller to worker0 3. Allocate memory on each worker 4. Broadcast/scatter data from worker0 to all workers While exposing these steps is necessary for performance, especially when used repeatedly, it can be tedious/error-prone to use for initialization that is only performed once. This commit adds utility methods `Session.broadcast` and `Session.scatter`, which are implemented in terms of the existing lower-level methods `Session.broadcast_from_worker0` and `Session.scatter_from_worker0`. These methods perform the transfer from the controller to worker0, and from worker0 to all other workers. * lint fix --- python/tvm/runtime/disco/session.py | 102 ++++++++++++++++++++++++++-- tests/python/disco/test_ccl.py | 70 ++++++++++++++++--- 2 files changed, 158 insertions(+), 14 deletions(-) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 97edeff1d19a..ddde1bc1f323 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -249,17 +249,34 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: """ return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member - def copy_to_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: + def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = None) -> DRef: """Copy the controller-side NDArray to worker-0. Parameters ---------- - host_array : numpy.ndarray - The array to be copied from worker-0. - remote_array : NDArray - The NDArray on worker-0. + host_array : NDArray + + The array to be copied to worker-0. + + remote_array : Optiona[DRef] + + The destination NDArray on worker-0. + + Returns + ------- + output_array: DRef + + The DRef containing the copied data on worker0, and + NullOpt on all other workers. If `remote_array` was + provided, this return value is the same as `remote_array`. + Otherwise, it is the newly allocated space. + """ - return _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member + if remote_array is None: + remote_array = self.empty(host_array.shape, host_array.dtype, worker0_only=True) + + _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member + return remote_array def load_vm_module( self, @@ -302,6 +319,40 @@ def init_ccl(self, ccl: str, *device_ids): _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member self._clear_ipc_memory_pool() + def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + """Broadcast an array to all workers + + Parameters + ---------- + src: Union[np.ndarray, NDArray] + + The array to be broadcasted. + + dst: Optional[DRef] + + The output array. If None, an array matching the shape + and dtype of `src` will be allocated on each worker. + + Returns + ------- + output_array: DRef + + The DRef containing the broadcasted data on all workers. + If `dst` was provided, this return value is the same as + `dst`. Otherwise, it is the newly allocated space. + + """ + if not isinstance(src, NDArray): + src = _as_NDArray(src) + + if dst is None: + dst = self.empty(src.shape, src.dtype) + + src_dref = self.copy_to_worker_0(src) + self.broadcast_from_worker0(src_dref, dst) + + return dst + def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef: """Broadcast an array from worker-0 to all other workers. @@ -313,6 +364,45 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef: func = self._get_cached_method("runtime.disco.broadcast_from_worker0") func(src, dst) + def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + """Scatter an array across all workers + + Parameters + ---------- + src: Union[np.ndarray, NDArray] + + The array to be scattered. The first dimension of this + array, `src.shape[0]`, must be equal to the number of + workers. + + dst: Optional[DRef] + + The output array. If None, an array with compatible shape + and the same dtype as `src` will be allocated on each + worker. + + Returns + ------- + output_array: DRef + + The DRef containing the scattered data on all workers. + If `dst` was provided, this return value is the same as + `dst`. Otherwise, it is the newly allocated space. + + """ + assert src.shape[0] == self.num_workers + + if not isinstance(src, NDArray): + src = _as_NDArray(src) + + if dst is None: + dst = self.empty(src.shape[1:], src.dtype) + + src_dref = self.copy_to_worker_0(src) + self.scatter_from_worker0(src_dref, dst) + + return dst + def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None: """Scatter an array from worker-0 to all other workers. diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index b94bfdb2bb59..5831f245dfaf 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -103,33 +103,87 @@ def test_allgather(session_kind, ccl): @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) -def test_broadcast_from_worker0(session_kind, ccl): +@pytest.mark.parametrize("use_explicit_output", [True, False]) +def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output): devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) array = np.arange(12, dtype="float32").reshape(3, 4) - d_array = sess.empty((3, 4), "float32", worker0_only=True) - d_array.debug_copy_from(0, array) - dst_array = sess.empty((3, 4), "float32") - sess.broadcast_from_worker0(d_array, dst_array) + + if use_explicit_output: + src_array = sess.empty((3, 4), "float32", worker0_only=True) + src_array.debug_copy_from(0, array) + dst_array = sess.empty((3, 4), "float32") + sess.broadcast_from_worker0(src_array, dst_array) + else: + dst_array = sess.broadcast(array) + result = dst_array.debug_get_from_remote(1).numpy() np.testing.assert_equal(result, array) @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) -def test_scatter(session_kind, ccl, capfd): +@pytest.mark.parametrize("use_explicit_output", [True, False]) +def test_scatter(session_kind, ccl, use_explicit_output, capfd): + devices = [0, 1] + sess = session_kind(num_workers=len(devices)) + sess.init_ccl(ccl, *devices) + + array = np.arange(36, dtype="float32").reshape(2, 6, 3) + + if use_explicit_output: + d_src = sess.empty((2, 6, 3), "float32", worker0_only=True) + d_dst = sess.empty((6, 3), "float32") + d_src.debug_copy_from(0, array) + sess.scatter_from_worker0(d_src, d_dst) + else: + d_dst = sess.scatter(array) + + np.testing.assert_equal( + d_dst.debug_get_from_remote(0).numpy(), + array[0, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(1).numpy(), + array[1, :, :], + ) + + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.scatter_from_worker0" + + +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_scatter_with_implicit_reshape(session_kind, ccl, capfd): + """Scatter may perform an implicit reshape + + Scattering elements to the workers requires the total number of + elements to be divisible by the number of workers. It does not + necessarily correspond to scattering across the outermost + dimension. Here, the number of workers (2) and the outermost + dimension (3) are not divisible, but the scatter may still be + performed. + + This is only allowed when the caller explicitly uses the + `sess.scatter_from_worker0` method, and is not allowed in + `sess.scatter` method. Because the `sess.scatter` method may + perform an allocation on the disco workers, it requires that the + scatter occur across the outermost dimension. + + """ devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) array = np.arange(36, dtype="float32").reshape(3, 4, 3) + d_src = sess.empty((3, 4, 3), "float32", worker0_only=True) d_dst = sess.empty((3, 3, 2), "float32") - d_src.debug_copy_from(0, array) - sess.scatter_from_worker0(d_src, d_dst) np.testing.assert_equal(