From 3f495229f26ef08360048d050dfe014ca4b57b4f Mon Sep 17 00:00:00 2001 From: Dharmit Dalvi Date: Thu, 26 Sep 2024 03:58:35 +0530 Subject: [PATCH] feat(sdk): Allow disabling default caching via a CLI flag and env var (#11222) * feat(sdk): Allow setting a default of execution caching disabled via a compiler CLI flag and env var Co-authored-by: Greg Sheremeta Signed-off-by: ddalvi * Add tests for disabling default caching var and flag Signed-off-by: ddalvi --------- Signed-off-by: ddalvi Co-authored-by: Greg Sheremeta --- sdk/python/kfp/cli/cli_test.py | 84 ++++++++++++++++++++++++ sdk/python/kfp/cli/compile_.py | 10 +++ sdk/python/kfp/compiler/compiler_test.py | 53 +++++++++++++++ sdk/python/kfp/dsl/base_component.py | 2 + sdk/python/kfp/dsl/pipeline_context.py | 14 ++++ sdk/python/kfp/dsl/pipeline_task.py | 3 +- 6 files changed, 165 insertions(+), 1 deletion(-) diff --git a/sdk/python/kfp/cli/cli_test.py b/sdk/python/kfp/cli/cli_test.py index 361db73a14e..d1af4095fde 100644 --- a/sdk/python/kfp/cli/cli_test.py +++ b/sdk/python/kfp/cli/cli_test.py @@ -27,6 +27,7 @@ from click import testing from kfp.cli import cli from kfp.cli import compile_ +import yaml class TestCliNounAliases(unittest.TestCase): @@ -196,5 +197,88 @@ def test(self, noun: str, verb: str): self.assertEqual(result.exit_code, 0) +class TestKfpDslCompile(unittest.TestCase): + + def invoke(self, args): + starting_args = ['dsl', 'compile'] + args = starting_args + args + runner = testing.CliRunner() + return runner.invoke( + cli=cli.cli, args=args, catch_exceptions=False, obj={}) + + def create_pipeline_file(self): + pipeline_code = b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""" + temp_pipeline = tempfile.NamedTemporaryFile(suffix='.py', delete=False) + temp_pipeline.write(pipeline_code) + temp_pipeline.flush() + return temp_pipeline + + def load_output_yaml(self, output_file): + with open(output_file, 'r') as f: + return yaml.safe_load(f) + + def test_compile_with_caching_flag_enabled(self): + temp_pipeline = self.create_pipeline_file() + output_file = 'test_output.yaml' + + result = self.invoke( + ['--py', temp_pipeline.name, '--output', output_file]) + self.assertEqual(result.exit_code, 0) + + output_data = self.load_output_yaml(output_file) + self.assertIn('root', output_data) + self.assertIn('tasks', output_data['root']['dag']) + for task in output_data['root']['dag']['tasks'].values(): + self.assertIn('cachingOptions', task) + caching_options = task['cachingOptions'] + self.assertEqual(caching_options.get('enableCache'), True) + + def test_compile_with_caching_flag_disabled(self): + temp_pipeline = self.create_pipeline_file() + output_file = 'test_output.yaml' + + result = self.invoke([ + '--py', temp_pipeline.name, '--output', output_file, + '--disable-execution-caching-by-default' + ]) + self.assertEqual(result.exit_code, 0) + + output_data = self.load_output_yaml(output_file) + self.assertIn('root', output_data) + self.assertIn('tasks', output_data['root']['dag']) + for task in output_data['root']['dag']['tasks'].values(): + self.assertIn('cachingOptions', task) + caching_options = task['cachingOptions'] + self.assertEqual(caching_options, {}) + + def test_compile_with_caching_disabled_env_var(self): + temp_pipeline = self.create_pipeline_file() + output_file = 'test_output.yaml' + + os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true' + result = self.invoke( + ['--py', temp_pipeline.name, '--output', output_file]) + self.assertEqual(result.exit_code, 0) + del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] + + output_data = self.load_output_yaml(output_file) + self.assertIn('root', output_data) + self.assertIn('tasks', output_data['root']['dag']) + for task in output_data['root']['dag']['tasks'].values(): + self.assertIn('cachingOptions', task) + caching_options = task['cachingOptions'] + self.assertEqual(caching_options, {}) + + if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/cli/compile_.py b/sdk/python/kfp/cli/compile_.py index 2bd3bab18c2..e1fc28a8328 100644 --- a/sdk/python/kfp/cli/compile_.py +++ b/sdk/python/kfp/cli/compile_.py @@ -24,6 +24,7 @@ from kfp import compiler from kfp.dsl import base_component from kfp.dsl import graph_component +from kfp.dsl.pipeline_context import Pipeline def is_pipeline_func(func: Callable) -> bool: @@ -133,14 +134,23 @@ def parse_parameters(parameters: Optional[str]) -> Dict: is_flag=True, default=False, help='Whether to disable type checking.') +@click.option( + '--disable-execution-caching-by-default', + is_flag=True, + default=False, + envvar='KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT', + help='Whether to disable execution caching by default.') def compile_( py: str, output: str, function_name: Optional[str] = None, pipeline_parameters: Optional[str] = None, disable_type_check: bool = False, + disable_execution_caching_by_default: bool = False, ) -> None: """Compiles a pipeline or component written in a .py file.""" + + Pipeline._execution_caching_default = not disable_execution_caching_by_default pipeline_func = collect_pipeline_or_component_func( python_file=py, function_name=function_name) parsed_parameters = parse_parameters(parameters=pipeline_parameters) diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 598983af778..2433f09bc6d 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -910,6 +910,59 @@ def my_pipeline() -> NamedTuple('Outputs', [ task = print_and_return(text='Hello') +class TestCompilePipelineCaching(unittest.TestCase): + + def test_compile_pipeline_with_caching_enabled(self): + """Test pipeline compilation with caching enabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name='tiny-pipeline') + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(True) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec['cachingOptions'] + + self.assertTrue(caching_options['enableCache']) + + def test_compile_pipeline_with_caching_disabled(self): + """Test pipeline compilation with caching disabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name='tiny-pipeline') + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(False) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec.get('cachingOptions', {}) + + self.assertEqual(caching_options, {}) + + class V2NamespaceAliasTest(unittest.TestCase): """Test that imports of both modules and objects are aliased (e.g. all import path variants work).""" diff --git a/sdk/python/kfp/dsl/base_component.py b/sdk/python/kfp/dsl/base_component.py index 86e291018ad..2682321417d 100644 --- a/sdk/python/kfp/dsl/base_component.py +++ b/sdk/python/kfp/dsl/base_component.py @@ -103,6 +103,8 @@ def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask: args=task_inputs, execute_locally=pipeline_context.Pipeline.get_default_pipeline() is None, + execution_caching_default=pipeline_context.Pipeline + .get_execution_caching_default(), ) @property diff --git a/sdk/python/kfp/dsl/pipeline_context.py b/sdk/python/kfp/dsl/pipeline_context.py index 4881bc5680c..4d0bbbaa840 100644 --- a/sdk/python/kfp/dsl/pipeline_context.py +++ b/sdk/python/kfp/dsl/pipeline_context.py @@ -14,6 +14,7 @@ """Definition for Pipeline.""" import functools +import os from typing import Callable, Optional from kfp.dsl import component_factory @@ -101,6 +102,19 @@ def get_default_pipeline(): """Gets the default pipeline.""" return Pipeline._default_pipeline + # _execution_caching_default can be disabled via the click option --disable-execution-caching-by-default + # or the env var KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT. + # align with click's treatment of env vars for boolean flags. + # per click doc, "1", "true", "t", "yes", "y", and "on" are all converted to True + _execution_caching_default = not str( + os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower( + ) in {'1', 'true', 't', 'yes', 'y', 'on'} + + @staticmethod + def get_execution_caching_default(): + """Gets the default execution caching.""" + return Pipeline._execution_caching_default + def __init__(self, name: str): """Creates a new instance of Pipeline. diff --git a/sdk/python/kfp/dsl/pipeline_task.py b/sdk/python/kfp/dsl/pipeline_task.py index 773fb1e0676..822f5520788 100644 --- a/sdk/python/kfp/dsl/pipeline_task.py +++ b/sdk/python/kfp/dsl/pipeline_task.py @@ -98,6 +98,7 @@ def __init__( component_spec: structures.ComponentSpec, args: Dict[str, Any], execute_locally: bool = False, + execution_caching_default: bool = True, ) -> None: """Initilizes a PipelineTask instance.""" # import within __init__ to avoid circular import @@ -130,7 +131,7 @@ def __init__( inputs=dict(args.items()), dependent_tasks=[], component_ref=component_spec.name, - enable_caching=True) + enable_caching=execution_caching_default) self._run_after: List[str] = [] self.importer_spec = None