Skip to content

Commit

Permalink
(feat) Introduce cache_key to sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Dec 13, 2024
1 parent 0eb67e1 commit 5c1ad7d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 3 deletions.
24 changes: 23 additions & 1 deletion sdk/python/kfp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def run_pipeline(
version_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1Run:
"""Runs a specified pipeline.
Expand All @@ -709,6 +710,8 @@ def run_pipeline(
is ``True`` for all tasks by default. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
Expand All @@ -721,6 +724,7 @@ def run_pipeline(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -806,6 +810,7 @@ def create_recurring_run(
enabled: bool = True,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
) -> kfp_server_api.V2beta1RecurringRun:
"""Creates a recurring run.
Expand Down Expand Up @@ -850,6 +855,8 @@ def create_recurring_run(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account this recurring run uses.
Returns:
Expand All @@ -862,6 +869,7 @@ def create_recurring_run(
pipeline_id=pipeline_id,
version_id=version_id,
enable_caching=enable_caching,
cache_key=cache_key,
pipeline_root=pipeline_root,
)

Expand Down Expand Up @@ -908,6 +916,7 @@ def _create_job_config(
pipeline_id: Optional[str],
version_id: Optional[str],
enable_caching: Optional[bool],
cache_key: Optional[str],
pipeline_root: Optional[str],
) -> _JobConfig:
"""Creates a JobConfig with spec and resource_references.
Expand All @@ -928,6 +937,8 @@ def _create_job_config(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
pipeline_root: Root path of the pipeline outputs.
Returns:
Expand Down Expand Up @@ -956,7 +967,8 @@ def _create_job_config(
# settings.
if enable_caching is not None:
_override_caching_options(pipeline_doc.pipeline_spec,
enable_caching)
enable_caching,
cache_key)
pipeline_spec = pipeline_doc.to_dict()

pipeline_version_reference = None
Expand All @@ -983,6 +995,7 @@ def create_run_from_pipeline_func(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1004,6 +1017,8 @@ def create_run_from_pipeline_func(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1032,6 +1047,7 @@ def create_run_from_pipeline_func(
namespace=namespace,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)

Expand All @@ -1044,6 +1060,7 @@ def create_run_from_pipeline_package(
namespace: Optional[str] = None,
pipeline_root: Optional[str] = None,
enable_caching: Optional[bool] = None,
cache_key: Optional[str] = "",
service_account: Optional[str] = None,
experiment_id: Optional[str] = None,
) -> RunPipelineResult:
Expand All @@ -1065,6 +1082,8 @@ def create_run_from_pipeline_package(
different caching options for individual tasks. If set, the
setting applies to all tasks in the pipeline (overrides the
compile time settings).
cache_key (optional): Customized cache key for this task.
If set, the cache_key will be used as the key for the task's cache.
service_account: Specifies which Kubernetes service
account to use for this run.
experiment_id: ID of the experiment to add the run to. You cannot specify both experiment_id and experiment_name.
Expand Down Expand Up @@ -1105,6 +1124,7 @@ def create_run_from_pipeline_package(
params=arguments,
pipeline_root=pipeline_root,
enable_caching=enable_caching,
cache_key=cache_key,
service_account=service_account,
)
return RunPipelineResult(self, run_info)
Expand Down Expand Up @@ -1681,6 +1701,7 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc:
def _override_caching_options(
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
enable_caching: bool,
cache_key: str="",
) -> None:
"""Overrides caching options.
Expand All @@ -1690,3 +1711,4 @@ def _override_caching_options(
"""
for _, task_spec in pipeline_spec.root.dag.tasks.items():
task_spec.caching_options.enable_cache = enable_caching
task_spec.caching_options.cache_key = cache_key
2 changes: 2 additions & 0 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def build_task_spec_for_task(
utils.sanitize_component_name(task.name))
pipeline_task_spec.caching_options.enable_cache = (
task._task_spec.enable_caching)
pipeline_task_spec.caching_options.cache_key = (
task._task_spec.cache_key)

if task._task_spec.retry_policy is not None:
pipeline_task_spec.retry_policy.CopyFrom(
Expand Down
8 changes: 6 additions & 2 deletions sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
args: Dict[str, Any],
execute_locally: bool = False,
execution_caching_default: bool = True,
execution_cache_key: str = "",
) -> None:
"""Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
Expand Down Expand Up @@ -131,7 +132,8 @@ def __init__(
inputs=dict(args.items()),
dependent_tasks=[],
component_ref=component_spec.name,
enable_caching=execution_caching_default)
enable_caching=execution_caching_default,
cache_key=execution_cache_key)
self._run_after: List[str] = []

self.importer_spec = None
Expand Down Expand Up @@ -301,16 +303,18 @@ def _extract_container_spec_and_convert_placeholders(
return container_spec

@block_if_final()
def set_caching_options(self, enable_caching: bool) -> 'PipelineTask':
def set_caching_options(self, enable_caching: bool, cache_key: str = "") -> 'PipelineTask':
"""Sets caching options for the task.
Args:
enable_caching: Whether to enable caching.
cache_key: Customized cache key for this task.
Returns:
Self return to allow chained setting calls.
"""
self._task_spec.enable_caching = enable_caching
self._task_spec.cache_key = cache_key
return self

def _ensure_container_spec_exists(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/kfp/dsl/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ class TaskSpec:
from the [items][] collection.
enable_caching (optional): whether or not to enable caching for the task.
Default is True.
cache_key (optional): Customized cache key for this task.
Default is empty string.
display_name (optional): the display name of the task. If not specified,
the task name will be used as the display name.
"""
Expand All @@ -421,6 +423,7 @@ class TaskSpec:
iterator_items: Optional[Any] = None
iterator_item_input: Optional[str] = None
enable_caching: bool = True
cache_key: str = ""
display_name: Optional[str] = None
retry_policy: Optional[RetryPolicy] = None

Expand Down

0 comments on commit 5c1ad7d

Please sign in to comment.