diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index e935ed89..03b002b2 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -16,6 +16,7 @@ from typing_extensions import Protocol from simple_parsing.utils import ( + Dataclass, DataclassT, all_subclasses, get_args, @@ -173,6 +174,7 @@ class SerializableMixin: >>> assert config == config_ """ + # TODO: Get rid of this, use __subclasses__ method of the type instead. subclasses: ClassVar[list[type[D]]] = [] decode_into_subclasses: ClassVar[bool] = False @@ -684,6 +686,10 @@ def dumps_yaml(dc, dump_fn: DumpsFn | None = None, **kwargs) -> str: DC_TYPE_KEY = "_type_" +def get_dc_type_path(dc: type[Dataclass]) -> str: + return dc.__module__ + "." + dc.__qualname__ + + def to_dict( dc: DataclassT, dict_factory: type[dict] = dict, @@ -706,9 +712,8 @@ def to_dict( d: dict[str, Any] = dict_factory() if save_dc_types: - class_name = dc.__class__.__qualname__ - module = type(dc).__module__ - if "" in class_name: + class_path = get_dc_type_path(type(dc)) + if "" in class_path: # Don't save the type of function-scoped dataclasses. warnings.warn( RuntimeWarning( @@ -718,7 +723,7 @@ def to_dict( ) ) else: - d[DC_TYPE_KEY] = module + "." + class_name + d[DC_TYPE_KEY] = class_path for f in fields(dc): name = f.name diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index ec81bdd1..5fac766a 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -289,8 +289,9 @@ def parse_known_args( # default Namespace built from parser defaults if namespace is None: namespace = Namespace() + if self.config_path: - if isinstance(self.config_path, Path): + if isinstance(self.config_path, (str, Path)): config_paths = [self.config_path] else: config_paths = self.config_path @@ -387,33 +388,47 @@ def set_defaults(self, config_path: str | Path | None = None, **kwargs: Any) -> # The kwargs that are set in the dataclasses, rather than on the namespace. kwarg_defaults_set_in_dataclasses = {} for wrapper in self._wrappers: - if wrapper.dest in kwargs: - default_for_dataclass = kwargs[wrapper.dest] - - if isinstance(default_for_dataclass, (str, Path)): - default_for_dataclass = read_file(path=default_for_dataclass) - elif not isinstance(default_for_dataclass, dict) and not dataclasses.is_dataclass( - default_for_dataclass - ): - raise ValueError( - f"Got a default for field {wrapper.dest} that isn't a dataclass, dict or " - f"path: {default_for_dataclass}" - ) + if wrapper.dest not in kwargs: + # The dataclass doesn't have any values in the loaded config dict. + continue + + default_in_config_for_dc = kwargs[wrapper.dest] + # TODO: This interacts weirdly with subgroups! + if wrapper.was_subgroup: + # FIXME: Debugging + assert False, (wrapper, default_in_config_for_dc) + + if isinstance(default_in_config_for_dc, dict): + # wrapper.set_default(config_default_for_dc) + pass + elif isinstance(default_in_config_for_dc, (str, Path)): + # BUG: Need to check if the field is a subgroup first! + default_in_config_for_dc = read_file(path=default_in_config_for_dc) + # wrapper.set_default(config_default_for_dc) + elif is_dataclass_instance(default_in_config_for_dc): + # wrapper.set_default(config_default_for_dc) + pass + else: + raise ValueError( + f"Got a default for field {wrapper.dest} that isn't a dataclass, dict or " + f"Path: {default_in_config_for_dc}" + ) - # Set the .default attribute on the DataclassWrapper (which also updates the - # defaults of the fields and any nested dataclass fields). - wrapper.set_default(default_for_dataclass) + # Set the .default attribute on the DataclassWrapper (which also updates the + # defaults of the fields and any nested dataclass fields). + wrapper.set_default(default_in_config_for_dc) + + # It's impossible for multiple wrappers in kwargs to have the same destination. + assert wrapper.dest not in kwarg_defaults_set_in_dataclasses + value_for_constructor_arguments = ( + default_in_config_for_dc + if isinstance(default_in_config_for_dc, dict) + else dataclasses.asdict(default_in_config_for_dc) + ) + kwarg_defaults_set_in_dataclasses[wrapper.dest] = value_for_constructor_arguments + # Remove this from the **kwargs, so they don't get set on the namespace. + kwargs.pop(wrapper.dest) - # It's impossible for multiple wrappers in kwargs to have the same destination. - assert wrapper.dest not in kwarg_defaults_set_in_dataclasses - value_for_constructor_arguments = ( - default_for_dataclass - if isinstance(default_for_dataclass, dict) - else dataclasses.asdict(default_for_dataclass) - ) - kwarg_defaults_set_in_dataclasses[wrapper.dest] = value_for_constructor_arguments - # Remove this from the **kwargs, so they don't get set on the namespace. - kwargs.pop(wrapper.dest) # TODO: Stop using a defaultdict for the very important `self.constructor_arguments`! self.constructor_arguments = dict_union( self.constructor_arguments, @@ -643,16 +658,25 @@ def _resolve_subgroups( flags = subgroup_field.option_strings argument_options = subgroup_field.arg_options + argparse_default = argument_options.get("default") if subgroup_field.subgroup_default is dataclasses.MISSING: assert argument_options["required"] + assert argparse_default is None + elif isinstance(argparse_default, dict): + # BUG: We have a default value that is a dictionary. This is probably because + # .set_defaults was called before we could resolve the subgroups, and so the + # subgroup_field has this weird default value of a dictionary. + # TODO: Figure out how to fix this! + argument_options["default"] = subgroup_field.subgroup_default else: - assert argument_options["default"] is subgroup_field.subgroup_default + assert argparse_default is subgroup_field.subgroup_default, ( + argparse_default, + subgroup_field.subgroup_default, + ) assert not is_dataclass_instance(argument_options["default"]) - # TODO: Do we really need to care about this "SUPPRESS" stuff here? if argparse.SUPPRESS in subgroup_field.parent.defaults: assert argument_options["default"] is argparse.SUPPRESS - argument_options["default"] = argparse.SUPPRESS logger.debug( f"Adding subgroup argument: add_argument(*{flags} **{str(argument_options)})" diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index 2c00c98d..bea9651c 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -9,10 +9,13 @@ from dataclasses import MISSING from logging import getLogger from typing import Any, Callable, Generic, TypeVar, cast +import warnings import docstring_parser as dp from typing_extensions import Literal +from simple_parsing.helpers.serialization.serializable import DC_TYPE_KEY + from .. import docstring, utils from ..utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type from .field_wrapper import FieldWrapper @@ -274,6 +277,10 @@ def defaults(self) -> list[DataclassT | dict[str, Any] | None | Literal[argparse self._defaults = [default_field_value] return self._defaults + @property + def was_subgroup(self) -> bool: + return self._field is not None and "subgroups" in self._field.metadata + @defaults.setter def defaults(self, value: list[DataclassT]): self._defaults = value @@ -288,24 +295,55 @@ def default(self) -> DataclassT | None: def set_default(self, value: DataclassT | dict | None): """Sets the default values for the arguments of the fields of this dataclass.""" - if value is not None and not isinstance(value, dict): - field_default_values = dataclasses.asdict(value) - else: - field_default_values = value - self._default = value - if field_default_values is None: + if value is None: + self._default = None return + + field_default_values = ( + value if isinstance(value, dict) else dataclasses.asdict(value) + ).copy() + if DC_TYPE_KEY in field_default_values: + from simple_parsing.helpers.serialization.serializable import _locate + + dc_type_qualpath = field_default_values.pop(DC_TYPE_KEY) + assert isinstance(dc_type_qualpath, str) + dataclass_type = _locate(dc_type_qualpath) + logger.debug( + f"The default value dictionary has the {DC_TYPE_KEY=} of {dc_type_qualpath}" + ) + if self.dataclass_fn is not dataclass_type: + logger.debug( + f"Overwriting the dataclass_fn from {self.dataclass_fn} to {dataclass_type}" + ) + self.dataclass_fn = dataclass_type + + self._default = field_default_values + for field_wrapper in self.fields: if field_wrapper.name not in field_default_values: continue # Manually set the default value for this argument. - field_default_value = field_default_values[field_wrapper.name] + field_default_value = field_default_values.pop(field_wrapper.name) + if field_wrapper.is_subgroup and isinstance(field_default_value, dict): + # TODO: FIX THIS: Perhaps we need to store this "default" for later, so that we can + # apply them to the DataclassWrapper that will be created for this field once + # subgroups are resolved. + assert False, field_default_value + field_wrapper.set_default(field_default_value) for nested_dataclass_wrapper in self._children: if nested_dataclass_wrapper.name not in field_default_values: continue - field_default_value = field_default_values[nested_dataclass_wrapper.name] + field_default_value = field_default_values.pop(nested_dataclass_wrapper.name) nested_dataclass_wrapper.set_default(field_default_value) + if field_default_values: + warnings.warn( + RuntimeWarning( + f"Got some unexpected leftover default values for dataclass at path " + f"{self.dest}: " + f"{field_default_values}" + ) + ) @property def title(self) -> str: diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 0becac2b..108d4bf6 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -797,6 +797,17 @@ def _get_value(dataclass_default: utils.Dataclass | dict, name: str) -> Any: def set_default(self, value: Any): logger.debug(f"The field {self.name} has its default manually set to a value of {value}.") + # TODO: Should we do something different if we're a subgroup field and this default value + # is a dict (which happens when using a config path)? This field is intended to just parse + # the subgroup choice. + if self.is_subgroup and isinstance(value, dict): + raise RuntimeError( + f"Something weird happened: We don't expect this subgroup field at {self.dest!r} " + f"to get a dict default value (it's just used to parse the choice of subgroup " + f"(a string)) " + f"Received default value of {value}" + ) + # logger.debug("Setting the default value for a subgroup field to a dict!") self._default = value @property diff --git a/test/subgroups/test_issue_276.py b/test/subgroups/test_issue_276.py index 52a6f627..21f254ab 100644 --- a/test/subgroups/test_issue_276.py +++ b/test/subgroups/test_issue_276.py @@ -46,7 +46,7 @@ class C: "", C(c=B2(a=0, b=1)), True, - marks=pytest.mark.xfail(strict=True, raises=AssertionError), + # marks=pytest.mark.xfail(strict=True, raises=AssertionError), ), pytest.param( "--c.a=1", diff --git a/test/test_examples.py b/test/test_examples.py index d20967c4..9a79b14a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -38,6 +38,7 @@ def set_prog(prog_name: str, args: str): @pytest.fixture def assert_equals_stdout(capsys): + # TODO: Replace all this with a FileRegression. def strip(string): return "".join(string.split())