From 6154292742f7abe57296651b2b26d4a2868934d7 Mon Sep 17 00:00:00 2001 From: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com> Date: Thu, 1 Aug 2024 00:54:07 +0200 Subject: [PATCH] Improvements in batch (#1181) This PR contains several important extensions and improvements of Batch. 1. A method for setting a new sequence. This is a first step towards using code that's less reliant on setting attributes directly. The method also permits setting a subsequence and filling up with default values. This will help ensuring that all entries in a batch are of the same length, something that we should start enforcing soon. 2. A new method for applying arbitrary transformations to "leaves" in a Batch. This simplifies already existing code and is generally very handy for users 3. The arbitrary transformations allowed implementing `isnull`, `hasnull` and `dropnull`, which helps finding errors early. 4. The arbitrary transformations now also allow extracting a `schema` from a batch. This was used in `Batch.cat_` to perform additional input validation (we now make sure there that the structures are the same when concatenating batches). This input validation is a **breaking change**! Some tests that concatenated incompatible batches were removed. Eventually, we can add a `get_schema` method to the batch that will retrieve metainfo like shapes and datatypes. For now, this is delegated to the user who can use the new `apply_values_transform` 5. New feature: slicing of torch distributions. Previously several batches contained instances of `Distribution` which were not properly sliced, inviting bugs and errors in user code. This is now fixed - albeit not in a pretty way (torch doesn't allow slicing the objects natively) 6. Stricter typing and some further minor extensions and simplifications The new code was extensively tested and documented. --- docs/spelling_wordlist.txt | 2 +- examples/inverse/irl_gail.py | 18 +- test/base/env.py | 6 +- test/base/test_batch.py | 261 ++++++++---- test/base/test_buffer.py | 533 +++++++++++++++---------- test/base/test_collector.py | 14 +- test/base/test_env_finite.py | 11 +- test/base/test_returns.py | 150 ++++--- tianshou/data/batch.py | 504 ++++++++++++++++++----- tianshou/data/buffer/base.py | 2 +- tianshou/data/buffer/prio.py | 3 + tianshou/data/types.py | 18 +- tianshou/policy/base.py | 8 +- tianshou/policy/imitation/bcq.py | 4 +- tianshou/policy/imitation/cql.py | 3 +- tianshou/policy/multiagent/mapolicy.py | 2 +- 16 files changed, 1053 insertions(+), 486 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 83de82356..c30b9f2cb 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -270,4 +270,4 @@ v_s v_s_ obs obs_next - +dtype diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 42e5bc2c9..e327fd490 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -4,7 +4,7 @@ import datetime import os import pprint -from typing import SupportsFloat +from typing import SupportsFloat, cast import d4rl import gymnasium as gym @@ -16,6 +16,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAILPolicy from tianshou.policy.base import BasePolicy @@ -185,12 +186,15 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: for i in range(dataset_size): expert_buffer.add( - Batch( - obs=dataset["observations"][i], - act=dataset["actions"][i], - rew=dataset["rewards"][i], - done=dataset["terminals"][i], - obs_next=dataset["next_observations"][i], + cast( + RolloutBatchProtocol, + Batch( + obs=dataset["observations"][i], + act=dataset["actions"][i], + rew=dataset["rewards"][i], + done=dataset["terminals"][i], + obs_next=dataset["next_observations"][i], + ), ), ) print("dataset loaded") diff --git a/test/base/env.py b/test/base/env.py index 2a7b09278..02f76ad2d 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -147,6 +147,8 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. if self.index == self.size: self.terminated = True return self._get_state(), self._get_reward(), self.terminated, False, {} + + info_dict = {"key": 1, "env": self} if action == 0: self.index = max(self.index - 1, 0) return ( @@ -154,7 +156,7 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. self._get_reward(), self.terminated, False, - {"key": 1, "env": self} if self.dict_state else {}, + info_dict, ) if action == 1: self.index += 1 @@ -164,7 +166,7 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. self._get_reward(), self.terminated, False, - {"key": 1, "env": self}, + info_dict, ) return None diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 0530d8232..9accff86c 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -9,8 +9,10 @@ import pytest import torch from deepdiff import DeepDiff +from torch.distributions.categorical import Categorical from tianshou.data import Batch, to_numpy, to_torch +from tianshou.data.batch import IndexType, get_sliced_dist def test_batch() -> None: @@ -122,8 +124,8 @@ def test_batch() -> None: with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) - assert isinstance(batch2[0].a.b, np.float64) - assert isinstance(batch2[0].a.d.e, np.float64) + assert isinstance(batch2[0].a.b, float) + assert isinstance(batch2[0].a.d.e, float) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch(list(batch2)) assert batch2_from_list.a.b == batch2.a.b @@ -244,13 +246,14 @@ def test_batch_cat_and_stack() -> None: assert b12_cat_in.a.d.e.ndim == 1 a = Batch(a=Batch(a=np.random.randn(3, 4))) + a_empty = Batch(a=Batch(a=Batch())) assert np.allclose( np.concatenate([a.a.a, a.a.a]), - Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a, + Batch.cat([a, a_empty, a]).a.a, ) # test cat with lens infer - a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) + a = Batch(a=Batch(a=np.random.randn(3, 4), t=Batch()), b=np.random.randn(3, 4)) b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) ans = Batch.cat([a, b, a]) assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) @@ -261,34 +264,8 @@ def test_batch_cat_and_stack() -> None: assert isinstance(b1.a.d.e, np.ndarray) assert b1.a.d.e.ndim == 2 - # test cat with incompatible keys - b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) - test = Batch.cat([b1, b2]) - ans = Batch( - a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c])), - ) - assert np.allclose(test.a, ans.a) - assert torch.allclose(test.b, ans.b) - assert np.allclose(test.common.c, ans.common.c) - - # test cat with reserved keys (values are Batch()) - b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) - test = Batch.cat([b1, b2]) - ans = Batch( - a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c])), - ) - assert np.allclose(test.a, ans.a) - assert torch.allclose(test.b, ans.b) - assert np.allclose(test.common.c, ans.common.c) - # test cat with all reserved keys (values are Batch()) - b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) + b1 = Batch(a=Batch(), b=torch.zeros(3, 3), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch( @@ -373,31 +350,6 @@ def test_batch_cat_and_stack() -> None: Batch.stack([b1, b2], axis=1) -def test_batch_over_batch_to_torch() -> None: - batch = Batch( - a=np.float64(1.0), - b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), - ) - batch.b.__dict__["e"] = 1 # bypass the check - batch.to_torch_() - assert isinstance(batch.a, torch.Tensor) - assert isinstance(batch.b.c, torch.Tensor) - assert isinstance(batch.b.d, torch.Tensor) - assert isinstance(batch.b.e, torch.Tensor) - assert batch.a.dtype == torch.float64 - assert batch.b.c.dtype == torch.float32 - assert batch.b.d.dtype == torch.float64 - if sys.platform in ["win32", "cygwin"]: # windows - assert batch.b.e.dtype == torch.int32 - else: - assert batch.b.e.dtype == torch.int64 - batch.to_torch_(dtype=torch.float32) - assert batch.a.dtype == torch.float32 - assert batch.b.c.dtype == torch.float32 - assert batch.b.d.dtype == torch.float32 - assert batch.b.e.dtype == torch.float32 - - def test_utils_to_torch_numpy() -> None: batch = Batch( a=np.float64(1.0), @@ -408,7 +360,7 @@ def test_utils_to_torch_numpy() -> None: a_torch_double = to_torch(batch.a, dtype=torch.float64) assert a_torch_double.dtype == torch.float64 batch_torch_float = to_torch(batch, dtype=torch.float32) - assert batch_torch_float.a.dtype == torch.float32 + assert batch_torch_float.a.dtype == torch.float64 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32 data_list = [float("nan"), 1] @@ -473,18 +425,6 @@ def test_batch_pickle() -> None: assert np.all(batch.np == batch_pk.np) -def test_batch_from_to_numpy_without_copy() -> None: - batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) - a_mem_addr_orig = batch.a.__array_interface__["data"][0] - c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] - batch.to_torch_() - batch.to_numpy_() - a_mem_addr_new = batch.a.__array_interface__["data"][0] - c_mem_addr_new = batch.b.c.__array_interface__["data"][0] - assert a_mem_addr_new == a_mem_addr_orig - assert c_mem_addr_new == c_mem_addr_orig - - def test_batch_copy() -> None: batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6])) batch2 = Batch({"c": np.array([6, 7, 8]), "b": batch}) @@ -703,9 +643,7 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None: assert not DeepDiff(batch.to_dict(recursive=True), expected) -class TestToNumpy: - """Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` .""" - +class TestBatchConversions: @staticmethod def test_to_numpy() -> None: batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) @@ -726,10 +664,6 @@ def test_to_numpy_() -> None: assert isinstance(batch.b, np.ndarray) assert isinstance(batch.c.d, np.ndarray) - -class TestToTorch: - """Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` .""" - @staticmethod def test_to_torch() -> None: batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])}) @@ -749,3 +683,178 @@ def test_to_torch_() -> None: assert id_batch == id(batch) assert isinstance(batch.b, torch.Tensor) assert isinstance(batch.c.d, torch.Tensor) + + @staticmethod + def test_apply_array_func() -> None: + batch = Batch(a=1, b=np.arange(3), c={"d": np.array([1, 2, 3])}) + batch_with_max = batch.apply_values_transform(np.max) + assert np.array_equal(batch_with_max.a, np.array(1)) + assert np.array_equal(batch_with_max.b, np.array(2)) + assert np.array_equal(batch_with_max.c.d, np.array(3)) + + batch_array_added = batch.apply_values_transform(lambda x: x + np.array([1, 2, 3])) + assert np.array_equal(batch_array_added.a, np.array([2, 3, 4])) + assert np.array_equal(batch_array_added.c.d, np.array([2, 4, 6])) + + @staticmethod + def test_batch_to_numpy_without_copy() -> None: + batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) + a_mem_addr_orig = batch.a.__array_interface__["data"][0] + c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] + batch.to_numpy_() + a_mem_addr_new = batch.a.__array_interface__["data"][0] + c_mem_addr_new = batch.b.c.__array_interface__["data"][0] + assert a_mem_addr_new == a_mem_addr_orig + assert c_mem_addr_new == c_mem_addr_orig + + @staticmethod + def test_batch_from_to_numpy_without_copy() -> None: + batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) + a_mem_addr_orig = batch.a.__array_interface__["data"][0] + c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] + batch.to_torch_() + batch.to_numpy_() + a_mem_addr_new = batch.a.__array_interface__["data"][0] + c_mem_addr_new = batch.b.c.__array_interface__["data"][0] + assert a_mem_addr_new == a_mem_addr_orig + assert c_mem_addr_new == c_mem_addr_orig + + @staticmethod + def test_batch_over_batch_to_torch() -> None: + batch = Batch( + a=np.float64(1.0), + b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), + ) + batch.b.set_array_at_key(np.array([1]), "e") + batch.to_torch_() + assert isinstance(batch.a, torch.Tensor) + assert isinstance(batch.b.c, torch.Tensor) + assert isinstance(batch.b.d, torch.Tensor) + assert isinstance(batch.b.e, torch.Tensor) + assert batch.a.dtype == torch.float64 + assert batch.b.c.dtype == torch.float32 + assert batch.b.d.dtype == torch.float64 + if sys.platform in ["win32", "cygwin"]: # windows + assert batch.b.e.dtype == torch.int32 + else: + assert batch.b.e.dtype == torch.int64 + batch.to_torch_(dtype=torch.float32) + assert batch.a.dtype == torch.float32 + assert batch.b.c.dtype == torch.float32 + assert batch.b.d.dtype == torch.float32 + assert batch.b.e.dtype == torch.float32 + + +class TestAssignment: + @staticmethod + def test_assign_full_length_array() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + batch.set_array_at_key(np.array([1, 2, 3]), "a") + batch.set_array_at_key(np.array([4, 5, 6]), "new_key") + assert np.array_equal(batch.a, np.array([1, 2, 3])) + assert np.array_equal(batch.new_key, np.array([4, 5, 6])) + + # other keys are not affected + assert np.array_equal(batch.b, np.array([7, 8, 9])) + assert np.array_equal(batch.c.d, np.array([1, 2, 3])) + + with pytest.raises(ValueError): + # wrong length + batch.set_array_at_key(np.array([1, 2]), "a") + + @staticmethod + def test_assign_subarray_existing_key() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + batch.set_array_at_key(np.array([1, 2]), "a", index=[0, 1]) + assert np.array_equal(batch.a, np.array([1, 2, 6])) + batch.set_array_at_key(np.array([10, 12]), "a", index=slice(0, 2)) + assert np.array_equal(batch.a, np.array([10, 12, 6])) + batch.set_array_at_key(np.array([1, 2]), "a", index=[0, 2]) + assert np.array_equal(batch.a, np.array([1, 12, 2])) + batch.set_array_at_key(np.array([1, 2]), "a", index=[2, 0]) + assert np.array_equal(batch.a, np.array([2, 12, 1])) + batch.set_array_at_key(np.array([1, 2, 3]), "a", index=[2, 1, 0]) + assert np.array_equal(batch.a, np.array([3, 2, 1])) + + with pytest.raises(IndexError): + # Index out of bounds + batch.set_array_at_key(np.array([1, 2]), "a", index=[10, 11]) + + # other keys are not affected + assert np.array_equal(batch.b, np.array([7, 8, 9])) + assert np.array_equal(batch.c.d, np.array([1, 2, 3])) + + @staticmethod + def test_assign_subarray_new_key() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + batch.set_array_at_key(np.array([1, 2]), "new_key", index=[0, 1], default_value=0) + assert np.array_equal(batch.new_key, np.array([1, 2, 0])) + # with float, None can be cast to NaN + batch.set_array_at_key(np.array([1.0, 2.0]), "new_key2", index=[0, 1]) + assert np.array_equal(batch.new_key2, np.array([1.0, 2.0, np.nan]), equal_nan=True) + + @staticmethod + def test_isnull() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([1, None, 3])}) + batch_isnan = batch.isnull() + assert not batch_isnan.a.any() + assert batch_isnan.b[2] + assert not batch_isnan.b[:2].any() + assert np.array_equal(batch_isnan.c.d, np.array([False, True, False])) + + @staticmethod + def test_hasnull() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([1, 2, 3])}) + assert batch.hasnull() + batch = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + assert not batch.hasnull() + batch = Batch(a=[4, 5, 6], c={"d": np.array([1, None, 3])}) + assert batch.hasnull() + + @staticmethod + def test_dropnull() -> None: + batch = Batch(a=[4, 5, 6], b=[7, 8, None], c={"d": np.array([None, 2.1, 3.0])}) + assert batch.dropnull() == Batch( + a=[5], + b=[8], + c={"d": np.array([2.1])}, + ).apply_values_transform( + np.atleast_1d, + ) + batch2 = Batch(a=[4, 5, 6, 7], b=[7, 8, None, 10], c={"d": np.array([None, 2, 3, 4])}) + assert batch2.dropnull() == Batch(a=[5, 7], b=[8, 10], c={"d": np.array([2, 4])}) + batch_no_nan = Batch(a=[4, 5, 6], b=[7, 8, 9], c={"d": np.array([1, 2, 3])}) + assert batch_no_nan.dropnull() == batch_no_nan + + +class TestSlicing: + # TODO: parametrize with other dists + @staticmethod + def test_slice_distribution() -> None: + cat_probs = torch.randint(1, 10, (10, 3)) + dist = Categorical(probs=cat_probs) + batch = Batch(dist=dist) + selected_idx = [1, 3] + sliced_batch = batch[selected_idx] + sliced_probs = cat_probs[selected_idx] + assert (sliced_batch.dist.probs == Categorical(probs=sliced_probs).probs).all() + assert ( + Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs + ).all() + # retrieving a single index + assert torch.allclose(batch[0].dist.probs, dist.probs[0]) + + @staticmethod + def test_getitem_with_int_gives_scalars() -> None: + batch = Batch(a=[1, 2], b=Batch(c=[3, 4])) + batch_sliced = batch[0] + assert batch_sliced.a == np.array(1) + assert batch_sliced.b.c == np.array(3) + + @staticmethod + @pytest.mark.parametrize("index", ([0, 1], np.array([0, 1]), torch.tensor([0, 1]), slice(0, 2))) + def test_getitem_with_slice_gives_subslice(index: IndexType) -> None: + batch = Batch(a=[1, 2, 3], b=Batch(c=torch.tensor([4, 5, 6]))) + batch_sliced = batch[index] + assert (batch_sliced.a == batch.a[index]).all() + assert (batch_sliced.b.c == batch.b.c[index]).all() diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 9f3f40828..0996f2436 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -21,6 +21,7 @@ SegmentTree, VectorReplayBuffer, ) +from tianshou.data.types import RolloutBatchProtocol from tianshou.data.utils.converter import to_hdf5 @@ -34,14 +35,17 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) buf.add( - Batch( - obs=obs, - act=[act], - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=obs_next, - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info, + ), ), ) obs = obs_next @@ -58,33 +62,36 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert (data.terminated <= 1).all() assert (data.truncated >= 0).all() assert (data.truncated <= 1).all() - b = ReplayBuffer(size=10) + replay_buffer = ReplayBuffer(size=10) # neg bsz should return empty index - assert b.sample_indices(-1).tolist() == [] - ptr, ep_rew, ep_len, ep_idx = b.add( - Batch( - obs=1, - act=1, - rew=1, - terminated=1, - truncated=0, - obs_next="str", - info={"a": 3, "b": {"c": 5.0}}, + assert replay_buffer.sample_indices(-1).tolist() == [] + ptr, ep_rew, ep_len, ep_idx = replay_buffer.add( + cast( + RolloutBatchProtocol, + Batch( + obs=1, + act=1, + rew=1, + terminated=1, + truncated=0, + obs_next="str", + info={"a": 3, "b": {"c": 5.0}}, + ), ), ) - assert b.obs[0] == 1 - assert b.done[0] - assert b.terminated[0] - assert not b.truncated[0] - assert b.obs_next[0] == "str" - assert np.all(b.obs[1:] == 0) - assert np.all(b.obs_next[1:] == np.array(None)) - assert b.info.a[0] == 3 - assert b.info.a.dtype == int - assert np.all(b.info.a[1:] == 0) - assert b.info.b.c[0] == 5.0 - assert b.info.b.c.dtype == float - assert np.all(b.info.b.c[1:] == 0.0) + assert replay_buffer.obs[0] == 1 + assert replay_buffer.done[0] + assert replay_buffer.terminated[0] + assert not replay_buffer.truncated[0] + assert replay_buffer.obs_next[0] == "str" + assert np.all(replay_buffer.obs[1:] == 0) + assert np.all(replay_buffer.obs_next[1:] == np.array(None)) + assert replay_buffer.info.a[0] == 3 + assert replay_buffer.info.a.dtype == int + assert np.all(replay_buffer.info.a[1:] == 0) + assert replay_buffer.info.b.c[0] == 5.0 + assert replay_buffer.info.b.c.dtype == float + assert np.all(replay_buffer.info.b.c[1:] == 0.0) assert ptr.shape == (1,) assert ptr[0] == 0 assert ep_rew.shape == (1,) @@ -94,28 +101,32 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert ep_idx.shape == (1,) assert ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically - batch = Batch( - obs=2, - act=2, - rew=2, - terminated=0, - truncated=0, - obs_next="str2", - info={"a": 4, "d": {"e": -np.inf}}, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=2, + act=2, + rew=2, + terminated=0, + truncated=0, + obs_next="str2", + info={"a": 4, "d": {"e": -np.inf}}, + ), ) - b.add(batch) + replay_buffer.add(batch) info_keys = ["a", "b", "d"] - assert set(b.info.keys()) == set(info_keys) - assert b.info.a[1] == 4 - assert b.info.b.c[1] == 0 - assert b.info.d.e[1] == -np.inf + assert set(replay_buffer.info.keys()) == set(info_keys) + assert replay_buffer.info.a[1] == 4 + assert replay_buffer.info.b.c[1] == 0 + assert replay_buffer.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = [1] - batch.terminated = [0] - batch.truncated = [1] + batch.terminated = [0] # type: ignore[assignment] + batch.truncated = [1] # type: ignore[assignment] + assert isinstance(batch.info, Batch) batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) - ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) + ptr, ep_rew, ep_len, ep_idx = replay_buffer.add(batch, buffer_ids=[0]) assert ptr.shape == (1,) assert ptr[0] == 2 assert ep_rew.shape == (1,) @@ -124,17 +135,17 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert ep_len[0] == 2 assert ep_idx.shape == (1,) assert ep_idx[0] == 1 - assert set(b.info.keys()) == {*info_keys, "e"} - assert b.info.e.shape == (b.maxsize, 1, 4) + assert set(replay_buffer.info.keys()) == {*info_keys, "e"} + assert replay_buffer.info.e.shape == (replay_buffer.maxsize, 1, 4) with pytest.raises(IndexError): - b[22] + replay_buffer[22] # test prev / next - assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) - assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) + assert np.all(replay_buffer.prev(np.array([0, 1, 2])) == [0, 1, 1]) + assert np.all(replay_buffer.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] - b.add(batch, buffer_ids=[0]) - assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) - assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) + replay_buffer.add(batch, buffer_ids=[0]) + assert np.all(replay_buffer.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) + assert np.all(replay_buffer.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) def test_ignore_obs_next(size: int = 10) -> None: @@ -142,17 +153,20 @@ def test_ignore_obs_next(size: int = 10) -> None: buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): buf.add( - Batch( - obs={ - "mask1": np.array([i, 1, 1, 0, 0]), - "mask2": np.array([i + 4, 0, 1, 0, 0]), - "mask": i, - }, - act={"act_id": i, "position_id": i + 3}, - rew=i, - terminated=i % 3 == 0, - truncated=False, - info={"if": i}, + cast( + RolloutBatchProtocol, + Batch( + obs={ + "mask1": np.array([i, 1, 1, 0, 0]), + "mask2": np.array([i + 4, 0, 1, 0, 0]), + "mask": i, + }, + act={"act_id": i, "position_id": i + 3}, + rew=i, + terminated=i % 3 == 0, + truncated=False, + info={"if": i}, + ), ), ) indices = np.arange(len(buf)) @@ -224,34 +238,43 @@ def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated buf.add( - Batch( - obs=obs, - act=1, - rew=rew, - terminated=terminated, - truncated=truncated, - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=1, + rew=rew, + terminated=terminated, + truncated=truncated, + info=info, + ), ), ) buf2.add( - Batch( - obs=obs, - act=1, - rew=rew, - terminated=terminated, - truncated=truncated, - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=1, + rew=rew, + terminated=terminated, + truncated=truncated, + info=info, + ), ), ) buf3.add( - Batch( - obs=[obs, obs, obs], - act=1, - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=[obs, obs], - info=info, + cast( + RolloutBatchProtocol, + Batch( + obs=[obs, obs, obs], + act=1, + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=[obs, obs], + info=info, + ), ), ) obs = obs_next @@ -289,19 +312,22 @@ def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) - obs, info = env.reset() + obs, _ = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=act, - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=obs_next, - info=info, - policy=np.random.randn() - 0.5, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=act, + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info, + policy=np.random.randn() - 0.5, + ), ) batch_stack = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) @@ -362,14 +388,17 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=[act], - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=obs_next, - info=info, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info, + ), ) buf.add(batch) buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2]) @@ -448,14 +477,17 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: for i in range(ep_len): act = 1 obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=[act], - rew=rew, - terminated=(i == ep_len - 1), - truncated=(i == ep_len - 1), - obs_next=obs_next, - info=info, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info, + ), ) buf.add(batch) obs = obs_next @@ -476,14 +508,17 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: for i in range(ep_len): act = 1 obs_next, rew, terminated, truncated, info = env.step(act) - batch = Batch( - obs=obs, - act=[act], - rew=rew, - terminated=(i == ep_len - 1), - truncated=(i == ep_len - 1), - obs_next=obs_next, - info=info, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info, + ), ) if x == 1 and obs["observation"] < 10: obs = obs_next @@ -501,13 +536,16 @@ def test_update() -> None: buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): buf1.add( - Batch( - obs=np.array([i]), - act=float(i), - rew=i * i, - terminated=i % 2 == 0, - truncated=False, - info={"incident": "found"}, + cast( + RolloutBatchProtocol, + Batch( + obs=np.array([i]), + act=float(i), + rew=i * i, + terminated=i % 2 == 0, + truncated=False, + info={"incident": "found"}, + ), ), ) assert len(buf1) > len(buf2) @@ -610,23 +648,29 @@ def test_pickle() -> None: rew = np.array([1, 1]) for i in range(4): vbuf.add( - Batch( - obs=Batch(index=np.array([i])), - act=0, - rew=rew, - terminated=0, - truncated=0, + cast( + RolloutBatchProtocol, + Batch( + obs=Batch(index=np.array([i])), + act=0, + rew=rew, + terminated=0, + truncated=0, + ), ), ) for i in range(5): pbuf.add( - Batch( - obs=Batch(index=np.array([i])), - act=2, - rew=rew, - terminated=0, - truncated=0, - info=np.random.rand(), + cast( + RolloutBatchProtocol, + Batch( + obs=Batch(index=np.array([i])), + act=2, + rew=rew, + terminated=0, + truncated=0, + info=np.random.rand(), + ), ), ) # save & load @@ -660,8 +704,8 @@ def test_hdf5() -> None: "done": i % 3 == 2, "info": {"number": {"n": i, "t": info_t}, "extra": None}, } - buffers["array"].add(Batch(kwargs)) - buffers["prioritized"].add(Batch(kwargs)) + buffers["array"].add(cast(RolloutBatchProtocol, Batch(kwargs))) + buffers["prioritized"].add(cast(RolloutBatchProtocol, Batch(kwargs))) # save paths = {} @@ -703,12 +747,15 @@ def test_hdf5() -> None: def test_replaybuffermanager() -> None: buf = VectorReplayBuffer(20, 4) - batch = Batch( - obs=[1, 2, 3], - act=[1, 2, 3], - rew=[1, 2, 3], - terminated=[0, 0, 1], - truncated=[0, 0, 0], + batch = cast( + RolloutBatchProtocol, + Batch( + obs=[1, 2, 3], + act=[1, 2, 3], + rew=[1, 2, 3], + terminated=[0, 0, 1], + truncated=[0, 0, 0], + ), ) ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2]) assert np.all(ep_len == [0, 0, 1]) @@ -728,7 +775,10 @@ def test_replaybuffermanager() -> None: indices_next = buf.next(indices) assert np.allclose(indices_next, indices), indices_next assert np.allclose(buf.unfinished_index(), [0, 5]) - buf.add(Batch(obs=[4], act=[4], rew=[4], terminated=[1], truncated=[0]), buffer_ids=[3]) + buf.add( + cast(RolloutBatchProtocol, Batch(obs=[4], act=[4], rew=[4], terminated=[1], truncated=[0])), + buffer_ids=[3], + ) assert np.allclose(buf.unfinished_index(), [0, 5]) batch, indices = buf.sample(10) batch, indices = buf.sample(0) @@ -739,20 +789,32 @@ def test_replaybuffermanager() -> None: assert np.allclose(indices_next, indices), indices_next data = np.array([0, 0, 0, 0]) buf.add( - Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) buf.add( - Batch(obs=data, act=data, rew=data, terminated=1 - data, truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=1 - data, truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) assert len(buf) == 12 buf.add( - Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=data, truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) buf.add( - Batch(obs=data, act=data, rew=data, terminated=[0, 1, 0, 1], truncated=data), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=data, terminated=[0, 1, 0, 1], truncated=data), + ), buffer_ids=[0, 1, 2, 3], ) assert len(buf) == 20 @@ -839,7 +901,7 @@ def test_replaybuffermanager() -> None: ) assert np.allclose(buf.unfinished_index(), [4, 14]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], terminated=[1], truncated=[0]), + cast(RolloutBatchProtocol, Batch(obs=[1], act=[1], rew=[1], terminated=[1], truncated=[0])), buffer_ids=[2], ) assert np.all(ep_len == [3]) @@ -915,7 +977,7 @@ def test_cachedbuffer() -> None: assert buf.sample_indices(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], terminated=[0], truncated=[0]), + cast(RolloutBatchProtocol, Batch(obs=[1], act=[1], rew=[1], terminated=[0], truncated=[0])), buffer_ids=[1], ) obs = np.zeros(buf.maxsize) @@ -930,7 +992,7 @@ def test_cachedbuffer() -> None: assert np.all(ptr == [15]) assert np.all(ep_idx == [15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[2], act=[2], rew=[2], terminated=[1], truncated=[0]), + cast(RolloutBatchProtocol, Batch(obs=[2], act=[2], rew=[2], terminated=[1], truncated=[0])), buffer_ids=[3], ) obs[[0, 25]] = 2 @@ -946,7 +1008,10 @@ def test_cachedbuffer() -> None: assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_indices(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], terminated=[0, 1], truncated=[0, 0]), + cast( + RolloutBatchProtocol, + Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], terminated=[0, 1], truncated=[0, 0]), + ), buffer_ids=[3, 1], # TODO ) assert np.all(ep_len == [0, 2]) @@ -968,12 +1033,35 @@ def test_cachedbuffer() -> None: buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) rew = np.ones([4, 4]) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 1, 1], truncated=[0, 0, 0, 0])) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0])) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[1, 1, 1, 1], truncated=[0, 0, 0, 0])) - buf.add(Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0])) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 1, 1], truncated=[0, 0, 0, 0]), + ), + ) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0]), + ), + ) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[1, 1, 1, 1], truncated=[0, 0, 0, 0]), + ), + ) + buf.add( + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 0, 0, 0], truncated=[0, 0, 0, 0]), + ), + ) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=data, act=data, rew=rew, terminated=[0, 1, 0, 1], truncated=[0, 0, 0, 0]), + cast( + RolloutBatchProtocol, + Batch(obs=data, act=data, rew=rew, terminated=[0, 1, 0, 1], truncated=[0, 0, 0, 0]), + ), ) assert np.all(ptr == [1, -1, 11, -1]) assert np.all(ep_idx == [0, -1, 10, -1]) @@ -1041,14 +1129,17 @@ def test_multibuf_stack() -> None: truncated_list = [truncated] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num - batch = Batch( - obs=obs_list, - act=act_list, - rew=rew_list, - terminated=terminated_list, - truncated=truncated_list, - obs_next=obs_next_list, - info=info_list, + batch = cast( + RolloutBatchProtocol, + Batch( + obs=obs_list, + act=act_list, + rew=rew_list, + terminated=terminated_list, + truncated=truncated_list, + obs_next=obs_next_list, + info=info_list, + ), ) buf5.add(batch) buf4.add(batch) @@ -1184,13 +1275,16 @@ def test_multibuf_stack() -> None: ) obs = np.random.rand(size, 4, 84, 84) buf6.add( - Batch( - obs=[obs[2], obs[0]], - act=[1, 1], - rew=[0, 0], - terminated=[0, 1], - truncated=[0, 0], - obs_next=[obs[3], obs[1]], + cast( + RolloutBatchProtocol, + Batch( + obs=[obs[2], obs[0]], + act=[1, 1], + rew=[0, 0], + terminated=[0, 1], + truncated=[0, 0], + obs_next=[obs[3], obs[1]], + ), ), buffer_ids=[1, 2], ) @@ -1309,49 +1403,52 @@ def test_from_data() -> None: def test_custom_key() -> None: - batch = Batch( - obs_next=np.array( - [ + batch = cast( + RolloutBatchProtocol, + Batch( + obs_next=np.array( [ - 1.174, - -0.1151, - -0.609, - -0.5205, - -0.9316, - 3.236, - -2.418, - 0.386, - 0.2227, - -0.5117, - 2.293, + [ + 1.174, + -0.1151, + -0.609, + -0.5205, + -0.9316, + 3.236, + -2.418, + 0.386, + 0.2227, + -0.5117, + 2.293, + ], ], - ], - ), - rew=np.array([4.28125]), - act=np.array([[-0.3088, -0.4636, 0.4956]]), - truncated=np.array([False]), - obs=np.array( - [ + ), + rew=np.array([4.28125]), + act=np.array([[-0.3088, -0.4636, 0.4956]]), + truncated=np.array([False]), + obs=np.array( [ - 1.193, - -0.1203, - -0.6123, - -0.519, - -0.9434, - 3.32, - -2.266, - 0.9116, - 0.623, - 0.1259, - 0.363, + [ + 1.193, + -0.1203, + -0.6123, + -0.519, + -0.9434, + 3.32, + -2.266, + 0.9116, + 0.623, + 0.1259, + 0.363, + ], ], - ], + ), + terminated=np.array([False]), + done=np.array([False]), + returns=np.array([74.70343082]), + info=Batch(), + policy=Batch(), ), - terminated=np.array([False]), - done=np.array([False]), - returns=np.array([74.70343082]), - info=Batch(), - policy=Batch(), ) buffer_size = len(batch.rew) buffer = ReplayBuffer(buffer_size) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d03a54df7..95b604905 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Sequence from test.base.env import MoveToRightEnv, NXEnv -from typing import Any +from typing import Any, cast import gymnasium as gym import numpy as np @@ -17,7 +17,11 @@ VectorReplayBuffer, ) from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ActStateBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy, TrainingStats @@ -54,7 +58,7 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> Batch: + ) -> ActStateBatchProtocol: if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) @@ -69,9 +73,9 @@ def forward( action_shape = len(batch.obs["index"]) else: action_shape = len(batch.obs) - return Batch(act=np.ones(action_shape), state=state) + return cast(ActStateBatchProtocol, Batch(act=np.ones(action_shape), state=state)) action_shape = self.action_shape if self.action_shape else len(batch.obs) - return Batch(act=np.ones(action_shape), state=state) + return cast(ActStateBatchProtocol, Batch(act=np.ones(action_shape), state=state)) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: raise NotImplementedError diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index ce8a93640..287e79677 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -12,7 +12,12 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler from tianshou.data import Batch, Collector -from tianshou.data.types import BatchProtocol, ObsBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + BatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type from tianshou.policy import BasePolicy @@ -208,8 +213,8 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> Batch: - return Batch(act=np.stack([1] * len(batch))) + ) -> ActBatchProtocol: + return cast(ActBatchProtocol, Batch(act=np.stack([1] * len(batch)))) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: pass diff --git a/test/base/test_returns.py b/test/base/test_returns.py index ab4430b85..078893113 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -1,7 +1,10 @@ +from typing import cast + import numpy as np import torch from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import BasePolicy @@ -20,56 +23,68 @@ def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: def test_episodic_returns(size: int = 2560) -> None: fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) - batch = Batch( - terminated=np.array([1, 0, 0, 1, 0, 0, 0, 1.0]), - truncated=np.array([0, 0, 0, 0, 0, 1, 0, 0]), - rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.0]), - info=Batch( - { - "TimeLimit.truncated": np.array( - [False, False, False, False, False, True, False, False], - ), - }, + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([1, 0, 0, 1, 0, 0, 0, 1.0]), + truncated=np.array([0, 0, 0, 0, 0, 1, 0, 0]), + rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.0]), + info=Batch( + { + "TimeLimit.truncated": np.array( + [False, False, False, False, False, True, False, False], + ), + }, + ), ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert np.allclose(returns, ans) buf.reset() - batch = Batch( - terminated=np.array([0, 1, 0, 1, 0, 1, 0.0]), - truncated=np.array([0, 0, 0, 0, 0, 0, 0.0]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 1, 0, 1, 0, 1, 0.0]), + truncated=np.array([0, 0, 0, 0, 0, 0, 0.0]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert np.allclose(returns, ans) buf.reset() - batch = Batch( - terminated=np.array([0, 1, 0, 1, 0, 0, 1.0]), - truncated=np.array([0, 0, 0, 0, 0, 0, 0]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 1, 0, 1, 0, 0, 1.0]), + truncated=np.array([0, 0, 0, 0, 0, 0, 0]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), + ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert np.allclose(returns, ans) buf.reset() - batch = Batch( - terminated=np.array([0, 0, 0, 1.0, 0, 0, 0, 1, 0, 0, 0, 1]), - truncated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), - rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 0, 0, 1.0, 0, 0, 0, 1, 0, 0, 0, 1]), + truncated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), + ), ) for b in batch: - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) @@ -91,33 +106,36 @@ def test_episodic_returns(size: int = 2560) -> None: ) assert np.allclose(returns, ground_truth) buf.reset() - batch = Batch( - terminated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), - truncated=np.array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]), - rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), - info=Batch( - { - "TimeLimit.truncated": np.array( - [ - False, - False, - False, - True, - False, - False, - False, - True, - False, - False, - False, - False, - ], - ), - }, + batch = cast( + RolloutBatchProtocol, + Batch( + terminated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), + truncated=np.array([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]), + rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), + info=Batch( + { + "TimeLimit.truncated": np.array( + [ + False, + False, + False, + True, + False, + False, + False, + True, + False, + False, + False, + False, + ], + ), + }, + ), ), ) for b in iter(batch): - b.obs = b.act = 1 + b.obs = b.act = 1 # type: ignore[assignment] buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) @@ -180,12 +198,15 @@ def test_nstep_returns(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( - Batch( - obs=0, - act=0, - rew=i + 1, - terminated=i % 4 == 3, - truncated=False, + cast( + RolloutBatchProtocol, + Batch( + obs=0, + act=0, + rew=i + 1, + terminated=i % 4 == 3, + truncated=False, + ), ), ) batch, indices = buf.sample(0) @@ -258,13 +279,16 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( - Batch( - obs=0, - act=0, - rew=i + 1, - terminated=i % 4 == 3 and i != 3, - truncated=i == 3, - info={"TimeLimit.truncated": i == 3}, + cast( + RolloutBatchProtocol, + Batch( + obs=0, + act=0, + rew=i + 1, + terminated=i % 4 == 3 and i != 3, + truncated=i == 3, + info={"TimeLimit.truncated": i == 3}, + ), ), ) batch, indices = buf.sample(0) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 03b3d9849..650a5ccdf 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,11 +1,12 @@ import pprint import warnings -from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence +from collections.abc import Callable, Collection, Iterable, Iterator, KeysView, Sequence from copy import deepcopy from numbers import Number from types import EllipsisType from typing import ( Any, + Literal, Protocol, Self, TypeVar, @@ -16,13 +17,21 @@ ) import numpy as np +import pandas as pd import torch from deepdiff import DeepDiff +from torch.distributions import Categorical, Distribution, Independent, Normal + +from tianshou.utils import logging _SingleIndexType = slice | int | EllipsisType -IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] +IndexType = np.ndarray | _SingleIndexType | Sequence[_SingleIndexType] TBatch = TypeVar("TBatch", bound="BatchProtocol") -arr_type = torch.Tensor | np.ndarray +TDistribution = TypeVar("TDistribution", bound=Distribution) +T = TypeVar("T") +TArr = torch.Tensor | np.ndarray + +log = logging.getLogger(__name__) def _is_batch_set(obj: Any) -> bool: @@ -196,6 +205,36 @@ def alloc_by_keys_diff( meta[key] = create_value(batch[key], size, stack) +class ProtocolCalledException(Exception): + """The methods of a Protocol should never be called. + + Currently, no static type checker actually verifies that a class that inherits + from a Protocol does in fact provide the correct interface. Thus, it may happen + that a method of the protocol is called accidentally (this is an + implementation error). The normal error for that is a somewhat cryptic + AttributeError, wherefore we instead raise this custom exception in the + BatchProtocol. + + Finally and importantly: using this in BatchProtocol makes mypy verify the fields + in the various sub-protocols and thus renders is MUCH more useful! + """ + + +def get_sliced_dist(dist: TDistribution, index: IndexType) -> TDistribution: + """Slice a distribution object by the given index.""" + if isinstance(dist, Categorical): + return Categorical(probs=dist.probs[index]) # type: ignore[return-value] + if isinstance(dist, Normal): + return Normal(loc=dist.loc[index], scale=dist.scale[index]) # type: ignore[return-value] + if isinstance(dist, Independent): + return Independent( + get_sliced_dist(dist.base_dist, index), + dist.reinterpreted_batch_ndims, + ) # type: ignore[return-value] + else: + raise NotImplementedError(f"Unsupported distribution for slicing: {dist}") + + # Note: This is implemented as a protocol because the interface # of Batch is always extended by adding new fields. Having a hierarchy of # protocols building off this one allows for type safety and IDE support despite @@ -214,72 +253,75 @@ class BatchProtocol(Protocol): @property def shape(self) -> list[int]: - ... - + raise ProtocolCalledException + + # NOTE: even though setattr and getattr are defined for any object, we need + # to explicitly define them for the BatchProtocol, since otherwise mypy will + # complain about new fields being added dynamically. For example, things like + # `batch.new_field = ...` followed by using `batch.new_field` become type errors + # if getattr and setattr are missing in the BatchProtocol. + # + # For the moment, tianshou relies on this kind of dynamic-field-addition + # in many, many places. In principle, it would be better to construct new + # objects with new combinations of fields instead of mutating existing ones - the + # latter is error-prone and can't properly be expressed with types. May be in a + # future, rather different version of tianshou it would be feasible to have stricter + # typing. Then the need for Protocols would in fact disappear def __setattr__(self, key: str, value: Any) -> None: - ... + raise ProtocolCalledException def __getattr__(self, key: str) -> Any: - ... - - def __contains__(self, key: str) -> bool: - ... + raise ProtocolCalledException - def __getstate__(self) -> dict: - ... - - def __setstate__(self, state: dict) -> None: - ... + def __iter__(self) -> Iterator[Self]: + raise ProtocolCalledException @overload def __getitem__(self, index: str) -> Any: - ... + raise ProtocolCalledException @overload def __getitem__(self, index: IndexType) -> Self: - ... + raise ProtocolCalledException def __getitem__(self, index: str | IndexType) -> Any: - ... + raise ProtocolCalledException def __setitem__(self, index: str | IndexType, value: Any) -> None: - ... + raise ProtocolCalledException def __iadd__(self, other: Self | Number | np.number) -> Self: - ... + raise ProtocolCalledException def __add__(self, other: Self | Number | np.number) -> Self: - ... + raise ProtocolCalledException def __imul__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __mul__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __itruediv__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __truediv__(self, value: Number | np.number) -> Self: - ... + raise ProtocolCalledException def __repr__(self) -> str: - ... - - def __iter__(self) -> Iterator[Self]: - ... + raise ProtocolCalledException def __eq__(self, other: Any) -> bool: - ... + raise ProtocolCalledException @staticmethod def to_numpy(batch: TBatch) -> TBatch: """Change all torch.Tensor to numpy.ndarray and return a new Batch.""" - ... + raise ProtocolCalledException def to_numpy_(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" - ... + raise ProtocolCalledException @staticmethod def to_torch( @@ -288,7 +330,7 @@ def to_torch( device: str | int | torch.device = "cpu", ) -> TBatch: """Change all numpy.ndarray to torch.Tensor and return a new Batch.""" - ... + raise ProtocolCalledException def to_torch_( self, @@ -296,11 +338,11 @@ def to_torch_( device: str | int | torch.device = "cpu", ) -> None: """Change all numpy.ndarray to torch.Tensor in-place.""" - ... + raise ProtocolCalledException def cat_(self, batches: Self | Sequence[dict | Self]) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" - ... + raise ProtocolCalledException @staticmethod def cat(batches: Sequence[dict | TBatch]) -> TBatch: @@ -320,11 +362,11 @@ def cat(batches: Sequence[dict | TBatch]) -> TBatch: >>> c.common.c.shape (7, 5) """ - ... + raise ProtocolCalledException def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None: """Stack a list of Batch object into current batch.""" - ... + raise ProtocolCalledException @staticmethod def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: @@ -349,7 +391,7 @@ def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: If there are keys that are not shared across all batches, ``stack`` with ``axis != 0`` is undefined, and will cause an exception. """ - ... + raise ProtocolCalledException def empty_(self, index: slice | IndexType | None = None) -> Self: """Return an empty Batch object with 0 or None filled. @@ -376,7 +418,7 @@ def empty_(self, index: slice | IndexType | None = None) -> Self: ), ) """ - ... + raise ProtocolCalledException @staticmethod def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: @@ -384,14 +426,14 @@ def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: The shape is the same as the given Batch. """ - ... + raise ProtocolCalledException def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: """Update this batch from another dict/Batch.""" - ... + raise ProtocolCalledException def __len__(self) -> int: - ... + raise ProtocolCalledException def split( self, @@ -409,14 +451,109 @@ def split( :param merge_last: merge the last batch into the previous one. Default to False. """ - ... + raise ProtocolCalledException def to_dict(self, recurse: bool = True) -> dict[str, Any]: - ... + raise ProtocolCalledException def to_list_of_dicts(self) -> list[dict[str, Any]]: + raise ProtocolCalledException + + def get_keys(self) -> KeysView: + raise ProtocolCalledException + + def set_array_at_key( + self, + seq: np.ndarray, + key: str, + index: IndexType | None = None, + default_value: float | None = None, + ) -> None: + """Set a sequence of values at a given key. + + If `index` is not passed, the sequence must have the same length as the batch. + + :param seq: the array of values to set. + :param key: the key to set the sequence at. + :param index: the indices to set the sequence at. If None, the sequence must have + the same length as the batch and will be set at all indices. + :param default_value: this only applies if `index` is passed and the key does not exist yet + in the batch. In that case, entries outside the passed index will be filled + with this default value. + Note that the array at the key will be of the same dtype as the passed sequence, + so `default_value` should be such that numpy can cast it to this dtype. + """ + raise ProtocolCalledException + + def isnull(self) -> Self: + """Return a boolean mask of the same shape, indicating missing values.""" + raise ProtocolCalledException + + def hasnull(self) -> bool: + """Return whether the batch has missing values.""" + raise ProtocolCalledException + + def dropnull(self) -> Self: + """Return a batch where all items in which any value is null are dropped. + + Note that it is not the same as just dropping the entries of the sequence. + For example, with + + >>> b = Batch(a=[None, 2, 3, 4], b=[4, 5, None, 7]) + >>> b.dropnull() + + will result in + + >>> Batch(a=[2, 4], b=[5, 7]) + + This logic is applied recursively to all nested batches. The result is + the same as if the batch was flattened, entries were dropped, + and then the batch was reshaped back to the original nested structure. + """ + ... + + @overload + def apply_values_transform( + self, + values_transform: Callable[[np.ndarray | torch.Tensor], Any], + ) -> Self: + ... + + @overload + def apply_values_transform( + self, + values_transform: Callable, + inplace: Literal[True], + ) -> None: + ... + + @overload + def apply_values_transform( + self, + values_transform: Callable[[np.ndarray | torch.Tensor], Any], + inplace: Literal[False], + ) -> Self: ... + def apply_values_transform( + self, + values_transform: Callable[[np.ndarray | torch.Tensor], Any], + inplace: bool = False, + ) -> None | Self: + """Apply a function to all arrays in the batch, including nested ones. + + :param values_transform: the function to apply to the arrays. + :param inplace: whether to apply the function in-place. If False, a new batch is returned, + otherwise the batch is modified in-place and None is returned. + """ + raise ProtocolCalledException + + def get(self, key: str, default: Any | None = None) -> Any: + raise ProtocolCalledException + + def pop(self, key: str, default: Any | None = None) -> Any: + raise ProtocolCalledException + class Batch(BatchProtocol): """See :class:`~tianshou.data.batch.BatchProtocol`.""" @@ -459,6 +596,12 @@ def to_dict(self, recursive: bool = True) -> dict[str, Any]: def get_keys(self) -> KeysView: return self.__dict__.keys() + def get(self, key: str, default: Any | None = None) -> Any: + return self.__dict__.get(key, default) + + def pop(self, key: str, default: Any | None = None) -> Any: + return self.__dict__.pop(key, default) + def to_list_of_dicts(self) -> list[dict[str, Any]]: return [entry.to_dict() for entry in self] @@ -504,17 +647,28 @@ def __getitem__(self, index: IndexType) -> Self: ... def __getitem__(self, index: str | IndexType) -> Any: - """Return self[index].""" + """Returns either the value of a key or a sliced Batch object.""" if isinstance(index, str): return self.__dict__[index] batch_items = self.items() if len(batch_items) > 0: new_batch = Batch() + + sliced_obj: Any for batch_key, obj in batch_items: - if isinstance(obj, Batch) and len(obj.get_keys()) == 0: - new_batch.__dict__[batch_key] = Batch() + # None and empty Batches as values are added to any slice + if obj is None: + sliced_obj = None + elif isinstance(obj, Batch) and len(obj.get_keys()) == 0: + sliced_obj = Batch() + # We attempt slicing of a distribution. This is hacky, but presents an important special case + elif isinstance(obj, Distribution): + sliced_obj = get_sliced_dist(obj, index) + # All other objects are either array-like or Batch-like, so hopefully sliceable + # A batch should have no scalars else: - new_batch.__dict__[batch_key] = obj[index] + sliced_obj = obj[index] + new_batch.__dict__[batch_key] = sliced_obj return new_batch raise IndexError("Cannot access item from empty Batch object.") @@ -630,22 +784,17 @@ def __repr__(self) -> str: @staticmethod def to_numpy(batch: TBatch) -> TBatch: - batch_dict = deepcopy(batch) - for batch_key, obj in batch_dict.items(): - if isinstance(obj, torch.Tensor): - batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy() - elif isinstance(obj, Batch): - obj = Batch.to_numpy(obj) - batch_dict.__dict__[batch_key] = obj - - return batch_dict + result = deepcopy(batch) + result.to_numpy_() + return result def to_numpy_(self) -> None: - for batch_key, obj in self.items(): - if isinstance(obj, torch.Tensor): - self.__dict__[batch_key] = obj.detach().cpu().numpy() - elif isinstance(obj, Batch): - obj.to_numpy_() + def arr_to_numpy(arr: TArr) -> TArr: + if isinstance(arr, torch.Tensor): + return arr.detach().cpu().numpy() + return arr + + self.apply_values_transform(arr_to_numpy, inplace=True) @staticmethod def to_torch( @@ -653,10 +802,9 @@ def to_torch( dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", ) -> TBatch: - new_batch = Batch(batch, copy=True) - new_batch.to_torch_(dtype=dtype, device=device) - - return new_batch # type: ignore[return-value] + result = deepcopy(batch) + result.to_torch_(dtype=dtype, device=device) + return result def to_torch_( self, @@ -666,28 +814,23 @@ def to_torch_( if not isinstance(device, torch.device): device = torch.device(device) - for batch_key, obj in self.items(): - if isinstance(obj, torch.Tensor): - if ( - dtype is not None - and obj.dtype != dtype - or obj.device.type != device.type - or device.index != obj.device.index - ): - if dtype is not None: - self.__dict__[batch_key] = obj.type(dtype).to(device) - else: - self.__dict__[batch_key] = obj.to(device) - elif isinstance(obj, Batch): - obj.to_torch_(dtype, device) - else: - # ndarray or scalar - if not isinstance(obj, np.ndarray): - obj = np.asanyarray(obj) - obj = torch.from_numpy(obj).to(device) + def arr_to_torch(arr: TArr) -> TArr: + if isinstance(arr, np.ndarray): + return torch.from_numpy(arr).to(device) + + # TODO: simplify + if ( + dtype is not None + and arr.dtype != dtype + or arr.device.type != device.type + or device.index != arr.device.index + ): if dtype is not None: - obj = obj.type(dtype) - self.__dict__[batch_key] = obj + arr = arr.type(dtype) + return arr.to(device) + return arr + + self.apply_values_transform(arr_to_torch, inplace=True) def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: """Private method for Batch.cat_. @@ -763,18 +906,43 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: batches = [batches] # check input format batch_list = [] + + original_keys_only_batch = None + """A batch with all values removed, just keys left. Can be considered a sort of schema. + Will be either the schema of self, or of the first non-empty batch in the sequence. + """ + if len(self) > 0: + original_keys_only_batch = self.apply_values_transform(lambda x: None) + original_keys_only_batch.replace_empty_batches_by_none() + for batch in batches: if isinstance(batch, dict): - if len(batch) > 0: - batch_list.append(Batch(batch)) - elif isinstance(batch, Batch): - if len(batch.get_keys()) != 0: - batch_list.append(batch) - else: + batch = Batch(batch) + if not isinstance(batch, Batch): raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_") + if len(batch.get_keys()) == 0: + continue + if original_keys_only_batch is None: + original_keys_only_batch = batch.apply_values_transform(lambda x: None) + original_keys_only_batch.replace_empty_batches_by_none() + batch_list.append(batch) + continue + + cur_keys_only_batch = batch.apply_values_transform(lambda x: None) + cur_keys_only_batch.replace_empty_batches_by_none() + if original_keys_only_batch != cur_keys_only_batch: + raise ValueError( + f"Batch.cat_ only supports concatenation of batches with the same structure but got " + f"structures: \n{original_keys_only_batch}\n and\n{cur_keys_only_batch}.", + ) + batch_list.append(batch) if len(batch_list) == 0: return + batches = batch_list + + # TODO: lot's of the remaining logic is devoted to filling up remaining keys with zeros + # this should be removed, and also the check above should be extended to nested keys try: # len(batch) here means batch is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and @@ -788,6 +956,7 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: ) from exception if len(self.get_keys()) != 0: batches = [self, *list(batches)] + # len of zero means that that item is Batch() and should be ignored lens = [0 if len(self) == 0 else len(self), *lens] self.__cat(batches, lens) @@ -919,7 +1088,6 @@ def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: self.update(kwargs) def __len__(self) -> int: - """Return len(self).""" lens = [] for obj in self.__dict__.values(): # TODO: causes inconsistent behavior to batch with empty batches @@ -967,3 +1135,151 @@ def split( yield self[indices[idx:]] break yield self[indices[idx : idx + size]] + + @overload + def apply_values_transform( + self, + values_transform: Callable, + ) -> Self: + ... + + @overload + def apply_values_transform( + self, + values_transform: Callable, + inplace: Literal[True], + ) -> None: + ... + + @overload + def apply_values_transform( + self, + values_transform: Callable, + inplace: Literal[False], + ) -> Self: + ... + + def apply_values_transform( + self, + values_transform: Callable, + inplace: bool = False, + ) -> None | Self: + """Applies a function to all non-batch-values in the batch, including + values in nested batches. + + A batch with keys pointing to either batches or to non-batch values can + be thought of as a tree of Batch nodes. This function traverses the tree + and applies the function to all leaf nodes (i.e. values that are not + batches themselves). + + The values are usually arrays, but can also be scalar values of an + arbitrary type since retrieving a single entry from a Batch a la + `batch[0]` will return a batch with scalar values. + """ + return _apply_batch_values_func_recursively(self, values_transform, inplace=inplace) + + def set_array_at_key( + self, + arr: np.ndarray, + key: str, + index: IndexType | None = None, + default_value: float | None = None, + ) -> None: + if index is not None: + if key not in self.get_keys(): + log.info( + f"Key {key} not found in batch, " + f"creating a sequence of len {len(self)} with {default_value=} for it.", + ) + try: + self[key] = np.array([default_value] * len(self), dtype=arr.dtype) + except TypeError as exception: + raise TypeError( + f"Cannot create a sequence of dtype {arr.dtype} with default value {default_value}. " + f"You can fix this either by passing an array with the correct dtype or by passing " + f"a different default value that can be cast to the array's dtype (or both).", + ) from exception + else: + existing_entry = self[key] + if isinstance(existing_entry, BatchProtocol): + raise ValueError( + f"Cannot set sequence at key {key} because it is a nested batch, " + f"can only set a subsequence of an array.", + ) + self[key][index] = arr + else: + if len(arr) != len(self): + raise ValueError( + f"Sequence length {len(arr)} does not match " + f"batch length {len(self)}. For setting a subsequence with missing " + f"entries filled up by default values, consider passing an index.", + ) + self[key] = arr + + def isnull(self) -> Self: + return self.apply_values_transform(pd.isnull, inplace=False) + + def hasnull(self) -> bool: + isnan_batch = self.isnull() + is_any_null_batch = isnan_batch.apply_values_transform(np.any, inplace=False) + + def is_any_true(boolean_batch: BatchProtocol) -> bool: + for val in boolean_batch.values(): + if isinstance(val, BatchProtocol): + if is_any_true(val): + return True + else: + assert val.size == 1, "This shouldn't have happened, it's a bug!" + # an unsized array with a boolean, e.g. np.array(False). behaves like the boolean itself + if val: + return True + return False + + return is_any_true(is_any_null_batch) + + def dropnull(self) -> Self: + # we need to use dicts since a batch retrieved for a single index has no length and cat fails + # TODO: make cat work with batches containing scalars? + sub_batches = [] + for b in self: + if b.hasnull(): + continue + # needed for cat to work + b = b.apply_values_transform(np.atleast_1d) + sub_batches.append(b) + return Batch.cat(sub_batches) + + def replace_empty_batches_by_none(self) -> None: + """Goes through the batch-tree" recursively and replaces empty batches by None. + + This is useful for extracting the structure of a batch without the actual data, + especially in combination with `apply_values_transform` with a + transform function a la `lambda x: None`. + """ + empty_batch = Batch() + for key, val in self.items(): + if isinstance(val, Batch): + if val == empty_batch: + self[key] = None + else: + val.replace_empty_batches_by_none() + + +def _apply_batch_values_func_recursively( + batch: TBatch, + values_transform: Callable, + inplace: bool = False, +) -> TBatch | None: + """Applies the desired function on all values of the batch recursively. + + See docstring of the corresponding method in the Batch class for more details. + """ + result = batch if inplace else deepcopy(batch) + for key, val in batch.__dict__.items(): + if isinstance(val, BatchProtocol): + result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False) + else: + result[key] = values_transform(val) + if not inplace: + return result + return None diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 2ecddc5ce..2699d92d7 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -355,7 +355,7 @@ def get( set, return this default_value. :param stack_num: Default to self.stack_num. """ - if key not in self._meta and default_value is not None: + if key not in self._meta.get_keys() and default_value is not None: return default_value val = self._meta[key] if stack_num is None: diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index bef6a06a0..406e39afd 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -103,5 +103,8 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> PrioBatchP batch.weight = weight / np.max(weight) if self._weight_norm else weight return cast(PrioBatchProtocol, batch) + def sample(self, batch_size: int | None) -> tuple[PrioBatchProtocol, np.ndarray]: + return cast(tuple[PrioBatchProtocol, np.ndarray], super().sample(batch_size=batch_size)) + def set_beta(self, beta: float) -> None: self._beta = beta diff --git a/tianshou/data/types.py b/tianshou/data/types.py index 3572e5484..a4fd43543 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -4,7 +4,7 @@ import torch from tianshou.data import Batch -from tianshou.data.batch import BatchProtocol, arr_type +from tianshou.data.batch import BatchProtocol, TArr TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] @@ -19,24 +19,24 @@ class ObsBatchProtocol(BatchProtocol, Protocol): Typically used inside a policy's forward """ - obs: arr_type | BatchProtocol - info: arr_type + obs: TArr | BatchProtocol + info: TArr | BatchProtocol class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" - obs_next: arr_type | BatchProtocol - act: arr_type + obs_next: TArr | BatchProtocol + act: TArr rew: np.ndarray - terminated: arr_type - truncated: arr_type + terminated: TArr + truncated: TArr class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" - returns: arr_type + returns: TArr class PrioBatchProtocol(RolloutBatchProtocol, Protocol): @@ -55,7 +55,7 @@ class RecurrentStateBatch(BatchProtocol, Protocol): class ActBatchProtocol(BatchProtocol, Protocol): """Simplest batch, just containing the action. Useful e.g., for random policy.""" - act: arr_type + act: TArr class ActStateBatchProtocol(ActBatchProtocol, Protocol): diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index c4fc9af3b..d886180a5 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -15,7 +15,7 @@ from torch import nn from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as -from tianshou.data.batch import Batch, BatchProtocol, arr_type +from tianshou.data.batch import Batch, BatchProtocol, TArr from tianshou.data.buffer.base import TBuffer from tianshou.data.types import ( ActBatchProtocol, @@ -355,7 +355,7 @@ def forward( """ @staticmethod - def _action_to_numpy(act: arr_type) -> np.ndarray: + def _action_to_numpy(act: TArr) -> np.ndarray: act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch if not isinstance(act, np.ndarray): raise ValueError( @@ -365,7 +365,7 @@ def _action_to_numpy(act: arr_type) -> np.ndarray: def map_action( self, - act: arr_type, + act: TArr, ) -> np.ndarray: """Map raw network output to action range in gym's env.action_space. @@ -400,7 +400,7 @@ def map_action( def map_action_inverse( self, - act: arr_type, + act: TArr, ) -> np.ndarray: """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`. diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index dee1a80a3..991c4aace 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -196,8 +196,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBCQT # now target_Q: (batch_size, 1) target_Q = ( - batch.rew.reshape(-1, 1) + (1 - batch.done).reshape(-1, 1) * self.gamma * target_Q + batch.rew.reshape(-1, 1) + + torch.logical_not(batch.done).reshape(-1, 1) * self.gamma * target_Q ) + target_Q = target_Q.float() current_Q1 = self.critic(obs, act) current_Q2 = self.critic2(obs, act) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 1ce6d83d4..66438c758 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -280,7 +280,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLT target_Q = torch.min(target_Q1, target_Q2) - self.alpha * new_log_pi - target_Q = rew + self.gamma * (1 - batch.done) * target_Q.flatten() + target_Q = rew + torch.logical_not(batch.done) * self.gamma * target_Q.flatten() + target_Q = target_Q.float() # shape: (batch_size) # compute critic loss diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index e88214d45..05cc8db8f 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -158,7 +158,7 @@ def process_fn( # type: ignore results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew - return Batch(results) + return cast(MAPRolloutBatchProtocol, Batch(results)) _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")