diff --git a/src/zenml/integrations/lightning/__init__.py b/src/zenml/integrations/lightning/__init__.py index 92eb88234f7..eae1140b158 100644 --- a/src/zenml/integrations/lightning/__init__.py +++ b/src/zenml/integrations/lightning/__init__.py @@ -28,7 +28,7 @@ class LightningIntegration(Integration): """Definition of Lightning Integration for ZenML.""" NAME = LIGHTNING - REQUIREMENTS = ["lightning-sdk"] + REQUIREMENTS = ["lightning-sdk>=0.1.17"] @classmethod def flavors(cls) -> List[Type[Flavor]]: diff --git a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py index 4d13493f1cb..77cc2b08959 100644 --- a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +++ b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py @@ -85,6 +85,15 @@ def is_synchronous(self) -> bool: """ return self.synchronous + @property + def is_schedulable(self) -> bool: + """Whether the orchestrator is schedulable or not. + + Returns: + Whether the orchestrator is schedulable or not. + """ + return False + class LightningOrchestratorFlavor(BaseOrchestratorFlavor): """Lightning orchestrator flavor.""" diff --git a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py index 778ad4ecbd4..82366bbdd99 100644 --- a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +++ b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py @@ -103,20 +103,29 @@ def _set_lightning_env_vars( Args: deployment: The pipeline deployment to prepare or run. + + Raises: + ValueError: If the user id and api key or username and organization """ settings = cast( LightningOrchestratorSettings, self.get_settings(deployment) ) - if settings.user_id: - os.environ["LIGHTNING_USER_ID"] = settings.user_id - if settings.api_key: - os.environ["LIGHTNING_API_KEY"] = settings.api_key + if not settings.user_id or not settings.api_key: + raise ValueError( + "Lightning orchestrator requires `user_id` and `api_key` both to be set in the settings." + ) + os.environ["LIGHTNING_USER_ID"] = settings.user_id + os.environ["LIGHTNING_API_KEY"] = settings.api_key if settings.username: os.environ["LIGHTNING_USERNAME"] = settings.username + elif settings.organization: + os.environ["LIGHTNING_ORG"] = settings.organization + else: + raise ValueError( + "Lightning orchestrator requires either `username` or `organization` to be set in the settings." + ) if settings.teamspace: os.environ["LIGHTNING_TEAMSPACE"] = settings.teamspace - if settings.organization: - os.environ["LIGHTNING_ORG"] = settings.organization @property def config(self) -> LightningOrchestratorConfig: @@ -267,9 +276,7 @@ def prepare_or_run_pipeline( ) as code_file: code_archive.write_archive(code_file) code_path = code_file.name - filename = f"{orchestrator_run_name}.tar.gz" - # Construct the env variables for the pipeline env_vars = environment.copy() orchestrator_run_id = str(uuid4()) @@ -392,9 +399,7 @@ def _construct_lightning_steps( f"Installing requirements: {pipeline_requirements_to_string}" ) studio.run(f"uv pip install {pipeline_requirements_to_string}") - studio.run( - "pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator" - ) + studio.run("pip install zenml -y") for custom_command in settings.custom_commands or []: studio.run( @@ -488,9 +493,7 @@ def _upload_and_run_pipeline( ) studio.run("pip install uv") studio.run(f"uv pip install {requirements}") - studio.run( - "pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator" - ) + studio.run("pip install zenml -y") # studio.run(f"pip install {wheel_path.rsplit('/', 1)[-1]}") for command in settings.custom_commands or []: output = studio.run( @@ -563,9 +566,7 @@ def _run_step_in_new_studio( ) studio.run("pip install uv") studio.run(f"uv pip install {details['requirements']}") - studio.run( - "pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator" - ) + studio.run("pip install zenml -y") # studio.run(f"pip install {wheel_path.rsplit('/', 1)[-1]}") for command in custom_commands or []: output = studio.run( diff --git a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py index e6624f04d77..4f0ea6394f3 100644 --- a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +++ b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py @@ -166,9 +166,7 @@ def main() -> None: f"uv pip install {pipeline_requirements_to_string}" ) logger.info(output) - output = main_studio.run( - "pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator" - ) + output = main_studio.run("pip install zenml -y") logger.info(output) for command in pipeline_settings.custom_commands or []: @@ -250,9 +248,7 @@ def run_step_on_lightning_studio(step_name: str) -> None: f"uv pip install {step_requirements_to_string}" ) logger.info(output) - output = studio.run( - "pip uninstall zenml -y && pip install git+https://github.com/zenml-io/zenml.git@feature/lightening-studio-orchestrator" - ) + output = studio.run("pip install zenml -y") logger.info(output) for command in step_settings.custom_commands or []: output = studio.run( diff --git a/src/zenml/integrations/mlflow/steps/mlflow_registry.py b/src/zenml/integrations/mlflow/steps/mlflow_registry.py index 479ef90f1d5..ced40183684 100644 --- a/src/zenml/integrations/mlflow/steps/mlflow_registry.py +++ b/src/zenml/integrations/mlflow/steps/mlflow_registry.py @@ -146,6 +146,8 @@ def mlflow_register_model_step( metadata.zenml_pipeline_run_uuid = pipeline_run_uuid if metadata.zenml_workspace is None: metadata.zenml_workspace = zenml_workspace + if metadata.model_extra.get("mlflow_run_id", None) is None: + metadata.model_extra["mlflow_run_id"] = mlflow_run_id # Register model version model_version = model_registry.register_model_version( diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index ac5e3f2e86f..cc3f5ad945c 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -237,6 +237,7 @@ class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata): description="Template used for the pipeline run.", ) is_templatable: bool = Field( + default=False, description="Whether a template can be created from this run.", ) diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index 20498e944b2..e21f84cf125 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -1,6 +1,7 @@ """Utils concerning anything concerning the cloud control plane backend.""" import os +from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional import requests @@ -19,11 +20,9 @@ class ZenMLCloudConfiguration(BaseModel): """ZenML Pro RBAC configuration.""" api_url: str - oauth2_client_id: str oauth2_client_secret: str oauth2_audience: str - auth0_domain: str @field_validator("api_url") @classmethod @@ -68,6 +67,8 @@ def __init__(self) -> None: """Initialize the RBAC component.""" self._config = ZenMLCloudConfiguration.from_environment() self._session: Optional[requests.Session] = None + self._token: Optional[str] = None + self._token_expires_at: Optional[datetime] = None def get( self, endpoint: str, params: Optional[Dict[str, Any]] @@ -91,7 +92,8 @@ def get( response = self.session.get(url=url, params=params, timeout=7) if response.status_code == 401: - # Refresh the auth token and try again + # If we get an Unauthorized error from the API serer, we refresh the + # auth token and try again self._clear_session() response = self.session.get(url=url, params=params, timeout=7) @@ -186,6 +188,8 @@ def session(self) -> requests.Session: def _clear_session(self) -> None: """Clear the authentication session.""" self._session = None + self._token = None + self._token_expires_at = None def _fetch_auth_token(self) -> str: """Fetch an auth token for the Cloud API from auth0. @@ -196,8 +200,16 @@ def _fetch_auth_token(self) -> str: Returns: Auth token. """ + if ( + self._token is not None + and self._token_expires_at is not None + and datetime.now(timezone.utc) + timedelta(minutes=5) + < self._token_expires_at + ): + return self._token + # Get an auth token from auth0 - auth0_url = f"https://{self._config.auth0_domain}/oauth/token" + login_url = f"{self._config.api_url}/auth/login" headers = {"content-type": "application/x-www-form-urlencoded"} payload = { "client_id": self._config.oauth2_client_id, @@ -207,18 +219,31 @@ def _fetch_auth_token(self) -> str: } try: response = requests.post( - auth0_url, headers=headers, data=payload, timeout=7 + login_url, headers=headers, data=payload, timeout=7 ) response.raise_for_status() except Exception as e: raise RuntimeError(f"Error fetching auth token from auth0: {e}") - access_token = response.json().get("access_token", "") + json_response = response.json() + access_token = json_response.get("access_token", "") + expires_in = json_response.get("expires_in", 0) - if not access_token or not isinstance(access_token, str): + if ( + not access_token + or not isinstance(access_token, str) + or not expires_in + or not isinstance(expires_in, int) + ): raise RuntimeError("Could not fetch auth token from auth0.") - return str(access_token) + self._token = access_token + self._token_expires_at = datetime.now(timezone.utc) + timedelta( + seconds=expires_in + ) + + assert self._token is not None + return self._token def cloud_connection() -> ZenMLCloudConnection: