diff --git a/docs/source/models.rst b/docs/source/models.rst index 66e1517a439ed..a4c5f641c90d2 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -3996,6 +3996,40 @@ set the `env_manager` argument when calling :py:func:`mlflow.pyfunc.spark_udf`. df = spark_df.withColumn("prediction", pyfunc_udf(struct("name", "age"))) +If you want to call `:py:func:`mlflow.pyfunc.spark_udf` through Databricks connect in remote client, you need to build the model environment in Databricks runtime first. + +.. rubric:: Example + +.. code-block:: python + + from mlflow.pyfunc import build_model_env + + # Build the model env and save it as an archive file to the provided UC volume directory + # and print the saved model env archive file path (like '/Volumes/.../.../XXXXX.tar.gz') + print(build_model_env(model_uri, "/Volumes/...")) + + # print the cluster id. Databricks Connect client needs to use the cluster id. + print(spark.conf.get("spark.databricks.clusterUsageTags.clusterId")) + +Once you have pre-built the model environment, you can run `:py:func:`mlflow.pyfunc.spark_udf` with 'prebuilt_model_env' parameter through Databricks connect in remote client, + +.. rubric:: Example + +.. code-block:: python + + from databricks.connect import DatabricksSession + + spark = DatabricksSession.builder.remote( + host=os.environ["DATABRICKS_HOST"], + token=os.environ["DATABRICKS_TOKEN"], + cluster_id="", # get cluster id by spark.conf.get("spark.databricks.clusterUsageTags.clusterId") + ).getOrCreate() + + # The path generated by `build_model_env` in Databricks runtime. + model_env_uc_uri = "dbfs:/Volumes/.../.../XXXXX.tar.gz" + pyfunc_udf = mlflow.pyfunc.spark_udf( + spark, model_uri, prebuilt_env_uri=model_env_uc_uri + ) .. _deployment_plugin: diff --git a/examples/spark_udf/spark_udf_with_prebuilt_env.py b/examples/spark_udf/spark_udf_with_prebuilt_env.py index f51622b626662..b2a3e952722ec 100644 --- a/examples/spark_udf/spark_udf_with_prebuilt_env.py +++ b/examples/spark_udf/spark_udf_with_prebuilt_env.py @@ -4,7 +4,6 @@ """ import os -import tempfile from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient @@ -33,21 +32,13 @@ # The prebuilt model environment archive file path. # To build the model environment, run the following line code in Databricks runtime: # `model_env_uc_path = mlflow.pyfunc.build_model_env(model_uri, "/Volumes/...")` -model_env_uc_path = "/Volumes/..." - -tmp_dir = tempfile.mkdtemp() -local_model_env_path = os.path.join(tmp_dir, os.path.basename(model_env_uc_path)) - -# Download model env file from UC volume. -with ws.files.download(model_env_uc_path).contents as rf, open(local_model_env_path, "wb") as wf: - while chunk := rf.read(4096): - wf.write(chunk) +model_env_uc_path = "dbfs:/Volumes/..." infer_spark_df = spark.createDataFrame(X) -# Setting 'prebuilt_env_path' parameter so that `spark_udf` can use the +# Setting 'prebuilt_env_uri' parameter so that `spark_udf` can use the # prebuilt python environment and skip rebuilding python environment. -pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_uri, prebuilt_env_path=local_model_env_path) +pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_uri, prebuilt_env_uri=model_env_uc_path) result = infer_spark_df.select(pyfunc_udf(*X.columns).alias("predictions")).toPandas() print(result) diff --git a/mlflow/environment_variables.py b/mlflow/environment_variables.py index b0d5ac1b840ed..9823084664a03 100644 --- a/mlflow/environment_variables.py +++ b/mlflow/environment_variables.py @@ -661,3 +661,9 @@ def get(self): MLFLOW_USE_DATABRICKS_SDK_MODEL_ARTIFACTS_REPO_FOR_UC = _BooleanEnvironmentVariable( "MLFLOW_USE_DATABRICKS_SDK_MODEL_ARTIFACTS_REPO_FOR_UC", False ) + +# Specifies the model environment archive file downloading path when using +# ``mlflow.pyfunc.spark_udf``. (default: ``None``) +MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR = _EnvironmentVariable( + "MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR", str, None +) diff --git a/mlflow/pyfunc/__init__.py b/mlflow/pyfunc/__init__.py index 6400b4956e042..c0940149ada73 100644 --- a/mlflow/pyfunc/__init__.py +++ b/mlflow/pyfunc/__init__.py @@ -416,10 +416,12 @@ def predict(self, context, model_input, params=None): from functools import lru_cache from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from urllib.parse import urlparse import numpy as np import pandas import yaml +from databricks.sdk import WorkspaceClient from packaging.version import Version import mlflow @@ -428,6 +430,7 @@ def predict(self, context, model_input, params=None): from mlflow.environment_variables import ( _MLFLOW_IN_CAPTURE_MODULE_PROCESS, _MLFLOW_TESTING, + MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR, MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT, ) from mlflow.exceptions import MlflowException @@ -1746,6 +1749,44 @@ def _prebuild_env_internal(local_model_path, archive_name, save_path): shutil.rmtree(env_root_dir, ignore_errors=True) +def _download_prebuilt_env_if_needed(prebuilt_env_uri): + from mlflow.utils.file_utils import get_or_create_tmp_dir + + parsed_url = urlparse(prebuilt_env_uri) + if parsed_url.scheme == "" or parsed_url.scheme == "file": + # local path + return parsed_url.path + if parsed_url.scheme == "dbfs": + tmp_dir = MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR.get() or get_or_create_tmp_dir() + model_env_uc_path = parsed_url.path + + # download file from DBFS. + local_model_env_path = os.path.join(tmp_dir, os.path.basename(model_env_uc_path)) + if os.path.exists(local_model_env_path): + # file is already downloaded. + return local_model_env_path + + try: + ws = WorkspaceClient() + # Download model env file from UC volume. + with ws.files.download(model_env_uc_path).contents as rf, open( + local_model_env_path, "wb" + ) as wf: + while chunk := rf.read(4096 * 1024): + wf.write(chunk) + return local_model_env_path + except (Exception, KeyboardInterrupt): + if os.path.exists(local_model_env_path): + # clean the partially saved file if downloading fails. + os.remove(local_model_env_path) + raise + + raise MlflowException( + f"Unsupported prebuilt env file path '{prebuilt_env_uri}', " + f"invalid scheme: '{parsed_url.scheme}'." + ) + + def build_model_env(model_uri, save_path): """ Prebuild model python environment and generate an archive file saved to provided @@ -1779,13 +1820,16 @@ def build_model_env(model_uri, save_path): from mlflow.pyfunc import build_model_env - # Create a python environment archive file at the path `prebuilt_env_path` - prebuilt_env_path = build_model_env(f"runs:/{run_id}/model", "/path/to/save_directory") + # Create a python environment archive file at the path `prebuilt_env_uri` + prebuilt_env_uri = build_model_env(f"runs:/{run_id}/model", "/path/to/save_directory") Args: model_uri: URI to the model that is used to build the python environment. save_path: The directory path that is used to save the prebuilt model environment archive file path. + The path can be either local directory path or + mounted DBFS path such as '/dbfs/...' or + mounted UC volume path such as '/Volumes/...'. Returns: Return the path of an archive file containing the python environment data. @@ -1840,7 +1884,7 @@ def spark_udf( env_manager=None, params: Optional[Dict[str, Any]] = None, extra_env: Optional[Dict[str, str]] = None, - prebuilt_env_path: Optional[str] = None, + prebuilt_env_uri: Optional[str] = None, model_config: Optional[Union[str, Path, Dict[str, Any]]] = None, ): """ @@ -1874,7 +1918,7 @@ def spark_udf( .. note:: When using Databricks Connect to connect to a remote Databricks cluster, the Databricks cluster must use runtime version >= 16, and when 'spark_udf' - param 'env_manager' is set as 'virtualenv', the 'prebuilt_env_path' param is + param 'env_manager' is set as 'virtualenv', the 'prebuilt_env_uri' param is required to be specified. .. note:: @@ -1952,7 +1996,7 @@ def spark_udf( env_manager: The environment manager to use in order to create the python environment for model inference. Note that environment is only restored in the context of the PySpark UDF; the software environment outside of the UDF is - unaffected. If `prebuilt_env_path` parameter is not set, the default value + unaffected. If `prebuilt_env_uri` parameter is not set, the default value is ``local``, and the following values are supported: - ``virtualenv``: Use virtualenv to restore the python environment that @@ -1963,19 +2007,25 @@ def spark_udf( may differ from the environment used to train the model and may lead to errors or invalid predictions. - If the `prebuilt_env_path` parameter is set, `env_manager` parameter should not + If the `prebuilt_env_uri` parameter is set, `env_manager` parameter should not be set. params: Additional parameters to pass to the model for inference. extra_env: Extra environment variables to pass to the UDF executors. - prebuilt_env_path: The path of the prebuilt env archive file created by + prebuilt_env_uri: The path of the prebuilt env archive file created by `mlflow.pyfunc.build_model_env` API. This parameter can only be used in Databricks Serverless notebook REPL, Databricks Shared cluster notebook REPL, and Databricks Connect client environment. - If this parameter is set, `env_manger` is ignored. + The path can be either local file path or DBFS path such as + 'dbfs:/Volumes/...', in this case, MLflow automatically downloads it + to local temporary directory, "MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR" + environmental variable can be set to specify the temporary directory + to use. + + If this parameter is set, `env_manger` parameter must not be set. model_config: The model configuration to set when loading the model. See 'model_config' argument in `mlflow.pyfunc.load_model` API for details. @@ -2008,10 +2058,10 @@ def spark_udf( openai_env_vars = mlflow.openai._OpenAIEnvVar.read_environ() mlflow_testing = _MLFLOW_TESTING.get_raw() - if prebuilt_env_path: + if prebuilt_env_uri: if env_manager is not None: raise MlflowException( - "If 'prebuilt_env_path' parameter is set, 'env_manager' parameter can't be set." + "If 'prebuilt_env_uri' parameter is set, 'env_manager' parameter can't be set." ) env_manager = _EnvManager.VIRTUALENV else: @@ -2027,16 +2077,16 @@ def spark_udf( is_spark_in_local_mode = spark.conf.get("spark.master").startswith("local") is_dbconnect_mode = is_databricks_connect(spark) - if prebuilt_env_path is not None and not is_dbconnect_mode: + if prebuilt_env_uri is not None and not is_dbconnect_mode: raise RuntimeError( "'prebuilt_env' parameter can only be used in Databricks Serverless " "notebook REPL, atabricks Shared cluster notebook REPL, and Databricks Connect client " "environment." ) - if prebuilt_env_path is None and is_dbconnect_mode and not is_in_databricks_runtime(): + if prebuilt_env_uri is None and is_dbconnect_mode and not is_in_databricks_runtime(): raise RuntimeError( - "'prebuilt_env_path' param is required if using Databricks Connect to connect " + "'prebuilt_env_uri' param is required if using Databricks Connect to connect " "to Databricks cluster from your own machine." ) @@ -2074,8 +2124,9 @@ def spark_udf( output_path=_create_model_downloading_tmp_dir(should_use_nfs), ) - if prebuilt_env_path: - _verify_prebuilt_env(spark, local_model_path, prebuilt_env_path) + if prebuilt_env_uri: + prebuilt_env_uri = _download_prebuilt_env_if_needed(prebuilt_env_uri) + _verify_prebuilt_env(spark, local_model_path, prebuilt_env_uri) if use_dbconnect_artifact and env_manager == _EnvManager.CONDA: raise MlflowException( "Databricks connect mode or Databricks Serverless python REPL doesn't " @@ -2109,14 +2160,14 @@ def spark_udf( "processes cannot be cleaned up if the Spark Job is canceled." ) - if prebuilt_env_path: - env_cache_key = os.path.basename(prebuilt_env_path)[:-7] + if prebuilt_env_uri: + env_cache_key = os.path.basename(prebuilt_env_uri)[:-7] elif use_dbconnect_artifact: env_cache_key = _gen_prebuilt_env_archive_name(spark, local_model_path) else: env_cache_key = None - if use_dbconnect_artifact or prebuilt_env_path is not None: + if use_dbconnect_artifact or prebuilt_env_uri is not None: prebuilt_env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, env_cache_key) pyfunc_backend_env_root_config = { "create_env_root_dir": False, @@ -2136,8 +2187,8 @@ def spark_udf( # Upload model artifacts and python environment to NFS as DBConncet artifacts. if env_manager == _EnvManager.VIRTUALENV: if not dbconnect_artifact_cache.has_cache_key(env_cache_key): - if prebuilt_env_path: - env_archive_path = prebuilt_env_path + if prebuilt_env_uri: + env_archive_path = prebuilt_env_uri else: env_archive_path = _prebuild_env_internal( local_model_path, env_cache_key, get_or_create_tmp_dir() @@ -2152,13 +2203,13 @@ def spark_udf( dbconnect_artifact_cache.add_artifact_archive(model_uri, model_archive_path) elif not should_use_spark_to_broadcast_file: - if prebuilt_env_path: + if prebuilt_env_uri: # Extract prebuilt env archive file to NFS directory. prebuilt_env_nfs_dir = os.path.join( get_or_create_nfs_tmp_dir(), "prebuilt_env", env_cache_key ) if not os.path.exists(prebuilt_env_nfs_dir): - extract_archive_to_dir(prebuilt_env_path, prebuilt_env_nfs_dir) + extract_archive_to_dir(prebuilt_env_uri, prebuilt_env_nfs_dir) else: # Prepare restored environment in driver side if possible. # Note: In databricks runtime, because databricks notebook cell output cannot capture @@ -2345,7 +2396,7 @@ def udf( # Create symlink if it does not exist if not os.path.exists(prebuilt_env_root_dir): os.symlink(env_src_dir, prebuilt_env_root_dir) - elif prebuilt_env_path is not None: + elif prebuilt_env_uri is not None: # prebuilt env is extracted to `prebuilt_env_nfs_dir` directory, # and model is downloaded to `local_model_path` which points to an NFS # directory too. diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 741cede3f525e..8473123ecda48 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -289,9 +289,11 @@ def is_databricks_connect(spark): if is_in_databricks_serverless_runtime() or is_in_databricks_shared_cluster_runtime(): return True try: - # TODO: Remove the `spark.client._builder._build` attribute access once + # TODO: Remove the `spark.client._builder` attribute usage once # Spark-connect has public attribute for this information. - return "databricks-session" in spark.client._builder.userAgent + return is_spark_connect_mode() and any( + k == "x-databricks-cluster-id" for k, v in spark.client._builder.metadata() + ) except Exception: return False @@ -355,9 +357,14 @@ def is_databricks_serverless(spark): """ from mlflow.utils.spark_utils import is_spark_connect_mode - return is_spark_connect_mode() and any( - k == "x-databricks-session-id" for k, v in spark.client.metadata() - ) + try: + # TODO: Remove the `spark.client._builder` attribute usage once + # Spark-connect has public attribute for this information. + return is_spark_connect_mode() and any( + k == "x-databricks-session-id" for k, v in spark.client._builder.metadata() + ) + except Exception: + return False def is_dbfs_fuse_available():