Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 19, 2024
1 parent c18f7a4 commit b102010
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
18 changes: 13 additions & 5 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6544,12 +6544,16 @@ def update_(

named = True

def inplace_update(name, dest, source):
def inplace_update(name, source, dest):
if source is None:
return None
name = _unravel_key_to_tuple(name)
for key in keys_to_update:
if key == name[: len(key)]:
if dest is None:
raise KeyError(
f"The key {name} was not found in the dest tensordict."
)
dest.copy_(source, non_blocking=non_blocking)

else:
Expand All @@ -6564,16 +6568,20 @@ def inplace_update(name, dest, source):
vals = [vals[k] for k in new_keys]
_foreach_copy_(vals, other_val, non_blocking=non_blocking)
return self
named = False
named = True

def inplace_update(dest, source):
def inplace_update(name, source, dest):
if source is None:
return None
if dest is None:
raise KeyError(
f"The key {name} was not found in the dest tensordict."
)
dest.copy_(source, non_blocking=non_blocking)

self._apply_nest(
input_dict_or_td._apply_nest(
inplace_update,
input_dict_or_td,
self,
nested_keys=True,
default=None,
filter_empty=True,
Expand Down
22 changes: 22 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3033,6 +3033,28 @@ def make(val, todict=False, stack=False):
assert (td1.select(("a", "b")) == 2).all()
assert (td1.exclude(("a", "b")) == 1).all()

# Any extra key in dest will raise an exception
with pytest.raises(KeyError):
td_dest = TensorDict(a=0)
td_source = TensorDict(b=1)
td_dest.update_(td_source)
with pytest.raises(KeyError):
td_dest = TensorDict(a=0)
td_source = {"b": torch.ones(())}
td_dest.update_(td_source)
with pytest.raises(KeyError):
td_dest = TensorDict(a=0)
td_source = TensorDict(b=1)
td_dest.update_(td_source, keys_to_update="b")
with pytest.raises(KeyError):
td_dest = TensorDict(a=0)
td_source = {"b": torch.ones(())}
td_dest.update_(td_source, keys_to_update="b")

td_dest = TensorDict(a=0, b=1)
td_source = TensorDict(a=0)
td_dest.update_(td_source)

def test_update_nested_dict(self):
t = TensorDict({"a": {"d": [[[0]] * 3] * 2}}, [2, 3])
assert ("a", "d") in t.keys(include_nested=True)
Expand Down

0 comments on commit b102010

Please sign in to comment.