Skip to content

Commit

Permalink
[Feature] RB MultiStep transform (pytorch#2008)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored and SandishKumarHN committed Mar 18, 2024
1 parent 8b5e0ff commit a87094d
Show file tree
Hide file tree
Showing 10 changed files with 504 additions and 120 deletions.
8 changes: 8 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -823,3 +823,11 @@ Utils
consolidate_spec
check_no_exclusive_keys
contains_lazy_spec

.. currentmodule:: torchrl.envs.transforms.rb_transforms

.. autosummary::
:toctree: generated/
:template: rl_template.rst

MultiStepTransform
44 changes: 22 additions & 22 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def test_dqn_state_dict(self, delay_value, device, action_spec_type):
loss_fn2 = DQNLoss(actor, loss_function="l2", delay_value=delay_value)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss")]
Expand Down Expand Up @@ -1126,7 +1126,7 @@ def test_qmixer_state_dict(self, delay_value, device, action_spec_type):
loss_fn2 = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -1159,7 +1159,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9)

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss")]
Expand Down Expand Up @@ -1803,7 +1803,7 @@ def test_ddpg_separate_losses(
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)])
def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9):
Expand Down Expand Up @@ -1834,7 +1834,7 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9):

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -2436,7 +2436,7 @@ def test_td3_separate_losses(
loss_fn.zero_grad()

@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)])
@pytest.mark.parametrize("policy_noise", [0.1, 1.0])
Expand Down Expand Up @@ -2482,7 +2482,7 @@ def test_td3_batcher(
np.random.seed(0)
loss = loss_fn(td)

if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -3231,7 +3231,7 @@ def test_sac_separate_losses(
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down Expand Up @@ -3295,7 +3295,7 @@ def test_sac_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -3930,7 +3930,7 @@ def test_discrete_sac_state_dict(
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [2])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -3986,7 +3986,7 @@ def test_discrete_sac_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -4875,7 +4875,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est):
# TODO: find a way to compare the losses: problem is that we sample actions either sequentially or in batch,
# so setting seed has little impact

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -4918,7 +4918,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9):
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -5487,7 +5487,7 @@ def test_cql_state_dict(
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("max_q_backup", [True, False])
Expand Down Expand Up @@ -5542,7 +5542,7 @@ def test_cql_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -5848,7 +5848,7 @@ def test_dcql_state_dict(self, delay_value, device, action_spec_type):
loss_fn2 = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -5879,7 +5879,7 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9)

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum([item for key, item in loss.items() if key.startswith("loss_")])
_loss_ms = sum(
Expand Down Expand Up @@ -9364,7 +9364,7 @@ def test_iql_separate_losses(self, separate_losses):
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@pytest.mark.parametrize("expectile", [0.1, 0.5, 1.0])
Expand Down Expand Up @@ -9415,7 +9415,7 @@ def test_iql_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -10176,7 +10176,7 @@ def test_discrete_iql_separate_losses(self, separate_losses):
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@pytest.mark.parametrize("expectile", [0.1, 0.5])
Expand Down Expand Up @@ -10227,7 +10227,7 @@ def test_discrete_iql_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down
47 changes: 13 additions & 34 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchrl.data.postprocs.postprocs import MultiStep


@pytest.mark.parametrize("n", range(13))
@pytest.mark.parametrize("n", range(1, 14))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("key", ["observation", "pixels", "observation_whatever"])
def test_multistep(n, key, device, T=11):
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_multistep(n, key, device, T=11):

assert ms_tensordict.get("done").max() == 1

if n == 0:
if n == 1:
assert_allclose_td(
tensordict, ms_tensordict.select(*list(tensordict.keys(True, True)))
)
Expand All @@ -76,20 +76,18 @@ def test_multistep(n, key, device, T=11):
)

# check that next obs is properly replaced, or that it is terminated
next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps) :]
true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps)]
next_obs = ms_tensordict.get(key)[:, (ms.n_steps) :]
true_next_obs = ms_tensordict.get(("next", key))[:, : -(ms.n_steps)]
terminated = ~ms_tensordict.get("nonterminal")
assert (
(next_obs == true_next_obs).all(-1) | terminated[:, (1 + ms.n_steps) :]
).all()
assert ((next_obs == true_next_obs).all(-1) | terminated[:, (ms.n_steps) :]).all()

# test gamma computation
torch.testing.assert_close(
ms_tensordict.get("gamma"), ms.gamma ** ms_tensordict.get("steps_to_next_obs")
)

# test reward
if n > 0:
if n > 1:
assert (
ms_tensordict.get(("next", "reward"))
!= ms_tensordict.get(("next", "original_reward"))
Expand All @@ -105,36 +103,17 @@ def test_multistep(n, key, device, T=11):
@pytest.mark.parametrize(
"batch_size",
[
[
4,
],
[4],
[],
[
1,
],
[1],
[2, 3],
],
)
@pytest.mark.parametrize(
"T",
[
10,
1,
2,
],
)
@pytest.mark.parametrize(
"obs_dim",
[
[
1,
],
[],
],
)
@pytest.mark.parametrize("T", [10, 1, 2])
@pytest.mark.parametrize("obs_dim", [[1], []])
@pytest.mark.parametrize("unsq_reward", [True, False])
@pytest.mark.parametrize("last_done", [True, False])
@pytest.mark.parametrize("n_steps", [3, 1, 0])
@pytest.mark.parametrize("n_steps", [4, 2, 1])
def test_mutistep_cattrajs(
batch_size, T, obs_dim, unsq_reward, last_done, device, n_steps
):
Expand Down Expand Up @@ -166,7 +145,7 @@ def test_mutistep_cattrajs(
)
ms = MultiStep(0.98, n_steps)
tdm = ms(td)
if n_steps == 0:
if n_steps == 1:
# n_steps = 0 has no effect
for k in td["next"].keys():
assert (tdm["next", k] == td["next", k]).all()
Expand All @@ -179,7 +158,7 @@ def test_mutistep_cattrajs(
if unsq_reward:
done = done.squeeze(-1)
for t in range(T):
idx = t + n_steps
idx = t + n_steps - 1
while (done[..., t:idx].any() and idx > t) or idx > done.shape[-1] - 1:
idx = idx - 1
next_obs.append(obs[..., idx])
Expand Down
Loading

0 comments on commit a87094d

Please sign in to comment.