Skip to content

Commit

Permalink
Debugging Issue #276
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 26, 2023
1 parent fc0c977 commit 403b81f
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 42 deletions.
13 changes: 9 additions & 4 deletions simple_parsing/helpers/serialization/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing_extensions import Protocol

from simple_parsing.utils import (
Dataclass,
DataclassT,
all_subclasses,
get_args,
Expand Down Expand Up @@ -173,6 +174,7 @@ class SerializableMixin:
>>> assert config == config_
"""

# TODO: Get rid of this, use __subclasses__ method of the type instead.
subclasses: ClassVar[list[type[D]]] = []
decode_into_subclasses: ClassVar[bool] = False

Expand Down Expand Up @@ -684,6 +686,10 @@ def dumps_yaml(dc, dump_fn: DumpsFn | None = None, **kwargs) -> str:
DC_TYPE_KEY = "_type_"


def get_dc_type_path(dc: type[Dataclass]) -> str:
return dc.__module__ + "." + dc.__qualname__


def to_dict(
dc: DataclassT,
dict_factory: type[dict] = dict,
Expand All @@ -706,9 +712,8 @@ def to_dict(
d: dict[str, Any] = dict_factory()

if save_dc_types:
class_name = dc.__class__.__qualname__
module = type(dc).__module__
if "<locals>" in class_name:
class_path = get_dc_type_path(type(dc))
if "<locals>" in class_path:
# Don't save the type of function-scoped dataclasses.
warnings.warn(
RuntimeWarning(
Expand All @@ -718,7 +723,7 @@ def to_dict(
)
)
else:
d[DC_TYPE_KEY] = module + "." + class_name
d[DC_TYPE_KEY] = class_path

for f in fields(dc):
name = f.name
Expand Down
82 changes: 53 additions & 29 deletions simple_parsing/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,9 @@ def parse_known_args(
# default Namespace built from parser defaults
if namespace is None:
namespace = Namespace()

if self.config_path:
if isinstance(self.config_path, Path):
if isinstance(self.config_path, (str, Path)):
config_paths = [self.config_path]
else:
config_paths = self.config_path
Expand Down Expand Up @@ -387,33 +388,47 @@ def set_defaults(self, config_path: str | Path | None = None, **kwargs: Any) ->
# The kwargs that are set in the dataclasses, rather than on the namespace.
kwarg_defaults_set_in_dataclasses = {}
for wrapper in self._wrappers:
if wrapper.dest in kwargs:
default_for_dataclass = kwargs[wrapper.dest]

if isinstance(default_for_dataclass, (str, Path)):
default_for_dataclass = read_file(path=default_for_dataclass)
elif not isinstance(default_for_dataclass, dict) and not dataclasses.is_dataclass(
default_for_dataclass
):
raise ValueError(
f"Got a default for field {wrapper.dest} that isn't a dataclass, dict or "
f"path: {default_for_dataclass}"
)
if wrapper.dest not in kwargs:
# The dataclass doesn't have any values in the loaded config dict.
continue

default_in_config_for_dc = kwargs[wrapper.dest]
# TODO: This interacts weirdly with subgroups!
if wrapper.was_subgroup:
# FIXME: Debugging
assert False, (wrapper, default_in_config_for_dc)

if isinstance(default_in_config_for_dc, dict):
# wrapper.set_default(config_default_for_dc)
pass
elif isinstance(default_in_config_for_dc, (str, Path)):
# BUG: Need to check if the field is a subgroup first!
default_in_config_for_dc = read_file(path=default_in_config_for_dc)
# wrapper.set_default(config_default_for_dc)
elif is_dataclass_instance(default_in_config_for_dc):
# wrapper.set_default(config_default_for_dc)
pass
else:
raise ValueError(
f"Got a default for field {wrapper.dest} that isn't a dataclass, dict or "
f"Path: {default_in_config_for_dc}"
)

# Set the .default attribute on the DataclassWrapper (which also updates the
# defaults of the fields and any nested dataclass fields).
wrapper.set_default(default_for_dataclass)
# Set the .default attribute on the DataclassWrapper (which also updates the
# defaults of the fields and any nested dataclass fields).
wrapper.set_default(default_in_config_for_dc)

# It's impossible for multiple wrappers in kwargs to have the same destination.
assert wrapper.dest not in kwarg_defaults_set_in_dataclasses
value_for_constructor_arguments = (
default_in_config_for_dc
if isinstance(default_in_config_for_dc, dict)
else dataclasses.asdict(default_in_config_for_dc)
)
kwarg_defaults_set_in_dataclasses[wrapper.dest] = value_for_constructor_arguments
# Remove this from the **kwargs, so they don't get set on the namespace.
kwargs.pop(wrapper.dest)

# It's impossible for multiple wrappers in kwargs to have the same destination.
assert wrapper.dest not in kwarg_defaults_set_in_dataclasses
value_for_constructor_arguments = (
default_for_dataclass
if isinstance(default_for_dataclass, dict)
else dataclasses.asdict(default_for_dataclass)
)
kwarg_defaults_set_in_dataclasses[wrapper.dest] = value_for_constructor_arguments
# Remove this from the **kwargs, so they don't get set on the namespace.
kwargs.pop(wrapper.dest)
# TODO: Stop using a defaultdict for the very important `self.constructor_arguments`!
self.constructor_arguments = dict_union(
self.constructor_arguments,
Expand Down Expand Up @@ -643,16 +658,25 @@ def _resolve_subgroups(
flags = subgroup_field.option_strings
argument_options = subgroup_field.arg_options

argparse_default = argument_options.get("default")
if subgroup_field.subgroup_default is dataclasses.MISSING:
assert argument_options["required"]
assert argparse_default is None
elif isinstance(argparse_default, dict):
# BUG: We have a default value that is a dictionary. This is probably because
# .set_defaults was called before we could resolve the subgroups, and so the
# subgroup_field has this weird default value of a dictionary.
# TODO: Figure out how to fix this!
argument_options["default"] = subgroup_field.subgroup_default
else:
assert argument_options["default"] is subgroup_field.subgroup_default
assert argparse_default is subgroup_field.subgroup_default, (
argparse_default,
subgroup_field.subgroup_default,
)
assert not is_dataclass_instance(argument_options["default"])

# TODO: Do we really need to care about this "SUPPRESS" stuff here?
if argparse.SUPPRESS in subgroup_field.parent.defaults:
assert argument_options["default"] is argparse.SUPPRESS
argument_options["default"] = argparse.SUPPRESS

logger.debug(
f"Adding subgroup argument: add_argument(*{flags} **{str(argument_options)})"
Expand Down
54 changes: 46 additions & 8 deletions simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from dataclasses import MISSING
from logging import getLogger
from typing import Any, Callable, Generic, TypeVar, cast
import warnings

import docstring_parser as dp
from typing_extensions import Literal

from simple_parsing.helpers.serialization.serializable import DC_TYPE_KEY

from .. import docstring, utils
from ..utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type
from .field_wrapper import FieldWrapper
Expand Down Expand Up @@ -274,6 +277,10 @@ def defaults(self) -> list[DataclassT | dict[str, Any] | None | Literal[argparse
self._defaults = [default_field_value]
return self._defaults

@property
def was_subgroup(self) -> bool:
return self._field is not None and "subgroups" in self._field.metadata

@defaults.setter
def defaults(self, value: list[DataclassT]):
self._defaults = value
Expand All @@ -288,24 +295,55 @@ def default(self) -> DataclassT | None:

def set_default(self, value: DataclassT | dict | None):
"""Sets the default values for the arguments of the fields of this dataclass."""
if value is not None and not isinstance(value, dict):
field_default_values = dataclasses.asdict(value)
else:
field_default_values = value
self._default = value
if field_default_values is None:
if value is None:
self._default = None
return

field_default_values = (
value if isinstance(value, dict) else dataclasses.asdict(value)
).copy()
if DC_TYPE_KEY in field_default_values:
from simple_parsing.helpers.serialization.serializable import _locate

dc_type_qualpath = field_default_values.pop(DC_TYPE_KEY)
assert isinstance(dc_type_qualpath, str)
dataclass_type = _locate(dc_type_qualpath)
logger.debug(
f"The default value dictionary has the {DC_TYPE_KEY=} of {dc_type_qualpath}"
)
if self.dataclass_fn is not dataclass_type:
logger.debug(
f"Overwriting the dataclass_fn from {self.dataclass_fn} to {dataclass_type}"
)
self.dataclass_fn = dataclass_type

self._default = field_default_values

for field_wrapper in self.fields:
if field_wrapper.name not in field_default_values:
continue
# Manually set the default value for this argument.
field_default_value = field_default_values[field_wrapper.name]
field_default_value = field_default_values.pop(field_wrapper.name)
if field_wrapper.is_subgroup and isinstance(field_default_value, dict):
# TODO: FIX THIS: Perhaps we need to store this "default" for later, so that we can
# apply them to the DataclassWrapper that will be created for this field once
# subgroups are resolved.
assert False, field_default_value

field_wrapper.set_default(field_default_value)
for nested_dataclass_wrapper in self._children:
if nested_dataclass_wrapper.name not in field_default_values:
continue
field_default_value = field_default_values[nested_dataclass_wrapper.name]
field_default_value = field_default_values.pop(nested_dataclass_wrapper.name)
nested_dataclass_wrapper.set_default(field_default_value)
if field_default_values:
warnings.warn(
RuntimeWarning(
f"Got some unexpected leftover default values for dataclass at path "
f"{self.dest}: "
f"{field_default_values}"
)
)

@property
def title(self) -> str:
Expand Down
11 changes: 11 additions & 0 deletions simple_parsing/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,17 @@ def _get_value(dataclass_default: utils.Dataclass | dict, name: str) -> Any:

def set_default(self, value: Any):
logger.debug(f"The field {self.name} has its default manually set to a value of {value}.")
# TODO: Should we do something different if we're a subgroup field and this default value
# is a dict (which happens when using a config path)? This field is intended to just parse
# the subgroup choice.
if self.is_subgroup and isinstance(value, dict):
raise RuntimeError(
f"Something weird happened: We don't expect this subgroup field at {self.dest!r} "
f"to get a dict default value (it's just used to parse the choice of subgroup "
f"(a string)) "
f"Received default value of {value}"
)
# logger.debug("Setting the default value for a subgroup field to a dict!")
self._default = value

@property
Expand Down
2 changes: 1 addition & 1 deletion test/subgroups/test_issue_276.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class C:
"",
C(c=B2(a=0, b=1)),
True,
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
# marks=pytest.mark.xfail(strict=True, raises=AssertionError),
),
pytest.param(
"--c.a=1",
Expand Down
1 change: 1 addition & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def set_prog(prog_name: str, args: str):

@pytest.fixture
def assert_equals_stdout(capsys):
# TODO: Replace all this with a FileRegression.
def strip(string):
return "".join(string.split())

Expand Down

0 comments on commit 403b81f

Please sign in to comment.