From f1eeaf8f5204de08295b4263fb9d871691845473 Mon Sep 17 00:00:00 2001 From: Andrew Liounis Date: Mon, 17 Apr 2023 08:36:21 -0400 Subject: [PATCH] [#241] fixes the issue where defaults from the config path were not being passed when using subparsers (field_wrapper.py). Additionally fixes an issue where subdataclasses were requiring every value to be defaulted in the config path, instead of falling back to the default in the dataclass definition if it wasn't (dataclass_wrapper.py) --- simple_parsing/wrappers/dataclass_wrapper.py | 11 +++++++++-- simple_parsing/wrappers/field_wrapper.py | 7 ++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index 2c00c98d..fdcbc574 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -264,8 +264,15 @@ def defaults(self) -> list[DataclassT | dict[str, Any] | None | Literal[argparse self._defaults = [] for default in self.parent.defaults: if default not in (None, argparse.SUPPRESS): - default = getattr(default, self.name) - self._defaults.append(default) + # we need to check here if the default has been provided. + # If not we'll use the default_value option function + if hasattr(default, self.name): + default = getattr(default, self.name) + else: + default = utils.default_value(self._field) + if default is MISSING: + continue + self._defaults.append(default) else: default_field_value = utils.default_value(self._field) if default_field_value is MISSING: diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 0becac2b..362e6f65 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -1030,7 +1030,12 @@ def add_subparsers(self, parser: ArgumentParser): # Just for typing correctness, as we didn't explicitly change # the return type of subparsers.add_parser method.) subparser = cast("ArgumentParser", subparser) - subparser.add_arguments(dataclass_type, dest=self.dest) + # we need to propagate the defaults down to the sub dataclass if they've been set. + # there may need to be some error handling here in case the use has specified the wrong values for the default. + if isinstance(self.default, dict) and self.default.get(subcommand, None) is not None: + subparser.add_arguments(dataclass_type, dest=self.dest, default=dataclass_type(**self.default[subcommand])) + else: + subparser.add_arguments(dataclass_type, dest=self.dest) def equivalent_argparse_code(self): arg_options = self.arg_options.copy()