Skip to content

Commit

Permalink
[FOLLOW-UP] Simplify Spark udf serverless API and add doc section (ml…
Browse files Browse the repository at this point in the history
…flow#13496)

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 authored Oct 24, 2024
1 parent 2ec80d3 commit 892a1a5
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 40 deletions.
34 changes: 34 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3996,6 +3996,40 @@ set the `env_manager` argument when calling :py:func:`mlflow.pyfunc.spark_udf`.
df = spark_df.withColumn("prediction", pyfunc_udf(struct("name", "age")))
If you want to call `:py:func:`mlflow.pyfunc.spark_udf` through Databricks connect in remote client, you need to build the model environment in Databricks runtime first.

.. rubric:: Example

.. code-block:: python
from mlflow.pyfunc import build_model_env
# Build the model env and save it as an archive file to the provided UC volume directory
# and print the saved model env archive file path (like '/Volumes/.../.../XXXXX.tar.gz')
print(build_model_env(model_uri, "/Volumes/..."))
# print the cluster id. Databricks Connect client needs to use the cluster id.
print(spark.conf.get("spark.databricks.clusterUsageTags.clusterId"))
Once you have pre-built the model environment, you can run `:py:func:`mlflow.pyfunc.spark_udf` with 'prebuilt_model_env' parameter through Databricks connect in remote client,

.. rubric:: Example

.. code-block:: python
from databricks.connect import DatabricksSession
spark = DatabricksSession.builder.remote(
host=os.environ["DATABRICKS_HOST"],
token=os.environ["DATABRICKS_TOKEN"],
cluster_id="<cluster id>", # get cluster id by spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
).getOrCreate()
# The path generated by `build_model_env` in Databricks runtime.
model_env_uc_uri = "dbfs:/Volumes/.../.../XXXXX.tar.gz"
pyfunc_udf = mlflow.pyfunc.spark_udf(
spark, model_uri, prebuilt_env_uri=model_env_uc_uri
)
.. _deployment_plugin:

Expand Down
15 changes: 3 additions & 12 deletions examples/spark_udf/spark_udf_with_prebuilt_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import os
import tempfile

from databricks.connect import DatabricksSession
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -33,21 +32,13 @@
# The prebuilt model environment archive file path.
# To build the model environment, run the following line code in Databricks runtime:
# `model_env_uc_path = mlflow.pyfunc.build_model_env(model_uri, "/Volumes/...")`
model_env_uc_path = "/Volumes/..."

tmp_dir = tempfile.mkdtemp()
local_model_env_path = os.path.join(tmp_dir, os.path.basename(model_env_uc_path))

# Download model env file from UC volume.
with ws.files.download(model_env_uc_path).contents as rf, open(local_model_env_path, "wb") as wf:
while chunk := rf.read(4096):
wf.write(chunk)
model_env_uc_path = "dbfs:/Volumes/..."

infer_spark_df = spark.createDataFrame(X)

# Setting 'prebuilt_env_path' parameter so that `spark_udf` can use the
# Setting 'prebuilt_env_uri' parameter so that `spark_udf` can use the
# prebuilt python environment and skip rebuilding python environment.
pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_uri, prebuilt_env_path=local_model_env_path)
pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_uri, prebuilt_env_uri=model_env_uc_path)
result = infer_spark_df.select(pyfunc_udf(*X.columns).alias("predictions")).toPandas()

print(result)
6 changes: 6 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,9 @@ def get(self):
MLFLOW_USE_DATABRICKS_SDK_MODEL_ARTIFACTS_REPO_FOR_UC = _BooleanEnvironmentVariable(
"MLFLOW_USE_DATABRICKS_SDK_MODEL_ARTIFACTS_REPO_FOR_UC", False
)

# Specifies the model environment archive file downloading path when using
# ``mlflow.pyfunc.spark_udf``. (default: ``None``)
MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR = _EnvironmentVariable(
"MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR", str, None
)
97 changes: 74 additions & 23 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,12 @@ def predict(self, context, model_input, params=None):
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from urllib.parse import urlparse

import numpy as np
import pandas
import yaml
from databricks.sdk import WorkspaceClient
from packaging.version import Version

import mlflow
Expand All @@ -428,6 +430,7 @@ def predict(self, context, model_input, params=None):
from mlflow.environment_variables import (
_MLFLOW_IN_CAPTURE_MODULE_PROCESS,
_MLFLOW_TESTING,
MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR,
MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT,
)
from mlflow.exceptions import MlflowException
Expand Down Expand Up @@ -1746,6 +1749,44 @@ def _prebuild_env_internal(local_model_path, archive_name, save_path):
shutil.rmtree(env_root_dir, ignore_errors=True)


def _download_prebuilt_env_if_needed(prebuilt_env_uri):
from mlflow.utils.file_utils import get_or_create_tmp_dir

parsed_url = urlparse(prebuilt_env_uri)
if parsed_url.scheme == "" or parsed_url.scheme == "file":
# local path
return parsed_url.path
if parsed_url.scheme == "dbfs":
tmp_dir = MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR.get() or get_or_create_tmp_dir()
model_env_uc_path = parsed_url.path

# download file from DBFS.
local_model_env_path = os.path.join(tmp_dir, os.path.basename(model_env_uc_path))
if os.path.exists(local_model_env_path):
# file is already downloaded.
return local_model_env_path

try:
ws = WorkspaceClient()
# Download model env file from UC volume.
with ws.files.download(model_env_uc_path).contents as rf, open(
local_model_env_path, "wb"
) as wf:
while chunk := rf.read(4096 * 1024):
wf.write(chunk)
return local_model_env_path
except (Exception, KeyboardInterrupt):
if os.path.exists(local_model_env_path):
# clean the partially saved file if downloading fails.
os.remove(local_model_env_path)
raise

raise MlflowException(
f"Unsupported prebuilt env file path '{prebuilt_env_uri}', "
f"invalid scheme: '{parsed_url.scheme}'."
)


def build_model_env(model_uri, save_path):
"""
Prebuild model python environment and generate an archive file saved to provided
Expand Down Expand Up @@ -1779,13 +1820,16 @@ def build_model_env(model_uri, save_path):
from mlflow.pyfunc import build_model_env
# Create a python environment archive file at the path `prebuilt_env_path`
prebuilt_env_path = build_model_env(f"runs:/{run_id}/model", "/path/to/save_directory")
# Create a python environment archive file at the path `prebuilt_env_uri`
prebuilt_env_uri = build_model_env(f"runs:/{run_id}/model", "/path/to/save_directory")
Args:
model_uri: URI to the model that is used to build the python environment.
save_path: The directory path that is used to save the prebuilt model environment
archive file path.
The path can be either local directory path or
mounted DBFS path such as '/dbfs/...' or
mounted UC volume path such as '/Volumes/...'.
Returns:
Return the path of an archive file containing the python environment data.
Expand Down Expand Up @@ -1840,7 +1884,7 @@ def spark_udf(
env_manager=None,
params: Optional[Dict[str, Any]] = None,
extra_env: Optional[Dict[str, str]] = None,
prebuilt_env_path: Optional[str] = None,
prebuilt_env_uri: Optional[str] = None,
model_config: Optional[Union[str, Path, Dict[str, Any]]] = None,
):
"""
Expand Down Expand Up @@ -1874,7 +1918,7 @@ def spark_udf(
.. note::
When using Databricks Connect to connect to a remote Databricks cluster,
the Databricks cluster must use runtime version >= 16, and when 'spark_udf'
param 'env_manager' is set as 'virtualenv', the 'prebuilt_env_path' param is
param 'env_manager' is set as 'virtualenv', the 'prebuilt_env_uri' param is
required to be specified.
.. note::
Expand Down Expand Up @@ -1952,7 +1996,7 @@ def spark_udf(
env_manager: The environment manager to use in order to create the python environment
for model inference. Note that environment is only restored in the context
of the PySpark UDF; the software environment outside of the UDF is
unaffected. If `prebuilt_env_path` parameter is not set, the default value
unaffected. If `prebuilt_env_uri` parameter is not set, the default value
is ``local``, and the following values are supported:
- ``virtualenv``: Use virtualenv to restore the python environment that
Expand All @@ -1963,19 +2007,25 @@ def spark_udf(
may differ from the environment used to train the model and may lead to
errors or invalid predictions.
If the `prebuilt_env_path` parameter is set, `env_manager` parameter should not
If the `prebuilt_env_uri` parameter is set, `env_manager` parameter should not
be set.
params: Additional parameters to pass to the model for inference.
extra_env: Extra environment variables to pass to the UDF executors.
prebuilt_env_path: The path of the prebuilt env archive file created by
prebuilt_env_uri: The path of the prebuilt env archive file created by
`mlflow.pyfunc.build_model_env` API.
This parameter can only be used in Databricks Serverless notebook REPL,
Databricks Shared cluster notebook REPL, and Databricks Connect client
environment.
If this parameter is set, `env_manger` is ignored.
The path can be either local file path or DBFS path such as
'dbfs:/Volumes/...', in this case, MLflow automatically downloads it
to local temporary directory, "MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR"
environmental variable can be set to specify the temporary directory
to use.
If this parameter is set, `env_manger` parameter must not be set.
model_config: The model configuration to set when loading the model.
See 'model_config' argument in `mlflow.pyfunc.load_model` API for details.
Expand Down Expand Up @@ -2008,10 +2058,10 @@ def spark_udf(
openai_env_vars = mlflow.openai._OpenAIEnvVar.read_environ()
mlflow_testing = _MLFLOW_TESTING.get_raw()

if prebuilt_env_path:
if prebuilt_env_uri:
if env_manager is not None:
raise MlflowException(
"If 'prebuilt_env_path' parameter is set, 'env_manager' parameter can't be set."
"If 'prebuilt_env_uri' parameter is set, 'env_manager' parameter can't be set."
)
env_manager = _EnvManager.VIRTUALENV
else:
Expand All @@ -2027,16 +2077,16 @@ def spark_udf(
is_spark_in_local_mode = spark.conf.get("spark.master").startswith("local")

is_dbconnect_mode = is_databricks_connect(spark)
if prebuilt_env_path is not None and not is_dbconnect_mode:
if prebuilt_env_uri is not None and not is_dbconnect_mode:
raise RuntimeError(
"'prebuilt_env' parameter can only be used in Databricks Serverless "
"notebook REPL, atabricks Shared cluster notebook REPL, and Databricks Connect client "
"environment."
)

if prebuilt_env_path is None and is_dbconnect_mode and not is_in_databricks_runtime():
if prebuilt_env_uri is None and is_dbconnect_mode and not is_in_databricks_runtime():
raise RuntimeError(
"'prebuilt_env_path' param is required if using Databricks Connect to connect "
"'prebuilt_env_uri' param is required if using Databricks Connect to connect "
"to Databricks cluster from your own machine."
)

Expand Down Expand Up @@ -2074,8 +2124,9 @@ def spark_udf(
output_path=_create_model_downloading_tmp_dir(should_use_nfs),
)

if prebuilt_env_path:
_verify_prebuilt_env(spark, local_model_path, prebuilt_env_path)
if prebuilt_env_uri:
prebuilt_env_uri = _download_prebuilt_env_if_needed(prebuilt_env_uri)
_verify_prebuilt_env(spark, local_model_path, prebuilt_env_uri)
if use_dbconnect_artifact and env_manager == _EnvManager.CONDA:
raise MlflowException(
"Databricks connect mode or Databricks Serverless python REPL doesn't "
Expand Down Expand Up @@ -2109,14 +2160,14 @@ def spark_udf(
"processes cannot be cleaned up if the Spark Job is canceled."
)

if prebuilt_env_path:
env_cache_key = os.path.basename(prebuilt_env_path)[:-7]
if prebuilt_env_uri:
env_cache_key = os.path.basename(prebuilt_env_uri)[:-7]
elif use_dbconnect_artifact:
env_cache_key = _gen_prebuilt_env_archive_name(spark, local_model_path)
else:
env_cache_key = None

if use_dbconnect_artifact or prebuilt_env_path is not None:
if use_dbconnect_artifact or prebuilt_env_uri is not None:
prebuilt_env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, env_cache_key)
pyfunc_backend_env_root_config = {
"create_env_root_dir": False,
Expand All @@ -2136,8 +2187,8 @@ def spark_udf(
# Upload model artifacts and python environment to NFS as DBConncet artifacts.
if env_manager == _EnvManager.VIRTUALENV:
if not dbconnect_artifact_cache.has_cache_key(env_cache_key):
if prebuilt_env_path:
env_archive_path = prebuilt_env_path
if prebuilt_env_uri:
env_archive_path = prebuilt_env_uri
else:
env_archive_path = _prebuild_env_internal(
local_model_path, env_cache_key, get_or_create_tmp_dir()
Expand All @@ -2152,13 +2203,13 @@ def spark_udf(
dbconnect_artifact_cache.add_artifact_archive(model_uri, model_archive_path)

elif not should_use_spark_to_broadcast_file:
if prebuilt_env_path:
if prebuilt_env_uri:
# Extract prebuilt env archive file to NFS directory.
prebuilt_env_nfs_dir = os.path.join(
get_or_create_nfs_tmp_dir(), "prebuilt_env", env_cache_key
)
if not os.path.exists(prebuilt_env_nfs_dir):
extract_archive_to_dir(prebuilt_env_path, prebuilt_env_nfs_dir)
extract_archive_to_dir(prebuilt_env_uri, prebuilt_env_nfs_dir)
else:
# Prepare restored environment in driver side if possible.
# Note: In databricks runtime, because databricks notebook cell output cannot capture
Expand Down Expand Up @@ -2345,7 +2396,7 @@ def udf(
# Create symlink if it does not exist
if not os.path.exists(prebuilt_env_root_dir):
os.symlink(env_src_dir, prebuilt_env_root_dir)
elif prebuilt_env_path is not None:
elif prebuilt_env_uri is not None:
# prebuilt env is extracted to `prebuilt_env_nfs_dir` directory,
# and model is downloaded to `local_model_path` which points to an NFS
# directory too.
Expand Down
17 changes: 12 additions & 5 deletions mlflow/utils/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,11 @@ def is_databricks_connect(spark):
if is_in_databricks_serverless_runtime() or is_in_databricks_shared_cluster_runtime():
return True
try:
# TODO: Remove the `spark.client._builder._build` attribute access once
# TODO: Remove the `spark.client._builder` attribute usage once
# Spark-connect has public attribute for this information.
return "databricks-session" in spark.client._builder.userAgent
return is_spark_connect_mode() and any(
k == "x-databricks-cluster-id" for k, v in spark.client._builder.metadata()
)
except Exception:
return False

Expand Down Expand Up @@ -355,9 +357,14 @@ def is_databricks_serverless(spark):
"""
from mlflow.utils.spark_utils import is_spark_connect_mode

return is_spark_connect_mode() and any(
k == "x-databricks-session-id" for k, v in spark.client.metadata()
)
try:
# TODO: Remove the `spark.client._builder` attribute usage once
# Spark-connect has public attribute for this information.
return is_spark_connect_mode() and any(
k == "x-databricks-session-id" for k, v in spark.client._builder.metadata()
)
except Exception:
return False


def is_dbfs_fuse_available():
Expand Down

0 comments on commit 892a1a5

Please sign in to comment.