Skip to content

Commit

Permalink
Merge branch 'main' into vmap_dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Jan 8, 2024
2 parents 7cdc1e6 + 975a205 commit f737172
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ to be able to create this other composition:
TransformedEnv
ActionMask
BinarizeReward
BurnInTransform
CatFrames
CatTensors
CenterCrop
Expand Down
230 changes: 229 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from torchrl.envs import (
ActionMask,
BinarizeReward,
BurnInTransform,
CatFrames,
CatTensors,
CenterCrop,
Expand Down Expand Up @@ -114,7 +115,7 @@
from torchrl.envs.transforms.vc1 import _has_vc
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp
from torchrl.modules import LSTMModule, MLP, ProbabilisticActor, TanhNormal
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal

TIMEOUT = 100.0

Expand Down Expand Up @@ -9366,6 +9367,233 @@ def test_transform_inverse(self):
pass


class TestBurnInTransform(TransformBase):
def _make_gru_module(self, input_size=4, hidden_size=4, device="cpu"):
return GRUModule(
input_size=input_size,
hidden_size=hidden_size,
batch_first=True,
in_keys=["observation", "rhs", "is_init"],
out_keys=["output", ("next", "rhs")],
device=device,
).set_recurrent_mode(True)

def _make_lstm_module(self, input_size=4, hidden_size=4, device="cpu"):
return LSTMModule(
input_size=input_size,
hidden_size=hidden_size,
batch_first=True,
in_keys=["observation", "rhs_h", "rhs_c", "is_init"],
out_keys=["output", ("next", "rhs_h"), ("next", "rhs_c")],
device=device,
).set_recurrent_mode(True)

def _make_batch(self, batch_size: int = 2, sequence_length: int = 5):
observation = torch.randn(batch_size, sequence_length + 1, 4)
is_init = torch.zeros(batch_size, sequence_length, 1, dtype=torch.bool)
batch = TensorDict(
{
"observation": observation[:, :-1],
"is_init": is_init,
"next": TensorDict(
{
"observation": observation[:, 1:],
},
batch_size=[batch_size, sequence_length],
),
},
batch_size=[batch_size, sequence_length],
)
return batch

def test_single_trans_env_check(self):
module = self._make_gru_module()
burn_in_transform = BurnInTransform(module, burn_in=2)
with pytest.raises(
RuntimeError,
match="BurnInTransform can only be appended to a ReplayBuffer.",
):
env = TransformedEnv(ContinuousActionVecMockEnv(), burn_in_transform)
check_env_specs(env)
env.close()

def test_serial_trans_env_check(self):
raise pytest.skip(
"BurnInTransform can only be appended to a ReplayBuffer, not to a TransformedEnv."
)

def test_parallel_trans_env_check(self):
raise pytest.skip(
"BurnInTransform can only be appended to a ReplayBuffer, not to a TransformedEnv."
)

def test_trans_serial_env_check(self):
raise pytest.skip(
"BurnInTransform can only be appended to a ReplayBuffer, not to a TransformedEnv."
)

def test_trans_parallel_env_check(self):
raise pytest.skip(
"BurnInTransform can only be appended to a ReplayBuffer, not to a TransformedEnv."
)

@pytest.mark.parametrize("module", ["gru", "lstm"])
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("sequence_length", [4, 8])
@pytest.mark.parametrize("burn_in", [2])
def test_transform_no_env(self, module, batch_size, sequence_length, burn_in):
"""tests the transform on dummy data, without an env."""
torch.manual_seed(0)
data = self._make_batch(batch_size, sequence_length)

if module == "gru":
module = self._make_gru_module()
hidden = torch.zeros(
data.batch_size + (module.gru.num_layers, module.gru.hidden_size)
)
data.set("rhs", hidden)
else:
module = self._make_lstm_module()
hidden_h = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
hidden_c = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
data.set("rhs_h", hidden_h)
data.set("rhs_c", hidden_c)

burn_in_transform = BurnInTransform(module, burn_in=burn_in)
data = burn_in_transform(data)
assert data.shape[-1] == sequence_length - burn_in

for key in data.keys():
if key.startswith("rhs"):
assert data[:, 0].get(key).abs().sum() > 0.0
assert data[:, 1:].get(key).sum() == 0.0

@pytest.mark.parametrize("module", ["gru", "lstm"])
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("sequence_length", [4, 8])
@pytest.mark.parametrize("burn_in", [2])
def test_transform_compose(self, module, batch_size, sequence_length, burn_in):
"""tests the transform on dummy data, without an env but inside a Compose."""
torch.manual_seed(0)
data = self._make_batch(batch_size, sequence_length)

if module == "gru":
module = self._make_gru_module()
hidden = torch.zeros(
data.batch_size + (module.gru.num_layers, module.gru.hidden_size)
)
data.set("rhs", hidden)
else:
module = self._make_lstm_module()
hidden_h = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
hidden_c = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
data.set("rhs_h", hidden_h)
data.set("rhs_c", hidden_c)

burn_in_compose = Compose(BurnInTransform(module, burn_in=burn_in))
data = burn_in_compose(data)
assert data.shape[-1] == sequence_length - burn_in

for key in data.keys():
if key.startswith("rhs"):
assert data[:, 0].get(key).abs().sum() > 0.0
assert data[:, 1:].get(key).sum() == 0.0

def test_transform_env(self):
module = self._make_gru_module()
burn_in_transform = BurnInTransform(module, burn_in=2)
env = TransformedEnv(ContinuousActionVecMockEnv(), burn_in_transform)
with pytest.raises(
RuntimeError,
match="BurnInTransform can only be appended to a ReplayBuffer.",
):
rollout = env.rollout(3)

@pytest.mark.parametrize("module", ["gru", "lstm"])
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("sequence_length", [4, 8])
@pytest.mark.parametrize("burn_in", [2])
def test_transform_model(self, module, batch_size, sequence_length, burn_in):
torch.manual_seed(0)
data = self._make_batch(batch_size, sequence_length)

if module == "gru":
module = self._make_gru_module()
hidden = torch.zeros(
data.batch_size + (module.gru.num_layers, module.gru.hidden_size)
)
data.set("rhs", hidden)
else:
module = self._make_lstm_module()
hidden_h = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
hidden_c = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
data.set("rhs_h", hidden_h)
data.set("rhs_c", hidden_c)

burn_in_transform = BurnInTransform(module, burn_in=burn_in)
module = nn.Sequential(burn_in_transform, nn.Identity())
data = module(data)
assert data.shape[-1] == sequence_length - burn_in

for key in data.keys():
if key.startswith("rhs"):
assert data[:, 0].get(key).abs().sum() > 0.0
assert data[:, 1:].get(key).sum() == 0.0

@pytest.mark.parametrize("module", ["gru", "lstm"])
@pytest.mark.parametrize("batch_size", [2, 4])
@pytest.mark.parametrize("sequence_length", [4, 8])
@pytest.mark.parametrize("burn_in", [2])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, module, batch_size, sequence_length, burn_in, rbclass):
torch.manual_seed(0)
data = self._make_batch(batch_size, sequence_length)

if module == "gru":
module = self._make_gru_module()
hidden = torch.zeros(
data.batch_size + (module.gru.num_layers, module.gru.hidden_size)
)
data.set("rhs", hidden)
else:
module = self._make_lstm_module()
hidden_h = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
hidden_c = torch.zeros(
data.batch_size + (module.lstm.num_layers, module.lstm.hidden_size)
)
data.set("rhs_h", hidden_h)
data.set("rhs_c", hidden_c)

burn_in_transform = BurnInTransform(module, burn_in=burn_in)
rb = rbclass(storage=LazyTensorStorage(20))
rb.append_transform(burn_in_transform)
rb.extend(data)
batch = rb.sample(2)
assert batch.shape[-1] == sequence_length - burn_in

for key in batch.keys():
if key.startswith("rhs"):
assert batch[:, 0].get(key).abs().sum() > 0.0
assert batch[:, 1:].get(key).sum() == 0.0

def test_transform_inverse(self):
raise pytest.skip("No inverse for BurnInTransform")


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .transforms import (
ActionMask,
BinarizeReward,
BurnInTransform,
CatFrames,
CatTensors,
CenterCrop,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .transforms import (
ActionMask,
BinarizeReward,
BurnInTransform,
CatFrames,
CatTensors,
CenterCrop,
Expand Down
Loading

0 comments on commit f737172

Please sign in to comment.