Skip to content

Commit

Permalink
feat(sdk): Allow disabling default caching via a CLI flag and env var (
Browse files Browse the repository at this point in the history
…#11222)

* feat(sdk): Allow setting a default of execution caching disabled via a compiler CLI flag and env var

Co-authored-by: Greg Sheremeta <[email protected]>
Signed-off-by: ddalvi <[email protected]>

* Add tests for disabling default caching var and flag

Signed-off-by: ddalvi <[email protected]>

---------

Signed-off-by: ddalvi <[email protected]>
Co-authored-by: Greg Sheremeta <[email protected]>
  • Loading branch information
DharmitD and gregsheremeta authored Sep 25, 2024
1 parent 880e46d commit 3f49522
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 1 deletion.
84 changes: 84 additions & 0 deletions sdk/python/kfp/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from click import testing
from kfp.cli import cli
from kfp.cli import compile_
import yaml


class TestCliNounAliases(unittest.TestCase):
Expand Down Expand Up @@ -196,5 +197,88 @@ def test(self, noun: str, verb: str):
self.assertEqual(result.exit_code, 0)


class TestKfpDslCompile(unittest.TestCase):

def invoke(self, args):
starting_args = ['dsl', 'compile']
args = starting_args + args
runner = testing.CliRunner()
return runner.invoke(
cli=cli.cli, args=args, catch_exceptions=False, obj={})

def create_pipeline_file(self):
pipeline_code = b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
"""
temp_pipeline = tempfile.NamedTemporaryFile(suffix='.py', delete=False)
temp_pipeline.write(pipeline_code)
temp_pipeline.flush()
return temp_pipeline

def load_output_yaml(self, output_file):
with open(output_file, 'r') as f:
return yaml.safe_load(f)

def test_compile_with_caching_flag_enabled(self):
temp_pipeline = self.create_pipeline_file()
output_file = 'test_output.yaml'

result = self.invoke(
['--py', temp_pipeline.name, '--output', output_file])
self.assertEqual(result.exit_code, 0)

output_data = self.load_output_yaml(output_file)
self.assertIn('root', output_data)
self.assertIn('tasks', output_data['root']['dag'])
for task in output_data['root']['dag']['tasks'].values():
self.assertIn('cachingOptions', task)
caching_options = task['cachingOptions']
self.assertEqual(caching_options.get('enableCache'), True)

def test_compile_with_caching_flag_disabled(self):
temp_pipeline = self.create_pipeline_file()
output_file = 'test_output.yaml'

result = self.invoke([
'--py', temp_pipeline.name, '--output', output_file,
'--disable-execution-caching-by-default'
])
self.assertEqual(result.exit_code, 0)

output_data = self.load_output_yaml(output_file)
self.assertIn('root', output_data)
self.assertIn('tasks', output_data['root']['dag'])
for task in output_data['root']['dag']['tasks'].values():
self.assertIn('cachingOptions', task)
caching_options = task['cachingOptions']
self.assertEqual(caching_options, {})

def test_compile_with_caching_disabled_env_var(self):
temp_pipeline = self.create_pipeline_file()
output_file = 'test_output.yaml'

os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true'
result = self.invoke(
['--py', temp_pipeline.name, '--output', output_file])
self.assertEqual(result.exit_code, 0)
del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT']

output_data = self.load_output_yaml(output_file)
self.assertIn('root', output_data)
self.assertIn('tasks', output_data['root']['dag'])
for task in output_data['root']['dag']['tasks'].values():
self.assertIn('cachingOptions', task)
caching_options = task['cachingOptions']
self.assertEqual(caching_options, {})


if __name__ == '__main__':
unittest.main()
10 changes: 10 additions & 0 deletions sdk/python/kfp/cli/compile_.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from kfp import compiler
from kfp.dsl import base_component
from kfp.dsl import graph_component
from kfp.dsl.pipeline_context import Pipeline


def is_pipeline_func(func: Callable) -> bool:
Expand Down Expand Up @@ -133,14 +134,23 @@ def parse_parameters(parameters: Optional[str]) -> Dict:
is_flag=True,
default=False,
help='Whether to disable type checking.')
@click.option(
'--disable-execution-caching-by-default',
is_flag=True,
default=False,
envvar='KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT',
help='Whether to disable execution caching by default.')
def compile_(
py: str,
output: str,
function_name: Optional[str] = None,
pipeline_parameters: Optional[str] = None,
disable_type_check: bool = False,
disable_execution_caching_by_default: bool = False,
) -> None:
"""Compiles a pipeline or component written in a .py file."""

Pipeline._execution_caching_default = not disable_execution_caching_by_default
pipeline_func = collect_pipeline_or_component_func(
python_file=py, function_name=function_name)
parsed_parameters = parse_parameters(parameters=pipeline_parameters)
Expand Down
53 changes: 53 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,59 @@ def my_pipeline() -> NamedTuple('Outputs', [
task = print_and_return(text='Hello')


class TestCompilePipelineCaching(unittest.TestCase):

def test_compile_pipeline_with_caching_enabled(self):
"""Test pipeline compilation with caching enabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(True)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec['cachingOptions']

self.assertTrue(caching_options['enableCache'])

def test_compile_pipeline_with_caching_disabled(self):
"""Test pipeline compilation with caching disabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(False)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec.get('cachingOptions', {})

self.assertEqual(caching_options, {})


class V2NamespaceAliasTest(unittest.TestCase):
"""Test that imports of both modules and objects are aliased (e.g. all
import path variants work)."""
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/kfp/dsl/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask:
args=task_inputs,
execute_locally=pipeline_context.Pipeline.get_default_pipeline() is
None,
execution_caching_default=pipeline_context.Pipeline
.get_execution_caching_default(),
)

@property
Expand Down
14 changes: 14 additions & 0 deletions sdk/python/kfp/dsl/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Definition for Pipeline."""

import functools
import os
from typing import Callable, Optional

from kfp.dsl import component_factory
Expand Down Expand Up @@ -101,6 +102,19 @@ def get_default_pipeline():
"""Gets the default pipeline."""
return Pipeline._default_pipeline

# _execution_caching_default can be disabled via the click option --disable-execution-caching-by-default
# or the env var KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT.
# align with click's treatment of env vars for boolean flags.
# per click doc, "1", "true", "t", "yes", "y", and "on" are all converted to True
_execution_caching_default = not str(
os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower(
) in {'1', 'true', 't', 'yes', 'y', 'on'}

@staticmethod
def get_execution_caching_default():
"""Gets the default execution caching."""
return Pipeline._execution_caching_default

def __init__(self, name: str):
"""Creates a new instance of Pipeline.
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
component_spec: structures.ComponentSpec,
args: Dict[str, Any],
execute_locally: bool = False,
execution_caching_default: bool = True,
) -> None:
"""Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(
inputs=dict(args.items()),
dependent_tasks=[],
component_ref=component_spec.name,
enable_caching=True)
enable_caching=execution_caching_default)
self._run_after: List[str] = []

self.importer_spec = None
Expand Down

0 comments on commit 3f49522

Please sign in to comment.