Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add positional args support for fdl.Config. #449

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions fiddle/_src/building.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,21 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable,

def call_buildable(
buildable: config_lib.Buildable,
arguments: Dict[str, Any],
kwargs: Dict[str, Any],
*,
current_path: daglish.Path,
) -> Any:
make_message = functools.partial(_make_message, current_path, buildable,
arguments)
"""Run the __build__ method on a Buildable given keyword arguments."""
make_message = functools.partial(
_make_message, current_path, buildable, kwargs
)
args = []
for name in buildable.__signature_info__.positional_arg_names:
if name in kwargs:
args.append(kwargs.pop(name))
args.extend(kwargs.pop('__args__', []))
with reraised_exception.try_with_lazy_message(make_message):
return buildable.__build__(**arguments)
return buildable.__build__(*args, **kwargs)


# Define typing overload for `build(Partial[T])`
Expand Down
163 changes: 137 additions & 26 deletions fiddle/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import copy
import dataclasses
import functools
import inspect
import types
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union

from fiddle._src import daglish
from fiddle._src import history
Expand Down Expand Up @@ -242,10 +243,15 @@ def __init__(
arg_history.add_new_value('__fn_or_cls__', fn_or_cls)
super().__setattr__('__argument_history__', arg_history)
super().__setattr__('__argument_tags__', collections.defaultdict(set))
arguments = signatures.SignatureInfo.signature_binding(
fn_or_cls, *args, **kwargs
arguments, positional_arguments = (
signatures.SignatureInfo.signature_binding(fn_or_cls, *args, **kwargs)
)

if positional_arguments:
self.__arguments__['__args__'] = list(positional_arguments)
for i, value in enumerate(positional_arguments):
self[i] = value

for name, value in arguments.items():
setattr(self, name, value)

Expand All @@ -258,6 +264,7 @@ def __init__(
def __init_callable__(
self, fn_or_cls: Union['Buildable[T]', TypeOrCallableProducingT[T]]
) -> None:
"""Save information on `fn_or_cls` to the `Buildable`."""
if isinstance(fn_or_cls, Buildable):
raise ValueError(
'Using the Buildable constructor to convert a buildable to a new '
Expand All @@ -273,9 +280,11 @@ def __init_callable__(
super().__setattr__('__fn_or_cls__', fn_or_cls)
super().__setattr__('__arguments__', {})
signature = signatures.get_signature(fn_or_cls)
# Several attributes are computed automatically by SignatureInfo during
# `__post_init__`.
super().__setattr__(
'__signature_info__',
signatures.SignatureInfo(signature),
signatures.SignatureInfo(signature=signature),
)

def __init_subclass__(cls):
Expand Down Expand Up @@ -311,6 +320,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]:

def __getattr__(self, name: str):
"""Get parameter with given ``name``."""
if name == 'posargs':
if not self.__signature_info__.has_var_positional:
raise TypeError(
"This function doesn't have variadic positional arguments (*args). "
'Please set other (including positional-only) arguments by name.'
)

name = '__args__'
value = self.__arguments__.get(name, _UNSET_SENTINEL)

if value is not _UNSET_SENTINEL:
Expand Down Expand Up @@ -340,9 +357,39 @@ def __getattr__(self, name: str):
)
raise AttributeError(msg)

def __setitem__(self, key: Any, value: Any):
if not isinstance(key, (int, slice)):
raise TypeError(
'Setting arguments by index is only supported for variadic '
"arguments (*args), like my_config[4] = 'foo'."
)
if not self.__signature_info__.has_var_positional:
raise TypeError(
"This function doesn't have variadic positional arguments (*args). "
'Please set other (including positional-only) arguments by name.'
)

if '__args__' not in self.__arguments__:
self.__arguments__['__args__'] = []
self.__argument_history__.add_new_value('__args__', [])
self.__arguments__['__args__'][key] = value
self.__argument_history__.add_new_value(
'__args__', self.__arguments__['__args__']
)

def __getitem__(self, key: Any):
if not isinstance(key, slice):
raise TypeError(
'Getting arguments by index is only supported when using slice, '
'for example `v = my_config[:2]`, or using the `posargs` attr '
f'instead, like v = my_config[0]. Got {type(key)} type as key.'
)
return self.posargs[key]

def __setattr__(self, name: str, value: Any):
"""Sets parameter ``name`` to ``value``."""

if name == 'posargs':
name = '__args__'
self.__signature_info__.validate_param_name(name, self.__fn_or_cls__)

if isinstance(value, TaggedValueCls):
Expand All @@ -362,6 +409,8 @@ def __setattr__(self, name: str, value: Any):

def __delattr__(self, name):
"""Unsets parameter ``name``."""
if name == 'posargs':
name = '__args__'
try:
del self.__arguments__[name]
self.__argument_history__.add_deleted_value(name)
Expand Down Expand Up @@ -488,9 +537,7 @@ def __getstate__(self):
Dict of serialized state.
"""
result = dict(self.__dict__)
result['__signature_info__'] = signatures.SignatureInfo( # pytype: disable=wrong-arg-types
None, result['__signature_info__'].has_var_keyword
)
result['__signature_info__'] = signatures.SignatureInfo(None) # pytype: disable=wrong-arg-types
return result

def __setstate__(self, state) -> None:
Expand All @@ -503,8 +550,10 @@ def __setstate__(self, state) -> None:
"""
self.__dict__.update(state) # Support unpickle.
if self.__signature_info__.signature is None:
self.__signature_info__.signature = signatures.get_signature(
self.__fn_or_cls__
signature = signatures.get_signature(self.__fn_or_cls__)
super().__setattr__(
'__signature_info__',
signatures.SignatureInfo(signature=signature),
)


Expand Down Expand Up @@ -637,6 +686,51 @@ def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
return False


def _align_var_positional_args(
new_signature: inspect.Signature,
original_args: Dict[str, Any],
drop_invalid_args: bool,
) -> List[str]:
"""Returns the list of positional arguments to unpack."""
args_start_index = -1
for index, arg in enumerate(new_signature.parameters.keys()):
if arg not in original_args.keys():
args_start_index = index
break
if (args_start_index == -1 and original_args['__args__']) or (
len(new_signature.parameters)
< args_start_index + 1 + len(original_args['__args__'])
):
if not drop_invalid_args:
raise ValueError(
'new_callable does not have enough arguments when unpack'
f' *args: {original_args["__args__"]} from the original'
' buildable.'
)
arg_keys = list(new_signature.parameters.keys())[args_start_index:]
return arg_keys


def _expand_args_history(
arg_keys: List[str], buildable: Buildable
) -> List[List[history.HistoryEntry]]:
"""Returns expanded history entries for positional arguments."""
args_history = buildable.__argument_history__['__args__']
expaneded_history = []
for index in range(len(arg_keys)):
expanded_entries = []
for entry in args_history:
new_entry = copy.copy(entry)
if isinstance(new_entry.new_value, list):
if index >= len(new_entry.new_value):
new_entry.new_value = history.NOTSET
else:
new_entry.new_value = new_entry.new_value[index]
expanded_entries.append(new_entry)
expaneded_history.append(expanded_entries)
return expaneded_history


def update_callable(
buildable: Buildable,
new_callable: TypeOrCallableProducingT,
Expand Down Expand Up @@ -667,23 +761,40 @@ def update_callable(
# Note: can't call `setattr` on all the args to validate them, because that
# will result in duplicate history entries.
original_args = buildable.__arguments__
signature = signatures.get_signature(new_callable)
if any(
param.kind == param.VAR_POSITIONAL
for param in signature.parameters.values()
):
raise NotImplementedError(
'Variable positional arguments (aka `*args`) not supported.'
)
signature_info = signatures.SignatureInfo(signature)
object.__setattr__(
buildable,
'__signature_info__',
signature_info,
)
if not signature_info.has_var_keyword:
new_signature = signatures.get_signature(new_callable)
# Update the signature early so that we can set arguments by position.
# Otherwise, parameter validation logics would complain about argument
# name not exists.
object.__setattr__(buildable, '__signature__', new_signature)
new_signature_info = signatures.SignatureInfo(signature=new_signature)
original_signature_info = buildable.__signature_info__
object.__setattr__(buildable, '__signature_info__', new_signature_info)

if new_signature_info.has_var_positional:
# If only new callable has positional arguments
if not original_signature_info.has_var_positional:
buildable.__arguments__['__args__'] = []
buildable.__argument_history__.add_new_value('__args__', [])
else:
# If only the original config has *args
if original_signature_info.has_var_positional:
arg_keys = _align_var_positional_args(
new_signature, original_args, drop_invalid_args
)
expanded_history = _expand_args_history(arg_keys, buildable)

for arg, value, history_extries in zip(
arg_keys, original_args['__args__'], expanded_history
):
buildable.__setattr__(arg, value)
buildable.__argument_history__[arg] = history_extries
buildable.__delattr__('__args__')

if not new_signature_info.has_var_keyword:
invalid_args = [
arg for arg in original_args.keys() if arg not in signature.parameters
arg
for arg in original_args.keys()
if arg not in new_signature.parameters and arg != '__args__'
]
if invalid_args:
if drop_invalid_args:
Expand Down
Loading