diff --git a/src/fondant/core/exceptions.py b/src/fondant/core/exceptions.py index 8ce549ed..4143f389 100644 --- a/src/fondant/core/exceptions.py +++ b/src/fondant/core/exceptions.py @@ -21,3 +21,7 @@ class InvalidPipelineDefinition(ValidationError, FondantException): class InvalidTypeSchema(ValidationError, FondantException): """Thrown when a Type schema definition is invalid.""" + + +class UnsupportedTypeAnnotation(FondantException): + """Thrown when an unsupported type annotation is encountered during type inference.""" diff --git a/src/fondant/pipeline/argument_inference.py b/src/fondant/pipeline/argument_inference.py new file mode 100644 index 00000000..ee2c24bb --- /dev/null +++ b/src/fondant/pipeline/argument_inference.py @@ -0,0 +1,103 @@ +import inspect +import typing as t + +from fondant.component import Component +from fondant.core.component_spec import Argument +from fondant.core.exceptions import UnsupportedTypeAnnotation + +BUILTIN_TYPES = [str, int, float, bool, dict, list] + + +def annotation_to_type(annotation: t.Any) -> t.Type: + """Extract the simple built-in type from an Annotation. + + Examples: + dict[str, int] -> dict + t.Optional[str] -> str + + Args: + annotation: Annotation of an argument as returned by inspect.signature + + Raises: + UnsupportedTypeAnnotation: If the annotation is not simple or not based on a built-in type + + """ + # If no type annotation is present, default to str + if annotation == inspect.Parameter.empty: + return str + + # Unpack the annotation until we get a simple type. + # This removes complex structures such as Optional + while t.get_origin(annotation) not in [*BUILTIN_TYPES, None]: + # Filter out NoneType values (Optional[x] is represented as Union[x, NoneType] + annotation_args = [ + arg for arg in t.get_args(annotation) if arg is not type(None) + ] + + # Multiple arguments remaining (eg. Union[str, int]) + # Raise error since we cannot infer type unambiguously + if len(annotation_args) > 1: + msg = ( + f"Fondant only supports simple types for component arguments." + f"Expected one of {BUILTIN_TYPES}, received {annotation} instead." + ) + raise UnsupportedTypeAnnotation(msg) + + annotation = annotation_args[0] + + # Remove any subscription (eg. dict[str, int] -> dict) + annotation = t.get_origin(annotation) or annotation + + # Check for classes not supported as argument + if annotation not in BUILTIN_TYPES: + msg = ( + f"Fondant only supports builtin types for component arguments." + f"Expected one of {BUILTIN_TYPES}, received {annotation} instead." + ) + raise UnsupportedTypeAnnotation(msg) + + return annotation + + +def is_optional(parameter: inspect.Parameter) -> bool: + """Check if an inspect.Parameter is optional. We check this based on the presence of a + default value instead of based on the type, since this is more trustworthy. + """ + return parameter.default != inspect.Parameter.empty + + +def get_default(parameter: inspect.Parameter) -> t.Any: + """Get the default value from an inspect.Parameter.""" + if parameter.default == inspect.Parameter.empty: + return None + return parameter.default + + +def parameter_to_argument(parameter: inspect.Parameter) -> Argument: + """Translate an inspect.Parameter into a Fondant Argument.""" + return Argument( + name=parameter.name, + type=annotation_to_type(parameter.annotation), + optional=is_optional(parameter), + default=get_default(parameter), + ) + + +def infer_arguments(component: t.Type[Component]) -> t.Dict[str, Argument]: + """Infer the user arguments from a Python Component class. + Default arguments are skipped. + + Args: + component: Component class to inspect. + """ + signature = inspect.signature(component) + + arguments = {} + for name, parameter in signature.parameters.items(): + # Skip non-user arguments + if name in ["self", "consumes", "produces", "kwargs"]: + continue + + arguments[name] = parameter_to_argument(parameter) + + return arguments diff --git a/tests/pipeline/test_argument_inference.py b/tests/pipeline/test_argument_inference.py new file mode 100644 index 00000000..2f306bf8 --- /dev/null +++ b/tests/pipeline/test_argument_inference.py @@ -0,0 +1,247 @@ +import sys +import typing as t + +import pytest +from fondant.component import PandasTransformComponent +from fondant.core.component_spec import Argument +from fondant.core.exceptions import UnsupportedTypeAnnotation +from fondant.pipeline.argument_inference import infer_arguments + + +def test_no_init(): + class TestComponent(PandasTransformComponent): + pass + + assert infer_arguments(TestComponent) == {} + + +def test_no_arguments(): + class TestComponent(PandasTransformComponent): + def __init__(self, **kwargs): + pass + + assert infer_arguments(TestComponent) == {} + + +def test_missing_types(): + class TestComponent(PandasTransformComponent): + def __init__(self, *, argument, **kwargs): + pass + + assert infer_arguments(TestComponent) == { + "argument": Argument( + name="argument", + type=str, + optional=False, + ), + } + + +def test_types(): + class TestComponent(PandasTransformComponent): + def __init__( + self, + *, + str_argument: str, + int_argument: int, + float_argument: float, + bool_argument: bool, + dict_argument: dict, + list_argument: list, + **kwargs, + ): + pass + + assert infer_arguments(TestComponent) == { + "str_argument": Argument( + name="str_argument", + type=str, + optional=False, + ), + "int_argument": Argument( + name="int_argument", + type=int, + optional=False, + ), + "float_argument": Argument( + name="float_argument", + type=float, + optional=False, + ), + "bool_argument": Argument( + name="bool_argument", + type=bool, + optional=False, + ), + "dict_argument": Argument( + name="dict_argument", + type=dict, + optional=False, + ), + "list_argument": Argument( + name="list_argument", + type=list, + optional=False, + ), + } + + +def test_optional_types(): + class TestComponent(PandasTransformComponent): + def __init__( + self, + *, + str_argument: t.Optional[str] = "", + int_argument: t.Optional[int] = 1, + float_argument: t.Optional[float] = 1.0, + bool_argument: t.Optional[bool] = False, + dict_argument: t.Optional[dict] = None, + list_argument: t.Optional[list] = None, + **kwargs, + ): + pass + + assert infer_arguments(TestComponent) == { + "str_argument": Argument( + name="str_argument", + type=str, + optional=True, + default="", + ), + "int_argument": Argument( + name="int_argument", + type=int, + optional=True, + default=1, + ), + "float_argument": Argument( + name="float_argument", + type=float, + optional=True, + default=1.0, + ), + "bool_argument": Argument( + name="bool_argument", + type=bool, + optional=True, + default=False, + ), + "dict_argument": Argument( + name="dict_argument", + type=dict, + optional=True, + default=None, + ), + "list_argument": Argument( + name="list_argument", + type=list, + optional=True, + default=None, + ), + } + + +def test_parametrized_types_old(): + class TestComponent(PandasTransformComponent): + def __init__( + self, + *, + dict_argument: t.Dict[str, t.Any], + list_argument: t.Optional[t.List[int]] = None, + **kwargs, + ): + pass + + assert infer_arguments(TestComponent) == { + "dict_argument": Argument( + name="dict_argument", + type=dict, + optional=False, + default=None, + ), + "list_argument": Argument( + name="list_argument", + type=list, + optional=True, + default=None, + ), + } + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_parametrized_types_new(): + class TestComponent(PandasTransformComponent): + def __init__( + self, + *, + dict_argument: dict[str, t.Any], + list_argument: list[int] | None = None, + **kwargs, + ): + pass + + assert infer_arguments(TestComponent) == { + "dict_argument": Argument( + name="dict_argument", + type=dict, + optional=False, + default=None, + ), + "list_argument": Argument( + name="list_argument", + type=list, + optional=True, + default=None, + ), + } + + +def test_unsupported_complex_type(): + class TestComponent(PandasTransformComponent): + def __init__( + self, + *, + union_argument: t.Union[str, int], + **kwargs, + ): + pass + + with pytest.raises( + UnsupportedTypeAnnotation, + match="Fondant only supports simple types", + ): + infer_arguments(TestComponent) + + +def test_unsupported_custom_type(): + class CustomClass: + pass + + class TestComponent(PandasTransformComponent): + def __init__( + self, + *, + class_argument: CustomClass, + **kwargs, + ): + pass + + with pytest.raises( + UnsupportedTypeAnnotation, + match="Fondant only supports builtin types", + ): + infer_arguments(TestComponent) + + +def test_consumes_produces(): + class TestComponent(PandasTransformComponent): + def __init__(self, *, argument, consumes, **kwargs): + pass + + assert infer_arguments(TestComponent) == { + "argument": Argument( + name="argument", + type=str, + optional=False, + ), + }