diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 2433f09bc6d..6ac26568c89 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -3847,6 +3847,42 @@ def outer(): foo_platform_set_bar_feature(task, 12) +class TestPipelineSemaphoreMutex(unittest.TestCase): + + def test_pipeline_with_semaphore_and_mutex(self): + from kfp import compiler + from kfp import dsl + from kfp.dsl.pipeline_config import PipelineConfig + + config = PipelineConfig() + config.set_semaphore_key('semaphore') + config.set_mutex_name('mutex') + + @dsl.pipeline(pipeline_config=config) + def my_pipeline(): + task = comp() + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_docs = list(yaml.safe_load_all(f)) + + pipeline_spec = None + for doc in pipeline_docs: + if 'platforms' in doc: + pipeline_spec = doc + break + + if pipeline_spec: + kubernetes_spec = pipeline_spec['platforms']['kubernetes'][ + 'pipelineConfig'] + assert kubernetes_spec['semaphoreKey'] == 'semaphore' + assert kubernetes_spec['mutexName'] == 'mutex' + + class ExtractInputOutputDescription(unittest.TestCase): def test_no_descriptions(self):