Skip to content

Commit

Permalink
Add positional args support for fdl.Config
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549731020
  • Loading branch information
panzhufeng authored and copybara-github committed Jul 26, 2023
1 parent c882b09 commit 78b48c5
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 33 deletions.
15 changes: 11 additions & 4 deletions fiddle/_src/building.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,21 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable,

def call_buildable(
buildable: config_lib.Buildable,
arguments: Dict[str, Any],
kwargs: Dict[str, Any],
*,
current_path: daglish.Path,
) -> Any:
make_message = functools.partial(_make_message, current_path, buildable,
arguments)
"""Run the __build__ method on a Buildable given keyword arguments."""
make_message = functools.partial(
_make_message, current_path, buildable, kwargs
)
args = []
for name in buildable.__positional_arg_names__:
if name in kwargs:
args.append(kwargs.pop(name))
args.extend(kwargs.pop('__args__', []))
with reraised_exception.try_with_lazy_message(make_message):
return buildable.__build__(**arguments)
return buildable.__build__(*args, **kwargs)


# Define typing overload for `build(Partial[T])`
Expand Down
136 changes: 121 additions & 15 deletions fiddle/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import itertools
import logging
import types
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union

from fiddle._src import arg_factory
from fiddle._src import daglish
Expand Down Expand Up @@ -223,6 +223,8 @@ class Buildable(Generic[T], metaclass=abc.ABCMeta):
__arguments__: Dict[str, Any]
__argument_history__: history.History
__argument_tags__: Dict[str, Set[tag_type.TagType]]
__positional_arg_names__: List[str]
__has_var_positional__: bool
_has_var_keyword: bool

def __init__(
Expand All @@ -245,19 +247,23 @@ def __init__(
super().__setattr__('__argument_history__', arg_history)
super().__setattr__('__argument_tags__', collections.defaultdict(set))

positional_arguments = ()
arguments = signature.bind_partial(*args, **kwargs).arguments
for name in list(arguments.keys()): # Make a copy in case we mutate.
param = signature.parameters[name]
if param.kind == param.VAR_POSITIONAL:
# TODO(b/197367863): Add *args support.
err_msg = (
'Variable positional arguments (aka `*args`) not supported. '
f'Found param `{name}` in `{fn_or_cls}`.'
)
raise NotImplementedError(err_msg)
positional_arguments = arguments.pop(param.name)
elif param.kind == param.VAR_KEYWORD:
arguments.update(arguments.pop(param.name))

if positional_arguments:
self.__arguments__['__args__'] = list(positional_arguments)
self.__argument_history__.add_new_value(
'__args__', self.__arguments__['__args__']
)

for i, value in enumerate(positional_arguments):
self[i] = value
for name, value in arguments.items():
setattr(self, name, value)

Expand Down Expand Up @@ -286,10 +292,25 @@ def __init_callable__(
super().__setattr__('__arguments__', {})
signature = signatures.get_signature(fn_or_cls)
super().__setattr__('__signature__', signature)
has_var_keyword = any(
param.kind == param.VAR_KEYWORD
for param in signature.parameters.values()
)

# If *args exists, we must pass things before it in positional format. This
# list tracks those arguments.
maybe_positional_args = []

positional_only_args = []
has_var_positional, has_var_keyword = False, False
for param in signature.parameters.values():
if param.kind == param.VAR_POSITIONAL:
has_var_positional = True
positional_only_args.extend(maybe_positional_args)
elif param.kind == param.VAR_KEYWORD:
has_var_keyword = True
elif param.kind == param.POSITIONAL_ONLY:
positional_only_args.append(param.name)
elif param.kind == param.POSITIONAL_OR_KEYWORD:
maybe_positional_args.append(param.name)
super().__setattr__('__positional_arg_names__', positional_only_args)
super().__setattr__('__has_var_positional__', has_var_positional)
super().__setattr__('_has_var_keyword', has_var_keyword)
return signature

Expand Down Expand Up @@ -326,6 +347,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]:

def __getattr__(self, name: str):
"""Get parameter with given ``name``."""
if name == 'posargs':
if not self.__has_var_positional__:
raise TypeError(
"This function doesn't have variadic positional arguments (*args). "
'Please set other (including positional-only) arguments by name.'
)

name = '__args__'
value = self.__arguments__.get(name, _UNSET_SENTINEL)

if value is not _UNSET_SENTINEL:
Expand Down Expand Up @@ -387,6 +416,34 @@ def __validate_param_name__(self, name) -> None:
)
raise TypeError(err_msg)

def __setitem__(self, key: Any, value: Any):
if not isinstance(key, (int, slice)):
raise TypeError(
'Setting arguments by index is only supported for variadic '
"arguments (*args), like my_config[4] = 'foo'."
)
if not self.__has_var_positional__:
raise TypeError(
"This function doesn't have variadic positional arguments (*args). "
'Please set other (including positional-only) arguments by name.'
)

# In the future, use a specialized history-tracking list.
if '__args__' not in self.__arguments__:
self.__arguments__['__args__'] = []
self.__argument_history__.add_new_value('__args__', [])
args = self.__arguments__['__args__']
args[key] = value

def __getitem__(self, key: Any):
if not isinstance(key, slice):
raise TypeError(
'Getting arguments by index is only supported when using slice, '
'for example `v = my_config[:2]`, or using the `posargs` attr '
f'instead, like v = my_config[0]. Got {type(key)} type as key.'
)
return self.posargs[key]

def __setattr__(self, name: str, value: Any):
"""Sets parameter ``name`` to ``value``."""

Expand Down Expand Up @@ -950,13 +1007,63 @@ def update_callable(
# will result in duplicate history entries.
original_args = buildable.__arguments__
signature = signatures.get_signature(new_callable)
# Update the signature early so that we can set arguments by position
object.__setattr__(buildable, '__signature__', signature)

if any(
param.kind == param.VAR_POSITIONAL
for param in signature.parameters.values()
):
raise NotImplementedError(
'Variable positional arguments (aka `*args`) not supported.'
)
# Both callables have *args
if buildable.__has_var_positional__:
args_ptr = -1
consumed_args = 0
for idx, arg in enumerate(signature.parameters.keys()):
if arg not in original_args.keys():
args_ptr = idx
break
all_args_key = list(signature.parameters.keys())
while args_ptr < len(all_args_key):
key = all_args_key[args_ptr]
param = signature.parameters[key]
if param.kind == param.VAR_POSITIONAL:
break
else:
value = original_args['__args__'][args_ptr]
buildable.__setattr__(key, value)
args_ptr += 1
consumed_args += 1

buildable.__arguments__['__args__'] = buildable.__arguments__['__args__'][
consumed_args:
]
# Only new callable has *args
else:
object.__setattr__(buildable, '__args__', [])
buildable.__argument_history__.add_new_value('__args__', [])
else:
# If only the original config has *args
if buildable.__has_var_positional__:
args_start_at = -1
for idx, arg in enumerate(signature.parameters.keys()):
if arg not in original_args.keys():
args_start_at = idx
break

if len(signature.parameters) < args_start_at + len(
original_args['__args__']
):
if not drop_invalid_args:
raise ValueError(
'new_callable does not have enough arguments when unpack *args: '
f'{original_args["__args__"]} from the original buildable.'
)
arg_keys = list(signature.parameters.keys())[args_start_at:]
for arg, value in zip(arg_keys, original_args['__args__']):
buildable.__setattr__(arg, value)
del buildable.__args__
object.__setattr__(buildable, '__has_var_positional__', False)

has_var_keyword = any(
param.kind == param.VAR_KEYWORD for param in signature.parameters.values()
)
Expand All @@ -976,7 +1083,6 @@ def update_callable(
)

object.__setattr__(buildable, '__fn_or_cls__', new_callable)
object.__setattr__(buildable, '__signature__', signature)
object.__setattr__(buildable, '_has_var_keyword', has_var_keyword)
buildable.__argument_history__.add_new_value('__fn_or_cls__', new_callable)

Expand Down
120 changes: 109 additions & 11 deletions fiddle/_src/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,84 @@ def test_config_for_functions_with_var_args_and_kwargs(self):
'kwargs': 'kwarg_called_kwarg'
})

# "args" below refer to positional arguments, typcally `*args``
def test_args_config_access(self):
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')

with self.subTest('ordered_arguments'):
self.assertEqual(
fdl.ordered_arguments(fn_config),
{
'arg1': 'foo',
'__args__': ['bar', 'baz'],
},
)

with self.subTest('posargs_access'):
self.assertEqual(fn_config.posargs[0], 'bar')
self.assertEqual(fn_config.posargs[1], 'baz')
self.assertSequenceEqual(fn_config.posargs, ['bar', 'baz'])

with self.subTest('index_access'):
with self.assertRaisesRegex(
TypeError,
'Getting arguments by index is only supported when using slice',
):
_ = fn_config[0]

with self.subTest('slice_access'):
self.assertEmpty(fn_config[:0])
self.assertSequenceEqual(fn_config[:1], ['bar'])
self.assertSequenceEqual(fn_config[:], ['bar', 'baz'])

def test_args_config_posargs_append(self):
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
fn_config.posargs.append('foo')
self.assertSequenceEqual(fn_config.posargs, ['bar', 'baz', 'foo'])

def test_args_config_slice_mutation(self):
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
self.assertSequenceEqual(fn_config[:], ['bar', 'baz'])
fn_config[:1] = ['zero', 'one']
self.assertSequenceEqual(fn_config[:], ['zero', 'one', 'baz'])

def test_args_config_shallow_copy(self):
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
self.assertLen(fn_config[:], 2)
a_copy = fn_config[:]
a_copy.append('foo')
self.assertLen(fn_config[:], 2)
self.assertLen(a_copy, 3)

def test_index_mutation(self):
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
fn_config[0] = 'foo'
self.assertEqual(fn_config.posargs[0], 'foo')
fn_config[-1] = 'last'
self.assertLen(fn_config.posargs, 2)
self.assertEqual(fn_config.posargs[1], 'last')
self.assertEqual(fn_config.posargs[-1], 'last')

def test_index_out_of_range(self):
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
self.assertLen(fn_config[:], 2)
with self.assertRaisesRegex(
IndexError, 'list assignment index out of range'
):
fn_config[2] = 'index-2'

def test_args_config_build(self):
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
fn_args = fdl.build(fn_config)
self.assertEqual(
fn_args,
{
'arg1': 'foo',
'args': ('bar', 'baz'),
'kwarg1': None,
},
)

def test_config_for_dicts(self):
dict_config = fdl.Config(dict, a=1, b=2)
dict_config.c = 3
Expand Down Expand Up @@ -858,12 +936,6 @@ def test_nonexistent_var_args_parameter_error(self):
with self.assertRaisesRegex(TypeError, expected_msg):
fn_config.args = (1, 2, 3)

def test_unsupported_var_args_error(self):
expected_msg = (r'Variable positional arguments \(aka `\*args`\) not '
r'supported\.')
with self.assertRaisesRegex(NotImplementedError, expected_msg):
fdl.Config(fn_with_var_args, 1, 2, 3)

def test_build_inside_build(self):

def inner_build(x: int) -> str:
Expand Down Expand Up @@ -1211,11 +1283,37 @@ def test_update_callable_new_kwargs(self):
}
}, fdl.build(cfg))

def test_update_callable_varargs(self):
cfg = fdl.Config(fn_with_var_kwargs, 1, 2)
with self.assertRaisesRegex(NotImplementedError,
'Variable positional arguments'):
fdl.update_callable(cfg, fn_with_var_args_and_kwargs)
def test_update_args_to_args(self):
cfg = fdl.Config(fn_with_var_args, 1, 2, kwarg1=3)
fdl.update_callable(cfg, fn_with_var_args_and_kwargs)
self.assertEqual(
cfg.__arguments__, {'arg1': 1, '__args__': [2], 'kwarg1': 3}
)
self.assertEqual(
{'arg1': 1, 'args': (2,), 'kwarg1': 3, 'kwargs': {}}, fdl.build(cfg)
)

def test_update_args_to_no_args(self):
cfg = fdl.Config(fn_with_var_args, 1, 2, kwarg1=3)
fdl.update_callable(cfg, basic_fn)
cfg.arg2 = 22
self.assertEqual(cfg.__arguments__, {'arg1': 1, 'arg2': 22, 'kwarg1': 3})
self.assertEqual(
{'arg1': 1, 'arg2': 22, 'kwarg1': 3, 'kwarg2': None}, fdl.build(cfg)
)

def test_update_args_kwargs(self):
def my_fn(*args, **kwargs):
del args, kwargs

cfg = fdl.Config(my_fn, 1, 2, 3, kwarg1=4, kwarg2=5)
cfg.posargs[0] = 10
cfg.kwarg1 = 40
config_lib.update_callable(cfg, fn_with_var_args_and_kwargs)
self.assertEqual(
cfg.__arguments__,
{'arg1': 10, '__args__': [2, 3], 'kwarg1': 40, 'kwarg2': 5},
)

def test_get_callable(self):
cfg = fdl.Config(basic_fn)
Expand Down
13 changes: 10 additions & 3 deletions fiddle/_src/mutate_buildable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@

from fiddle._src import config

_buildable_internals_keys = ('__fn_or_cls__', '__signature__', '__arguments__',
'_has_var_keyword', '__argument_tags__',
'__argument_history__')
_buildable_internals_keys = (
'__fn_or_cls__',
'__signature__',
'__arguments__',
'_has_var_keyword',
'__argument_tags__',
'__argument_history__',
'__has_var_positional__',
'__positional_arg_names__',
)


def move_buildable_internals(*, source: config.Buildable,
Expand Down

0 comments on commit 78b48c5

Please sign in to comment.