From 3b3eed6053087d5cf95cfcc087f3a9e1d30bda26 Mon Sep 17 00:00:00 2001 From: zazulam Date: Mon, 15 Jul 2024 13:52:29 -0400 Subject: [PATCH] address pr comments Signed-off-by: zazulam --- backend/src/v2/compiler/argocompiler/argo.go | 3 +-- .../src/v2/compiler/argocompiler/argo_test.go | 15 ++++++++++++ sdk/python/kfp/compiler/compiler_test.py | 2 +- sdk/python/kfp/dsl/component_factory.py | 7 ++++++ sdk/python/kfp/dsl/component_factory_test.py | 24 +++++++++++++++++++ 5 files changed, 48 insertions(+), 3 deletions(-) diff --git a/backend/src/v2/compiler/argocompiler/argo.go b/backend/src/v2/compiler/argocompiler/argo.go index faf5b2b69840..b6899f9b6c67 100644 --- a/backend/src/v2/compiler/argocompiler/argo.go +++ b/backend/src/v2/compiler/argocompiler/argo.go @@ -262,7 +262,7 @@ func (c *workflowCompiler) argumentsPlaceholder(componentName string) (string, e return workflowParameter(componentName), nil } -// extractBaseComponentName removes the iteration suffix that the IR compiler +// ExtractBaseComponentName removes the iteration suffix that the IR compiler // adds to the component name. func ExtractBaseComponentName(componentName string) string { baseComponentName := componentName @@ -270,7 +270,6 @@ func ExtractBaseComponentName(componentName string) string { if _, err := strconv.Atoi(componentNameArray[len(componentNameArray)-1]); err == nil { baseComponentName = strings.Join(componentNameArray[:len(componentNameArray)-1], "-") - } return baseComponentName diff --git a/backend/src/v2/compiler/argocompiler/argo_test.go b/backend/src/v2/compiler/argocompiler/argo_test.go index f3bb1fdcb1c0..3f6ccb88afe2 100644 --- a/backend/src/v2/compiler/argocompiler/argo_test.go +++ b/backend/src/v2/compiler/argocompiler/argo_test.go @@ -154,6 +154,21 @@ func Test_extractBaseComponentName(t *testing.T) { componentName: "component", expectedBaseName: "component", }, + { + name: "Last char is int", + componentName: "component-v2", + expectedBaseName: "component-v2", + }, + { + name: "Multiple dashes, ends with int", + componentName: "service-api-v2", + expectedBaseName: "service-api-v2", + }, + { + name: "Multiple dashes and ints", + componentName: "module-1-2-3", + expectedBaseName: "module-1-2", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index d417d9eec199..0546d83b252f 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -129,7 +129,7 @@ def comp(): @dsl.component -def return_1() -> int: +def return_one() -> int: return 1 diff --git a/sdk/python/kfp/dsl/component_factory.py b/sdk/python/kfp/dsl/component_factory.py index 1af26d80bfc4..e6d9656f89ad 100644 --- a/sdk/python/kfp/dsl/component_factory.py +++ b/sdk/python/kfp/dsl/component_factory.py @@ -66,6 +66,13 @@ class ComponentInfo(): def _python_function_name_to_component_name(name): name_with_spaces = re.sub(' +', ' ', name.replace('_', ' ')).strip(' ') + name_list = name_with_spaces.split(' ') + + if name_list[-1].isdigit(): + raise ValueError( + f'Invalid function name "{name}". The function name must not end in `_`.' + ) + return name_with_spaces[0].upper() + name_with_spaces[1:] diff --git a/sdk/python/kfp/dsl/component_factory_test.py b/sdk/python/kfp/dsl/component_factory_test.py index b602be241fd0..7f8770cbed5d 100644 --- a/sdk/python/kfp/dsl/component_factory_test.py +++ b/sdk/python/kfp/dsl/component_factory_test.py @@ -174,6 +174,30 @@ def comp(Output: OutputPath(str), text: str) -> str: pass +class TestPythonFunctionName(unittest.TestCase): + + def test_invalid_function_name(self): + + with self.assertRaisesRegex( + ValueError, + r'Invalid function name "comp_2". The function name must not end in `_`.' + ): + + @component + def comp_2(text: str) -> str: + pass + + def test_valid_function_name(self): + + @component + def comp_v2(text: str) -> str: + pass + + @component + def comp_(text: str) -> str: + pass + + class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase): def test_python_component_input(self):