diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 97edeff1d19a..ddde1bc1f323 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -249,17 +249,34 @@ def copy_from_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: """ return _ffi_api.SessionCopyFromWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member - def copy_to_worker_0(self, host_array: NDArray, remote_array: DRef) -> None: + def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = None) -> DRef: """Copy the controller-side NDArray to worker-0. Parameters ---------- - host_array : numpy.ndarray - The array to be copied from worker-0. - remote_array : NDArray - The NDArray on worker-0. + host_array : NDArray + + The array to be copied to worker-0. + + remote_array : Optiona[DRef] + + The destination NDArray on worker-0. + + Returns + ------- + output_array: DRef + + The DRef containing the copied data on worker0, and + NullOpt on all other workers. If `remote_array` was + provided, this return value is the same as `remote_array`. + Otherwise, it is the newly allocated space. + """ - return _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member + if remote_array is None: + remote_array = self.empty(host_array.shape, host_array.dtype, worker0_only=True) + + _ffi_api.SessionCopyToWorker0(self, host_array, remote_array) # type: ignore # pylint: disable=no-member + return remote_array def load_vm_module( self, @@ -302,6 +319,40 @@ def init_ccl(self, ccl: str, *device_ids): _ffi_api.SessionInitCCL(self, ccl, ShapeTuple(device_ids)) # type: ignore # pylint: disable=no-member self._clear_ipc_memory_pool() + def broadcast(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + """Broadcast an array to all workers + + Parameters + ---------- + src: Union[np.ndarray, NDArray] + + The array to be broadcasted. + + dst: Optional[DRef] + + The output array. If None, an array matching the shape + and dtype of `src` will be allocated on each worker. + + Returns + ------- + output_array: DRef + + The DRef containing the broadcasted data on all workers. + If `dst` was provided, this return value is the same as + `dst`. Otherwise, it is the newly allocated space. + + """ + if not isinstance(src, NDArray): + src = _as_NDArray(src) + + if dst is None: + dst = self.empty(src.shape, src.dtype) + + src_dref = self.copy_to_worker_0(src) + self.broadcast_from_worker0(src_dref, dst) + + return dst + def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef: """Broadcast an array from worker-0 to all other workers. @@ -313,6 +364,45 @@ def broadcast_from_worker0(self, src: DRef, dst: DRef) -> DRef: func = self._get_cached_method("runtime.disco.broadcast_from_worker0") func(src, dst) + def scatter(self, src: Union[np.ndarray, NDArray], dst: Optional[DRef] = None) -> DRef: + """Scatter an array across all workers + + Parameters + ---------- + src: Union[np.ndarray, NDArray] + + The array to be scattered. The first dimension of this + array, `src.shape[0]`, must be equal to the number of + workers. + + dst: Optional[DRef] + + The output array. If None, an array with compatible shape + and the same dtype as `src` will be allocated on each + worker. + + Returns + ------- + output_array: DRef + + The DRef containing the scattered data on all workers. + If `dst` was provided, this return value is the same as + `dst`. Otherwise, it is the newly allocated space. + + """ + assert src.shape[0] == self.num_workers + + if not isinstance(src, NDArray): + src = _as_NDArray(src) + + if dst is None: + dst = self.empty(src.shape[1:], src.dtype) + + src_dref = self.copy_to_worker_0(src) + self.scatter_from_worker0(src_dref, dst) + + return dst + def scatter_from_worker0(self, from_array: DRef, to_array: DRef) -> None: """Scatter an array from worker-0 to all other workers. diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py index b94bfdb2bb59..5831f245dfaf 100644 --- a/tests/python/disco/test_ccl.py +++ b/tests/python/disco/test_ccl.py @@ -103,33 +103,87 @@ def test_allgather(session_kind, ccl): @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) -def test_broadcast_from_worker0(session_kind, ccl): +@pytest.mark.parametrize("use_explicit_output", [True, False]) +def test_broadcast_from_worker0(session_kind, ccl, use_explicit_output): devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) array = np.arange(12, dtype="float32").reshape(3, 4) - d_array = sess.empty((3, 4), "float32", worker0_only=True) - d_array.debug_copy_from(0, array) - dst_array = sess.empty((3, 4), "float32") - sess.broadcast_from_worker0(d_array, dst_array) + + if use_explicit_output: + src_array = sess.empty((3, 4), "float32", worker0_only=True) + src_array.debug_copy_from(0, array) + dst_array = sess.empty((3, 4), "float32") + sess.broadcast_from_worker0(src_array, dst_array) + else: + dst_array = sess.broadcast(array) + result = dst_array.debug_get_from_remote(1).numpy() np.testing.assert_equal(result, array) @pytest.mark.parametrize("session_kind", _all_session_kinds) @pytest.mark.parametrize("ccl", _ccl) -def test_scatter(session_kind, ccl, capfd): +@pytest.mark.parametrize("use_explicit_output", [True, False]) +def test_scatter(session_kind, ccl, use_explicit_output, capfd): + devices = [0, 1] + sess = session_kind(num_workers=len(devices)) + sess.init_ccl(ccl, *devices) + + array = np.arange(36, dtype="float32").reshape(2, 6, 3) + + if use_explicit_output: + d_src = sess.empty((2, 6, 3), "float32", worker0_only=True) + d_dst = sess.empty((6, 3), "float32") + d_src.debug_copy_from(0, array) + sess.scatter_from_worker0(d_src, d_dst) + else: + d_dst = sess.scatter(array) + + np.testing.assert_equal( + d_dst.debug_get_from_remote(0).numpy(), + array[0, :, :], + ) + np.testing.assert_equal( + d_dst.debug_get_from_remote(1).numpy(), + array[1, :, :], + ) + + captured = capfd.readouterr() + assert ( + not captured.err + ), "No warning messages should be generated from disco.Session.scatter_from_worker0" + + +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("ccl", _ccl) +def test_scatter_with_implicit_reshape(session_kind, ccl, capfd): + """Scatter may perform an implicit reshape + + Scattering elements to the workers requires the total number of + elements to be divisible by the number of workers. It does not + necessarily correspond to scattering across the outermost + dimension. Here, the number of workers (2) and the outermost + dimension (3) are not divisible, but the scatter may still be + performed. + + This is only allowed when the caller explicitly uses the + `sess.scatter_from_worker0` method, and is not allowed in + `sess.scatter` method. Because the `sess.scatter` method may + perform an allocation on the disco workers, it requires that the + scatter occur across the outermost dimension. + + """ devices = [0, 1] sess = session_kind(num_workers=len(devices)) sess.init_ccl(ccl, *devices) array = np.arange(36, dtype="float32").reshape(3, 4, 3) + d_src = sess.empty((3, 4, 3), "float32", worker0_only=True) d_dst = sess.empty((3, 3, 2), "float32") - d_src.debug_copy_from(0, array) - sess.scatter_from_worker0(d_src, d_dst) np.testing.assert_equal(