-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add component argument inference #763
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import inspect | ||
import typing as t | ||
|
||
from fondant.component import Component | ||
from fondant.core.component_spec import Argument | ||
from fondant.core.exceptions import UnsupportedTypeAnnotation | ||
|
||
BUILTIN_TYPES = [str, int, float, bool, dict, list] | ||
|
||
|
||
def annotation_to_type(annotation: t.Any) -> t.Type: | ||
"""Extract the simple built-in type from an Annotation. | ||
|
||
Examples: | ||
dict[str, int] -> dict | ||
t.Optional[str] -> str | ||
|
||
Args: | ||
annotation: Annotation of an argument as returned by inspect.signature | ||
|
||
Raises: | ||
UnsupportedTypeAnnotation: If the annotation is not simple or not based on a built-in type | ||
|
||
""" | ||
# If no type annotation is present, default to str | ||
if annotation == inspect.Parameter.empty: | ||
return str | ||
|
||
# Unpack the annotation until we get a simple type. | ||
# This removes complex structures such as Optional | ||
while t.get_origin(annotation) not in [*BUILTIN_TYPES, None]: | ||
# Filter out NoneType values (Optional[x] is represented as Union[x, NoneType] | ||
annotation_args = [ | ||
arg for arg in t.get_args(annotation) if arg is not type(None) | ||
] | ||
|
||
# Multiple arguments remaining (eg. Union[str, int]) | ||
# Raise error since we cannot infer type unambiguously | ||
if len(annotation_args) > 1: | ||
msg = ( | ||
f"Fondant only supports simple types for component arguments." | ||
f"Expected one of {BUILTIN_TYPES}, received {annotation} instead." | ||
) | ||
raise UnsupportedTypeAnnotation(msg) | ||
|
||
annotation = annotation_args[0] | ||
|
||
# Remove any subscription (eg. dict[str, int] -> dict) | ||
annotation = t.get_origin(annotation) or annotation | ||
|
||
# Check for classes not supported as argument | ||
if annotation not in BUILTIN_TYPES: | ||
msg = ( | ||
f"Fondant only supports builtin types for component arguments." | ||
f"Expected one of {BUILTIN_TYPES}, received {annotation} instead." | ||
) | ||
raise UnsupportedTypeAnnotation(msg) | ||
|
||
return annotation | ||
|
||
|
||
def is_optional(parameter: inspect.Parameter) -> bool: | ||
"""Check if an inspect.Parameter is optional. We check this based on the presence of a | ||
default value instead of based on the type, since this is more trustworthy. | ||
""" | ||
return parameter.default != inspect.Parameter.empty | ||
|
||
|
||
def get_default(parameter: inspect.Parameter) -> t.Any: | ||
"""Get the default value from an inspect.Parameter.""" | ||
if parameter.default == inspect.Parameter.empty: | ||
return None | ||
return parameter.default | ||
|
||
|
||
def parameter_to_argument(parameter: inspect.Parameter) -> Argument: | ||
"""Translate an inspect.Parameter into a Fondant Argument.""" | ||
return Argument( | ||
name=parameter.name, | ||
type=annotation_to_type(parameter.annotation), | ||
optional=is_optional(parameter), | ||
default=get_default(parameter), | ||
) | ||
|
||
|
||
def infer_arguments(component: t.Type[Component]) -> t.Dict[str, Argument]: | ||
"""Infer the user arguments from a Python Component class. | ||
Default arguments are skipped. | ||
|
||
Args: | ||
component: Component class to inspect. | ||
""" | ||
signature = inspect.signature(component) | ||
|
||
arguments = {} | ||
for name, parameter in signature.parameters.items(): | ||
# Skip non-user arguments | ||
if name in ["self", "consumes", "produces", "kwargs"]: | ||
continue | ||
|
||
arguments[name] = parameter_to_argument(parameter) | ||
|
||
return arguments |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😍 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
import sys | ||
import typing as t | ||
|
||
import pytest | ||
from fondant.component import PandasTransformComponent | ||
from fondant.core.component_spec import Argument | ||
from fondant.core.exceptions import UnsupportedTypeAnnotation | ||
from fondant.pipeline.argument_inference import infer_arguments | ||
|
||
|
||
def test_no_init(): | ||
class TestComponent(PandasTransformComponent): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == {} | ||
|
||
|
||
def test_no_arguments(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == {} | ||
|
||
|
||
def test_missing_types(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__(self, *, argument, **kwargs): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"argument": Argument( | ||
name="argument", | ||
type=str, | ||
optional=False, | ||
), | ||
} | ||
|
||
|
||
def test_types(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
str_argument: str, | ||
int_argument: int, | ||
float_argument: float, | ||
bool_argument: bool, | ||
dict_argument: dict, | ||
list_argument: list, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"str_argument": Argument( | ||
name="str_argument", | ||
type=str, | ||
optional=False, | ||
), | ||
"int_argument": Argument( | ||
name="int_argument", | ||
type=int, | ||
optional=False, | ||
), | ||
"float_argument": Argument( | ||
name="float_argument", | ||
type=float, | ||
optional=False, | ||
), | ||
"bool_argument": Argument( | ||
name="bool_argument", | ||
type=bool, | ||
optional=False, | ||
), | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=False, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=False, | ||
), | ||
} | ||
|
||
|
||
def test_optional_types(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
str_argument: t.Optional[str] = "", | ||
int_argument: t.Optional[int] = 1, | ||
float_argument: t.Optional[float] = 1.0, | ||
bool_argument: t.Optional[bool] = False, | ||
dict_argument: t.Optional[dict] = None, | ||
list_argument: t.Optional[list] = None, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"str_argument": Argument( | ||
name="str_argument", | ||
type=str, | ||
optional=True, | ||
default="", | ||
), | ||
"int_argument": Argument( | ||
name="int_argument", | ||
type=int, | ||
optional=True, | ||
default=1, | ||
), | ||
"float_argument": Argument( | ||
name="float_argument", | ||
type=float, | ||
optional=True, | ||
default=1.0, | ||
), | ||
"bool_argument": Argument( | ||
name="bool_argument", | ||
type=bool, | ||
optional=True, | ||
default=False, | ||
), | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=True, | ||
default=None, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=True, | ||
default=None, | ||
), | ||
} | ||
|
||
|
||
def test_parametrized_types_old(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
dict_argument: t.Dict[str, t.Any], | ||
list_argument: t.Optional[t.List[int]] = None, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=False, | ||
default=None, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=True, | ||
default=None, | ||
), | ||
} | ||
|
||
|
||
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") | ||
def test_parametrized_types_new(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
dict_argument: dict[str, t.Any], | ||
list_argument: list[int] | None = None, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=False, | ||
default=None, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=True, | ||
default=None, | ||
), | ||
} | ||
|
||
|
||
def test_unsupported_complex_type(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
union_argument: t.Union[str, int], | ||
**kwargs, | ||
): | ||
pass | ||
|
||
with pytest.raises( | ||
UnsupportedTypeAnnotation, | ||
match="Fondant only supports simple types", | ||
): | ||
infer_arguments(TestComponent) | ||
|
||
|
||
def test_unsupported_custom_type(): | ||
class CustomClass: | ||
pass | ||
|
||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
class_argument: CustomClass, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
with pytest.raises( | ||
UnsupportedTypeAnnotation, | ||
match="Fondant only supports builtin types", | ||
): | ||
infer_arguments(TestComponent) | ||
|
||
|
||
def test_consumes_produces(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__(self, *, argument, consumes, **kwargs): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"argument": Argument( | ||
name="argument", | ||
type=str, | ||
optional=False, | ||
), | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice !