diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index f6a11e74d3c..f1a7c6663c1 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -17,7 +17,7 @@ from typing import Iterable, List, Optional, Tuple -from flwr.driver.grpc_driver import GrpcDriver +from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver from flwr.proto.driver_pb2 import ( CreateWorkloadRequest, GetNodesRequest, @@ -29,9 +29,30 @@ class Driver: - """`Driver` class provides an interface to the Driver API.""" - - def __init__(self) -> None: + """`Driver` class provides an interface to the Driver API. + + Parameters + ---------- + driver_service_address : Optional[str] + The IPv4 or IPv6 address of the Driver API server. + Defaults to `"[::]:9091"`. + certificates : bytes (default: None) + Tuple containing root certificate, server certificate, and private key + to start a secure SSL-enabled server. The tuple is expected to have + three bytes elements in the following order: + + * CA certificate. + * server certificate. + * server private key. + """ + + def __init__( + self, + driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, + certificates: Optional[bytes] = None, + ) -> None: + self.addr = driver_service_address + self.certificates = certificates self.grpc_driver: Optional[GrpcDriver] = None self.workload_id: Optional[int] = None self.node = Node(node_id=0, anonymous=True) @@ -40,7 +61,9 @@ def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]: # Check if the GrpcDriver is initialized if self.grpc_driver is None or self.workload_id is None: # Connect and create workload - self.grpc_driver = GrpcDriver() + self.grpc_driver = GrpcDriver( + driver_service_address=self.addr, certificates=self.certificates + ) self.grpc_driver.connect() res = self.grpc_driver.create_workload(CreateWorkloadRequest()) self.workload_id = res.workload_id