From f8f4946ea84839a5f28543c9972ff2e2641b4be4 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Thu, 19 Oct 2023 10:05:51 +0200 Subject: [PATCH] Add network argument to vertex runner (#537) This is needed to access private IPs from Vertex pipelines (eg. a private Weaviate endpoint). --- src/fondant/cli.py | 8 ++++++++ src/fondant/runner.py | 7 ++++++- tests/test_cli.py | 4 ++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/fondant/cli.py b/src/fondant/cli.py index 79b930e78..363554d74 100644 --- a/src/fondant/cli.py +++ b/src/fondant/cli.py @@ -464,6 +464,13 @@ def register_run(parent_parser): default=None, ) + vertex_parser.add_argument( + "--network", + help="Network for the job to connect to, useful when peering with Vertex AI. Format " + "should be 'projects/${project_number}/global/networks/${network}'", + default=None, + ) + vertex_parser.set_defaults(func=run_vertex) @@ -525,6 +532,7 @@ def run_vertex(args): project_id=args.project_id, region=args.region, service_account=args.service_account, + network=args.network, ) runner.run(input_spec=spec_ref) diff --git a/src/fondant/runner.py b/src/fondant/runner.py index a30df5762..c79dad0ef 100644 --- a/src/fondant/runner.py +++ b/src/fondant/runner.py @@ -99,6 +99,7 @@ def __init__( project_id: str, region: str, service_account: t.Optional[str] = None, + network: t.Optional[str] = None, ): self.__resolve_imports() @@ -107,6 +108,7 @@ def __init__( location=region, ) self.service_account = service_account + self.network = network def run(self, input_spec: str, *args, **kwargs): job = self.aip.PipelineJob( @@ -114,7 +116,10 @@ def run(self, input_spec: str, *args, **kwargs): template_path=input_spec, enable_caching=False, ) - job.submit(service_account=self.service_account) + job.submit( + service_account=self.service_account, + network=self.network, + ) def get_name_from_spec(self, input_spec: str): """Get the name of the pipeline from the spec.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 64fb2b0e4..da15759fe 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -279,6 +279,7 @@ def test_vertex_run(tmp_path_factory): region="europe-west-1", project_id="project-123", service_account=None, + network=None, ref="some/path", ) run_vertex(args) @@ -286,6 +287,7 @@ def test_vertex_run(tmp_path_factory): project_id="project-123", region="europe-west-1", service_account=None, + network=None, ) with patch("fondant.cli.VertexRunner") as mock_runner, patch( @@ -303,10 +305,12 @@ def test_vertex_run(tmp_path_factory): region="europe-west-1", project_id="project-123", service_account=None, + network=None, ) run_vertex(args) mock_runner.assert_called_once_with( project_id="project-123", region="europe-west-1", service_account=None, + network=None, )