Skip to content

Commit

Permalink
Fix bug when deep copying full config with missing parent (#3009)
Browse files Browse the repository at this point in the history
  • Loading branch information
jesszzzz authored Jan 16, 2025
1 parent ca4d25c commit 7c8fa4a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
25 changes: 16 additions & 9 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,24 @@ def _deep_copy_full_config(subconfig: Any) -> Any:
return copy.deepcopy(subconfig)

full_key = subconfig._get_full_key(None)
if full_key:
full_config_copy = copy.deepcopy(subconfig._get_root())
if OmegaConf.is_list(subconfig._get_parent()):
# OmegaConf has a bug where _get_full_key doesn't add [] if the parent
# is a list, eg. instead of foo[0], it'll return foo0
index = subconfig._key()
full_key = full_key[: -len(str(index))] + f"[{index}]"
return OmegaConf.select(full_config_copy, full_key)
else:
if not full_key:
return copy.deepcopy(subconfig)

if OmegaConf.is_list(subconfig._get_parent()):
# OmegaConf has a bug where _get_full_key doesn't add [] if the parent
# is a list, eg. instead of foo[0], it'll return foo0
index = subconfig._key()
full_key = full_key[: -len(str(index))] + f"[{index}]"
root = subconfig._get_root()
full_key = full_key.replace(root._get_full_key(None) or "", "", 1)
if OmegaConf.select(root, full_key) is not subconfig:
# The parent chain and full key are not consistent so don't
# try to copy the full config
return copy.deepcopy(subconfig)

full_config_copy = copy.deepcopy(root)
return OmegaConf.select(full_config_copy, full_key)


def instantiate(
config: Any,
Expand Down
14 changes: 14 additions & 0 deletions tests/instantiate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from omegaconf import MISSING, DictConfig, ListConfig

from hydra.types import TargetConf
from hydra.utils import instantiate
from tests.instantiate.module_shadowed_by_function import a_function

module_shadowed_by_function = a_function
Expand Down Expand Up @@ -418,6 +419,19 @@ class NestedConf:
b: Any = field(default_factory=lambda: User(name="b", age=2))


class TargetWithInstantiateInInit:
def __init__(
self, user_config: Optional[DictConfig], user: Optional[User] = None
) -> None:
if user:
self.user = user
else:
self.user = instantiate(user_config)

def __eq__(self, other: Any) -> bool:
return self.user.__eq__(other.user)


def recisinstance(got: Any, expected: Any) -> bool:
"""Compare got with expected type, recursively on dict and list."""
if not isinstance(got, type(expected)):
Expand Down
25 changes: 25 additions & 0 deletions tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
SimpleClassNonPrimitiveConf,
SimpleClassPrimitiveConf,
SimpleDataClass,
TargetWithInstantiateInInit,
Tree,
TreeConf,
UntypedPassthroughClass,
Expand Down Expand Up @@ -571,6 +572,30 @@ def test_none_cases(
OmegaConf.create({"unique_id": 5}),
id="interpolation_from_parent_with_interpolation",
),
param(
DictConfig(
{
"username": "test_user",
"node": {
"_target_": "tests.instantiate.TargetWithInstantiateInInit",
"_recursive_": False,
"user_config": {
"_target_": "tests.instantiate.User",
"name": "${foo_b.username}",
"age": 40,
},
},
"foo_b": {
"username": "${username}",
},
}
),
{},
TargetWithInstantiateInInit(
user_config=None, user=User(name="test_user", age=40)
),
id="target_with_instantiate_in_init",
),
],
)
def test_interpolation_accessing_parent(
Expand Down

0 comments on commit 7c8fa4a

Please sign in to comment.