From adf7be865cd26991769dff3bba3d47a854951a6e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:23:04 +0200 Subject: [PATCH] Add non in-place version of `Batch.to_torch`: * Breaking change: Previous in-place `Batch.to_torch` is now `Batch.to_torch_` (following naming convention of other in-place methods in codebase). * Update places where in-place was expected * Add tests for both to_torch/to_torch_ --- docs/01_tutorials/03_batch.rst | 4 ++-- docs/02_notebooks/L1_Batch.ipynb | 2 +- test/base/test_batch.py | 30 +++++++++++++++++++++++++++--- tianshou/data/batch.py | 22 +++++++++++++++++++++- tianshou/data/utils/converter.py | 2 +- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/docs/01_tutorials/03_batch.rst b/docs/01_tutorials/03_batch.rst index 46fa86b3d..08a335948 100644 --- a/docs/01_tutorials/03_batch.rst +++ b/docs/01_tutorials/03_batch.rst @@ -475,12 +475,12 @@ Miscellaneous Notes .. raw:: html
- Batch.to_torch and Batch.to_numpy + Batch.to_torch_ and Batch.to_numpy_ :: >>> data = Batch(a=np.zeros((3, 4))) - >>> data.to_torch(dtype=torch.float32, device='cpu') + >>> data.to_torch_(dtype=torch.float32, device='cpu') >>> print(data.a) tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], diff --git a/docs/02_notebooks/L1_Batch.ipynb b/docs/02_notebooks/L1_Batch.ipynb index 54008ee64..21e143a33 100644 --- a/docs/02_notebooks/L1_Batch.ipynb +++ b/docs/02_notebooks/L1_Batch.ipynb @@ -333,7 +333,7 @@ "source": [ "batch_cat.to_numpy_()\n", "print(batch_cat)\n", - "batch_cat.to_torch()\n", + "batch_cat.to_torch_()\n", "print(batch_cat)" ] }, diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 82ff4a3fb..5e90dfb66 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -379,7 +379,7 @@ def test_batch_over_batch_to_torch() -> None: b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), ) batch.b.__dict__["e"] = 1 # bypass the check - batch.to_torch() + batch.to_torch_() assert isinstance(batch.a, torch.Tensor) assert isinstance(batch.b.c, torch.Tensor) assert isinstance(batch.b.d, torch.Tensor) @@ -391,7 +391,7 @@ def test_batch_over_batch_to_torch() -> None: assert batch.b.e.dtype == torch.int32 else: assert batch.b.e.dtype == torch.int64 - batch.to_torch(dtype=torch.float32) + batch.to_torch_(dtype=torch.float32) assert batch.a.dtype == torch.float32 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float32 @@ -477,7 +477,7 @@ def test_batch_from_to_numpy_without_copy() -> None: batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) a_mem_addr_orig = batch.a.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] - batch.to_torch() + batch.to_torch_() batch.to_numpy_() a_mem_addr_new = batch.a.__array_interface__["data"][0] c_mem_addr_new = batch.b.c.__array_interface__["data"][0] @@ -727,6 +727,30 @@ def test_to_numpy_() -> None: assert isinstance(batch.c.d, np.ndarray) +class TestToTorch: + """Tests for `Batch.to_torch()` and its in-place counterpart `Batch.to_torch_()` .""" + + @staticmethod + def test_to_torch() -> None: + batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])}) + new_batch: Batch = Batch.to_torch(batch) + assert id(batch) != id(new_batch) + assert isinstance(batch.b, np.ndarray) + assert isinstance(batch.c.d, np.ndarray) + + assert isinstance(new_batch.b, torch.Tensor) + assert isinstance(new_batch.c.d, torch.Tensor) + + @staticmethod + def test_to_torch_() -> None: + batch = Batch(a=1, b=np.arange(5), c={"d": np.array([1, 2, 3])}) + id_batch = id(batch) + batch.to_torch_() + assert id_batch == id(batch) + assert isinstance(batch.b, torch.Tensor) + assert isinstance(batch.c.d, torch.Tensor) + + if __name__ == "__main__": test_batch() test_batch_over_batch() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index d911788c6..e4ea4a8b1 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -281,7 +281,16 @@ def to_numpy_(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" ... + @staticmethod def to_torch( + batch: TBatch, + dtype: torch.dtype | None = None, + device: str | int | torch.device = "cpu", + ) -> TBatch: + """Change all numpy.ndarray to torch.Tensor and return a new Batch.""" + ... + + def to_torch_( self, dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", @@ -641,7 +650,18 @@ def to_numpy_(self) -> None: elif isinstance(obj, Batch): obj.to_numpy_() + @staticmethod def to_torch( + batch: TBatch, + dtype: torch.dtype | None = None, + device: str | int | torch.device = "cpu", + ) -> TBatch: + new_batch = Batch(batch, copy=True) + new_batch.to_torch_(dtype=dtype, device=device) + + return new_batch # type: ignore[return-value] + + def to_torch_( self, dtype: torch.dtype | None = None, device: str | int | torch.device = "cpu", @@ -662,7 +682,7 @@ def to_torch( else: self.__dict__[batch_key] = obj.to(device) elif isinstance(obj, Batch): - obj.to_torch(dtype, device) + obj.to_torch_(dtype, device) else: # ndarray or scalar if not isinstance(obj, np.ndarray): diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 7edf3ff45..8f07e0494 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -57,7 +57,7 @@ def to_torch( return to_torch(np.asanyarray(x), dtype, device) if isinstance(x, dict | Batch): x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x) - x.to_torch(dtype, device) + x.to_torch_(dtype, device) return x if isinstance(x, list | tuple): return to_torch(_parse_value(x), dtype, device)