diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 02bf7af2d253..0a55ccd1e421 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -144,7 +144,7 @@ class _LGBMRegressorBase: # type: ignore from dask.bag import from_delayed as dask_bag_from_delayed from dask.dataframe import DataFrame as dask_DataFrame from dask.dataframe import Series as dask_Series - from dask.distributed import Client, default_client, wait + from dask.distributed import Client, Future, default_client, wait DASK_INSTALLED = True except ImportError: DASK_INSTALLED = False @@ -161,6 +161,12 @@ class Client: # type: ignore def __init__(self, *args, **kwargs): pass + class Future: # type: ignore + """Dummy class for dask.distributed.Future.""" + + def __init__(self, *args, **kwargs): + pass + class dask_Array: # type: ignore """Dummy class for dask.array.Array.""" diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 4df81007a968..8aeeac09eed2 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -6,6 +6,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost. """ +import operator import socket from collections import defaultdict from copy import deepcopy @@ -18,7 +19,7 @@ import scipy.sparse as ss from .basic import LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning -from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat, +from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, Future, LGBMNotFittedError, concat, dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait) from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomObjectiveFunction, @@ -38,18 +39,21 @@ _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] -class _HostWorkers: +class _RemoteSocket: + def acquire(self) -> int: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(('', 0)) + return self.socket.getsockname()[1] - def __init__(self, default: str, all_workers: List[str]): - self.default = default - self.all_workers = all_workers + def release(self) -> None: + self.socket.close() - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, type(self)) - and self.default == other.default - and self.all_workers == other.all_workers - ) + +def _acquire_port() -> Tuple[_RemoteSocket, int]: + s = _RemoteSocket() + port = s.acquire() + return s, port class _DatasetNames(Enum): @@ -83,73 +87,40 @@ def _get_dask_client(client: Optional[Client]) -> Client: return client -def _find_n_open_ports(n: int) -> List[int]: - """Find n random open ports on localhost. - - Returns - ------- - ports : list of int - n random open ports on localhost. - """ - sockets = [] - for _ in range(n): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('', 0)) - sockets.append(s) - ports = [] - for s in sockets: - ports.append(s.getsockname()[1]) - s.close() - return ports - - -def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWorkers]: - """Group all worker addresses by hostname. - - Returns - ------- - host_to_workers : dict - mapping from hostname to all its workers. - """ - host_to_workers: Dict[str, _HostWorkers] = {} - for address in worker_addresses: - hostname = urlparse(address).hostname - if not hostname: - raise ValueError(f"Could not parse host name from worker address '{address}'") - if hostname not in host_to_workers: - host_to_workers[hostname] = _HostWorkers(default=address, all_workers=[address]) - else: - host_to_workers[hostname].all_workers.append(address) - return host_to_workers - - def _assign_open_ports_to_workers( client: Client, - host_to_workers: Dict[str, _HostWorkers] -) -> Dict[str, int]: + workers: List[str], +) -> Tuple[Dict[str, Future], Dict[str, int]]: """Assign an open port to each worker. Returns ------- + worker_to_socket_future: dict + mapping from worker address to a future pointing to the remote socket. worker_to_port: dict - mapping from worker address to an open port. + mapping from worker address to an open port in the worker's host. """ - host_ports_futures = {} - for hostname, workers in host_to_workers.items(): - n_workers_in_host = len(workers.all_workers) - host_ports_futures[hostname] = client.submit( - _find_n_open_ports, - n=n_workers_in_host, - workers=[workers.default], - pure=False, + # Acquire port in worker + worker_to_future = {} + for worker in workers: + worker_to_future[worker] = client.submit( + _acquire_port, + workers=[worker], allow_other_workers=False, + pure=False, ) - found_ports = client.gather(host_ports_futures) - worker_to_port = {} - for hostname, workers in host_to_workers.items(): - for worker, port in zip(workers.all_workers, found_ports[hostname]): - worker_to_port[worker] = port - return worker_to_port + + # schedule futures to retrieve each element of the tuple + worker_to_socket_future = {} + worker_to_port_future = {} + for worker, socket_future in worker_to_future.items(): + worker_to_socket_future[worker] = client.submit(operator.itemgetter(0), socket_future) + worker_to_port_future[worker] = client.submit(operator.itemgetter(1), socket_future) + + # retrieve ports + worker_to_port = client.gather(worker_to_port_future) + + return worker_to_socket_future, worker_to_port def _concat(seq: List[_DaskPart]) -> _DaskPart: @@ -190,6 +161,7 @@ def _train_part( num_machines: int, return_model: bool, time_out: int, + remote_socket: _RemoteSocket, **kwargs: Any ) -> Optional[LGBMModel]: network_params = { @@ -320,6 +292,8 @@ def _train_part( kwargs['eval_class_weight'] = [eval_class_weight[i] for i in eval_component_idx] model = model_factory(**params) + if remote_socket is not None: + remote_socket.release() try: if is_ranker: model.fit( @@ -777,6 +751,7 @@ def _train( machines = params.pop("machines") # figure out network params + worker_to_socket_future: Dict[str, Future] = {} worker_addresses = worker_map.keys() if machines is not None: _log_info("Using passed-in 'machines' parameter") @@ -802,8 +777,7 @@ def _train( } else: _log_info("Finding random open ports for workers") - host_to_workers = _group_workers_by_host(worker_map.keys()) - worker_address_to_port = _assign_open_ports_to_workers(client, host_to_workers) + worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(client, list(worker_map.keys())) machines = ','.join([ f'{urlparse(worker_address).hostname}:{port}' @@ -831,6 +805,7 @@ def _train( local_listen_port=worker_address_to_port[worker], num_machines=num_machines, time_out=params.get('time_out', 120), + remote_socket=worker_to_socket_future.get(worker, None), return_model=(worker == master_worker), workers=[worker], allow_other_workers=False, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 662020428270..cb69440b3cde 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -519,26 +519,6 @@ def test_classifier_custom_objective(output, task, cluster): assert_eq(p1_proba, p1_proba_local) -def test_group_workers_by_host(): - hosts = [f'0.0.0.{i}' for i in range(2)] - workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts] - expected = { - host: lgb.dask._HostWorkers( - default=f'tcp://{host}:0', - all_workers=[f'tcp://{host}:0', f'tcp://{host}:1'] - ) - for host in hosts - } - host_to_workers = lgb.dask._group_workers_by_host(workers) - assert host_to_workers == expected - - -def test_group_workers_by_host_unparseable_host_names(): - workers_without_protocol = ['0.0.0.1:80', '0.0.0.2:80'] - with pytest.raises(ValueError, match="Could not parse host name from worker address '0.0.0.1:80'"): - lgb.dask._group_workers_by_host(workers_without_protocol) - - def test_machines_to_worker_map_unparseable_host_names(): workers = {'0.0.0.1:80': {}, '0.0.0.2:80': {}} machines = "0.0.0.1:80,0.0.0.2:80" @@ -546,23 +526,6 @@ def test_machines_to_worker_map_unparseable_host_names(): lgb.dask._machines_to_worker_map(machines=machines, worker_addresses=workers.keys()) -def test_assign_open_ports_to_workers(cluster): - with Client(cluster) as client: - workers = client.scheduler_info()['workers'].keys() - n_workers = len(workers) - host_to_workers = lgb.dask._group_workers_by_host(workers) - for _ in range(25): - worker_address_to_port = lgb.dask._assign_open_ports_to_workers(client, host_to_workers) - found_ports = worker_address_to_port.values() - assert len(found_ports) == n_workers - # check that found ports are different for same address (LocalCluster) - assert len(set(found_ports)) == len(found_ports) - # check that the ports are indeed open - for port in found_ports: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', port)) - - def test_training_does_not_fail_on_port_conflicts(cluster): with Client(cluster) as client: _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array') @@ -1588,15 +1551,17 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c assert 'machines' not in params # model 2 - machines given + workers = list(client.scheduler_info()['workers']) workers_hostname = _get_workers_hostname(cluster) - n_workers = len(client.scheduler_info()['workers']) - open_ports = lgb.dask._find_n_open_ports(n_workers) + remote_sockets, open_ports = lgb.dask._assign_open_ports_to_workers(client, workers) + for s in remote_sockets.values(): + s.release() dask_model2 = dask_model_factory( n_estimators=5, num_leaves=5, machines=",".join([ f"{workers_hostname}:{port}" - for port in open_ports + for port in open_ports.values() ]), )