Skip to content

Commit

Permalink
Merge branch 'develop' into feature/PRD-586-ochestrator-urls
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak authored Sep 24, 2024
2 parents d37fb71 + 9626c30 commit a4647dc
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/zenml/integrations/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class LightningIntegration(Integration):
"""Definition of Lightning Integration for ZenML."""

NAME = LIGHTNING
REQUIREMENTS = ["lightning-sdk"]
REQUIREMENTS = ["lightning-sdk>=0.1.17"]

@classmethod
def flavors(cls) -> List[Type[Flavor]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ def is_synchronous(self) -> bool:
"""
return self.synchronous

@property
def is_schedulable(self) -> bool:
"""Whether the orchestrator is schedulable or not.
Returns:
Whether the orchestrator is schedulable or not.
"""
return False


class LightningOrchestratorFlavor(BaseOrchestratorFlavor):
"""Lightning orchestrator flavor."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,29 @@ def _set_lightning_env_vars(
Args:
deployment: The pipeline deployment to prepare or run.
Raises:
ValueError: If the user id and api key or username and organization
"""
settings = cast(
LightningOrchestratorSettings, self.get_settings(deployment)
)
if settings.user_id:
os.environ["LIGHTNING_USER_ID"] = settings.user_id
if settings.api_key:
os.environ["LIGHTNING_API_KEY"] = settings.api_key
if not settings.user_id or not settings.api_key:
raise ValueError(
"Lightning orchestrator requires `user_id` and `api_key` both to be set in the settings."
)
os.environ["LIGHTNING_USER_ID"] = settings.user_id
os.environ["LIGHTNING_API_KEY"] = settings.api_key
if settings.username:
os.environ["LIGHTNING_USERNAME"] = settings.username
elif settings.organization:
os.environ["LIGHTNING_ORG"] = settings.organization
else:
raise ValueError(
"Lightning orchestrator requires either `username` or `organization` to be set in the settings."
)
if settings.teamspace:
os.environ["LIGHTNING_TEAMSPACE"] = settings.teamspace
if settings.organization:
os.environ["LIGHTNING_ORG"] = settings.organization

@property
def config(self) -> LightningOrchestratorConfig:
Expand Down Expand Up @@ -267,9 +276,7 @@ def prepare_or_run_pipeline(
) as code_file:
code_archive.write_archive(code_file)
code_path = code_file.name

filename = f"{orchestrator_run_name}.tar.gz"

# Construct the env variables for the pipeline
env_vars = environment.copy()
orchestrator_run_id = str(uuid4())
Expand Down Expand Up @@ -392,9 +399,7 @@ def _construct_lightning_steps(
f"Installing requirements: {pipeline_requirements_to_string}"
)
studio.run(f"uv pip install {pipeline_requirements_to_string}")
studio.run(
"pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator"
)
studio.run("pip install zenml -y")

for custom_command in settings.custom_commands or []:
studio.run(
Expand Down Expand Up @@ -488,9 +493,7 @@ def _upload_and_run_pipeline(
)
studio.run("pip install uv")
studio.run(f"uv pip install {requirements}")
studio.run(
"pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator"
)
studio.run("pip install zenml -y")
# studio.run(f"pip install {wheel_path.rsplit('/', 1)[-1]}")
for command in settings.custom_commands or []:
output = studio.run(
Expand Down Expand Up @@ -563,9 +566,7 @@ def _run_step_in_new_studio(
)
studio.run("pip install uv")
studio.run(f"uv pip install {details['requirements']}")
studio.run(
"pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator"
)
studio.run("pip install zenml -y")
# studio.run(f"pip install {wheel_path.rsplit('/', 1)[-1]}")
for command in custom_commands or []:
output = studio.run(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def main() -> None:
f"uv pip install {pipeline_requirements_to_string}"
)
logger.info(output)
output = main_studio.run(
"pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator"
)
output = main_studio.run("pip install zenml -y")
logger.info(output)

for command in pipeline_settings.custom_commands or []:
Expand Down Expand Up @@ -250,9 +248,7 @@ def run_step_on_lightning_studio(step_name: str) -> None:
f"uv pip install {step_requirements_to_string}"
)
logger.info(output)
output = studio.run(
"pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator"
)
output = studio.run("pip install zenml -y")
logger.info(output)
for command in step_settings.custom_commands or []:
output = studio.run(
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/integrations/mlflow/steps/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def mlflow_register_model_step(
metadata.zenml_pipeline_run_uuid = pipeline_run_uuid
if metadata.zenml_workspace is None:
metadata.zenml_workspace = zenml_workspace
if metadata.model_extra.get("mlflow_run_id", None) is None:
metadata.model_extra["mlflow_run_id"] = mlflow_run_id

# Register model version
model_version = model_registry.register_model_version(
Expand Down
1 change: 1 addition & 0 deletions src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
description="Template used for the pipeline run.",
)
is_templatable: bool = Field(
default=False,
description="Whether a template can be created from this run.",
)

Expand Down
41 changes: 33 additions & 8 deletions src/zenml/zen_server/cloud_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utils concerning anything concerning the cloud control plane backend."""

import os
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional

import requests
Expand All @@ -19,11 +20,9 @@ class ZenMLCloudConfiguration(BaseModel):
"""ZenML Pro RBAC configuration."""

api_url: str

oauth2_client_id: str
oauth2_client_secret: str
oauth2_audience: str
auth0_domain: str

@field_validator("api_url")
@classmethod
Expand Down Expand Up @@ -68,6 +67,8 @@ def __init__(self) -> None:
"""Initialize the RBAC component."""
self._config = ZenMLCloudConfiguration.from_environment()
self._session: Optional[requests.Session] = None
self._token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None

def get(
self, endpoint: str, params: Optional[Dict[str, Any]]
Expand All @@ -91,7 +92,8 @@ def get(

response = self.session.get(url=url, params=params, timeout=7)
if response.status_code == 401:
# Refresh the auth token and try again
# If we get an Unauthorized error from the API serer, we refresh the
# auth token and try again
self._clear_session()
response = self.session.get(url=url, params=params, timeout=7)

Expand Down Expand Up @@ -186,6 +188,8 @@ def session(self) -> requests.Session:
def _clear_session(self) -> None:
"""Clear the authentication session."""
self._session = None
self._token = None
self._token_expires_at = None

def _fetch_auth_token(self) -> str:
"""Fetch an auth token for the Cloud API from auth0.
Expand All @@ -196,8 +200,16 @@ def _fetch_auth_token(self) -> str:
Returns:
Auth token.
"""
if (
self._token is not None
and self._token_expires_at is not None
and datetime.now(timezone.utc) + timedelta(minutes=5)
< self._token_expires_at
):
return self._token

# Get an auth token from auth0
auth0_url = f"https://{self._config.auth0_domain}/oauth/token"
login_url = f"{self._config.api_url}/auth/login"
headers = {"content-type": "application/x-www-form-urlencoded"}
payload = {
"client_id": self._config.oauth2_client_id,
Expand All @@ -207,18 +219,31 @@ def _fetch_auth_token(self) -> str:
}
try:
response = requests.post(
auth0_url, headers=headers, data=payload, timeout=7
login_url, headers=headers, data=payload, timeout=7
)
response.raise_for_status()
except Exception as e:
raise RuntimeError(f"Error fetching auth token from auth0: {e}")

access_token = response.json().get("access_token", "")
json_response = response.json()
access_token = json_response.get("access_token", "")
expires_in = json_response.get("expires_in", 0)

if not access_token or not isinstance(access_token, str):
if (
not access_token
or not isinstance(access_token, str)
or not expires_in
or not isinstance(expires_in, int)
):
raise RuntimeError("Could not fetch auth token from auth0.")

return str(access_token)
self._token = access_token
self._token_expires_at = datetime.now(timezone.utc) + timedelta(
seconds=expires_in
)

assert self._token is not None
return self._token


def cloud_connection() -> ZenMLCloudConnection:
Expand Down

0 comments on commit a4647dc

Please sign in to comment.