Skip to content

Commit

Permalink
test env varibale
Browse files Browse the repository at this point in the history
  • Loading branch information
safoinme committed Jul 11, 2024
1 parent 10d941c commit f69233e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DatabricksOrchestratorSettings,
)
from zenml.integrations.databricks.orchestrators.databricks_orchestrator_entrypoint_config import (
ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID,
DatabricksEntrypointConfiguration,
)
from zenml.integrations.databricks.utils.databricks_utils import (
Expand All @@ -60,14 +61,12 @@

logger = get_logger(__name__)

ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID = (
"ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID"
)
ZENML_STEP_DEFAULT_ENTRYPOINT_COMMAND = "entrypoint.main"
DATABRICKS_WHEELS_DIRECTORY_PREFIX = "dbfs:/FileStore/zenml"
DATABRICKS_LOCAL_FILESYSTEM_PREFIX = "file:/"
DATABRICKS_CLUSTER_DEFAULT_NAME = "zenml-databricks-cluster"
DATABRICKS_SPARK_DEFAULT_VERSION = "15.3.x-scala2.12"
DATABRICKS_JOB_ID_PARAMETER_REFERENCE = "{{job.id}}"


class DatabricksOrchestrator(WheeledOrchestrator):
Expand Down Expand Up @@ -157,8 +156,17 @@ def get_orchestrator_run_id(self) -> str:
Returns:
The orchestrator run id.
Raises:
RuntimeError: If the run id cannot be read from the environment.
"""
return os.getenv(ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID) or ""
try:
return os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID]
except KeyError:
raise RuntimeError(
"Unable to read run id from environment variable "
f"{ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID}."
)

@property
def root_directory(self) -> str:
Expand Down Expand Up @@ -242,14 +250,6 @@ def prepare_or_run_pipeline(
# Get deployment id
deployment_id = deployment.id

# Set environment
os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID] = str(
deployment_id
)
environment[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID] = str(
deployment_id
)

# Create a callable for future compilation into a dsl.Pipeline.
def _construct_databricks_pipeline(
zenml_project_wheel: str, job_cluster_key: str
Expand All @@ -266,12 +266,11 @@ def _construct_databricks_pipeline(
for step_name, step in deployment.step_configurations.items():
# The arguments are passed to configure the entrypoint of the
# docker container when the step is called.
arguments = (
DatabricksEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
deployment_id=deployment_id,
wheel_package=self.package_name,
)
arguments = DatabricksEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name,
deployment_id=deployment_id,
wheel_package=self.package_name,
databricks_job_id=DATABRICKS_JOB_ID_PARAMETER_REFERENCE,
)

# Find the upstream container ops of the current step and
Expand Down Expand Up @@ -356,8 +355,6 @@ def _construct_databricks_pipeline(
for key, value in spark_env_vars.items():
env_vars[key] = value

env_vars[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID] = str(deployment_id)

fileio.rmtree(repository_temp_dir)

logger.info(
Expand Down Expand Up @@ -453,10 +450,8 @@ def get_pipeline_run_metadata(
Returns:
A dictionary of metadata.
"""
databricks_client = self._get_databricks_client()
run_url = (
f"{self.config.host}/jobs/"
f"{databricks_client.dbutils.widgets.get('job_id')}"
f"{self.config.host}/jobs/" f"{self.get_orchestrator_run_id()}"
)
return {
METADATA_ORCHESTRATOR_URL: Uri(run_url),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
)

WHEEL_PACKAGE_OPTION = "wheel_package"
DATABRICKS_JOB_ID_OPTION = "job_id"
ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID = (
"ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID"
)


class DatabricksEntrypointConfiguration(StepEntrypointConfiguration):
Expand All @@ -41,7 +45,11 @@ def get_entrypoint_options(cls) -> Set[str]:
Returns:
The superclass options as well as an option for the wheel package.
"""
return super().get_entrypoint_options() | {WHEEL_PACKAGE_OPTION}
return (
super().get_entrypoint_options()
| {WHEEL_PACKAGE_OPTION}
| {DATABRICKS_JOB_ID_OPTION}
)

@classmethod
def get_entrypoint_arguments(
Expand All @@ -65,6 +73,8 @@ def get_entrypoint_arguments(
return super().get_entrypoint_arguments(**kwargs) + [
f"--{WHEEL_PACKAGE_OPTION}",
kwargs[WHEEL_PACKAGE_OPTION],
f"--{DATABRICKS_JOB_ID_OPTION}",
kwargs[DATABRICKS_JOB_ID_OPTION],
]

def run(self) -> None:
Expand All @@ -77,5 +87,11 @@ def run(self) -> None:
sys.path.insert(0, project_root)
sys.path.insert(-1, project_root)

# Get the job id and add it to the environment
databricks_job_id = self.entrypoint_args[DATABRICKS_JOB_ID_OPTION]
os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID] = (
databricks_job_id
)

# Run the step
super().run()

0 comments on commit f69233e

Please sign in to comment.