Skip to content

Commit

Permalink
Add network argument to vertex runner (#537)
Browse files Browse the repository at this point in the history
This is needed to access private IPs from Vertex pipelines (eg. a
private Weaviate endpoint).
  • Loading branch information
RobbeSneyders authored Oct 19, 2023
1 parent 1e6932e commit f8f4946
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/fondant/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,13 @@ def register_run(parent_parser):
default=None,
)

vertex_parser.add_argument(
"--network",
help="Network for the job to connect to, useful when peering with Vertex AI. Format "
"should be 'projects/${project_number}/global/networks/${network}'",
default=None,
)

vertex_parser.set_defaults(func=run_vertex)


Expand Down Expand Up @@ -525,6 +532,7 @@ def run_vertex(args):
project_id=args.project_id,
region=args.region,
service_account=args.service_account,
network=args.network,
)
runner.run(input_spec=spec_ref)

Expand Down
7 changes: 6 additions & 1 deletion src/fondant/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
project_id: str,
region: str,
service_account: t.Optional[str] = None,
network: t.Optional[str] = None,
):
self.__resolve_imports()

Expand All @@ -107,14 +108,18 @@ def __init__(
location=region,
)
self.service_account = service_account
self.network = network

def run(self, input_spec: str, *args, **kwargs):
job = self.aip.PipelineJob(
display_name=self.get_name_from_spec(input_spec),
template_path=input_spec,
enable_caching=False,
)
job.submit(service_account=self.service_account)
job.submit(
service_account=self.service_account,
network=self.network,
)

def get_name_from_spec(self, input_spec: str):
"""Get the name of the pipeline from the spec."""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,15 @@ def test_vertex_run(tmp_path_factory):
region="europe-west-1",
project_id="project-123",
service_account=None,
network=None,
ref="some/path",
)
run_vertex(args)
mock_runner.assert_called_once_with(
project_id="project-123",
region="europe-west-1",
service_account=None,
network=None,
)

with patch("fondant.cli.VertexRunner") as mock_runner, patch(
Expand All @@ -303,10 +305,12 @@ def test_vertex_run(tmp_path_factory):
region="europe-west-1",
project_id="project-123",
service_account=None,
network=None,
)
run_vertex(args)
mock_runner.assert_called_once_with(
project_id="project-123",
region="europe-west-1",
service_account=None,
network=None,
)

0 comments on commit f8f4946

Please sign in to comment.