Skip to content

Commit

Permalink
Implement non-inplace to_numpy for Batch
Browse files Browse the repository at this point in the history
  * Breaking change: Previous in-place `Batch.to_numpy` is now `Batch.to_numpy_` (following naming convention of other in-place methods).
  * Update places where in-place was expected
  * Add tests for both to_numpy/to_numpy_
  • Loading branch information
dantp-ai committed Apr 16, 2024
1 parent 98d611c commit 164cf84
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/01_tutorials/03_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ Miscellaneous Notes
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> # data.to_numpy is also available
>>> data.to_numpy()
>>> # data.to_numpy_ is also available
>>> data.to_numpy_()

.. raw:: html

Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L1_Batch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@
},
"outputs": [],
"source": [
"batch_cat.to_numpy()\n",
"batch_cat.to_numpy_()\n",
"print(batch_cat)\n",
"batch_cat.to_torch()\n",
"print(batch_cat)"
Expand Down
26 changes: 25 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None:
a_mem_addr_orig = batch.a.__array_interface__["data"][0]
c_mem_addr_orig = batch.b.c.__array_interface__["data"][0]
batch.to_torch()
batch.to_numpy()
batch.to_numpy_()
a_mem_addr_new = batch.a.__array_interface__["data"][0]
c_mem_addr_new = batch.b.c.__array_interface__["data"][0]
assert a_mem_addr_new == a_mem_addr_orig
Expand Down Expand Up @@ -703,6 +703,30 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None:
assert not DeepDiff(batch.to_dict(recurse=True), expected)


class TestToNumpy:
"""Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` ."""

@staticmethod
def test_to_numpy() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
new_batch: Batch = Batch.to_numpy(batch)
assert id(batch) != id(new_batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)

assert isinstance(new_batch.b, np.ndarray)
assert isinstance(new_batch.c.d, np.ndarray)

@staticmethod
def test_to_numpy_() -> None:
batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])})
id_batch = id(batch)
batch.to_numpy_()
assert id_batch == id(batch)
assert isinstance(batch.b, np.ndarray)
assert isinstance(batch.c.d, np.ndarray)


if __name__ == "__main__":
test_batch()
test_batch_over_batch()
Expand Down
31 changes: 24 additions & 7 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,12 @@ def __iter__(self) -> Iterator[Self]:
def __eq__(self, other: Any) -> bool:
...

def to_numpy(self) -> None:
@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
"""Change all torch.Tensor to numpy.ndarray and return a new Batch."""
...

def to_numpy_(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...

Expand Down Expand Up @@ -508,10 +513,10 @@ def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False

self.to_numpy()
other.to_numpy()
this_dict = self.to_dict(recurse=True)
other_dict = other.to_dict(recurse=True)
this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
this_dict = this_batch_no_torch_tensor.to_dict(recurse=True)
other_dict = other_batch_no_torch_tensor.to_dict(recurse=True)

return not DeepDiff(this_dict, other_dict)

Expand Down Expand Up @@ -614,12 +619,24 @@ def __repr__(self) -> str:
self_str = self.__class__.__name__ + "()"
return self_str

def to_numpy(self) -> None:
@staticmethod
def to_numpy(batch: TBatch) -> TBatch:
batch_dict = deepcopy(batch)
for batch_key, obj in batch_dict.items():
if isinstance(obj, torch.Tensor):
batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj = Batch.to_numpy(obj)
batch_dict.__dict__[batch_key] = obj

return batch_dict

def to_numpy_(self) -> None:
for batch_key, obj in self.items():
if isinstance(obj, torch.Tensor):
self.__dict__[batch_key] = obj.detach().cpu().numpy()
elif isinstance(obj, Batch):
obj.to_numpy()
obj.to_numpy_()

def to_torch(
self,
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray:
return np.array(None, dtype=object)
if isinstance(x, dict | Batch):
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
x.to_numpy()
x.to_numpy_()
return x
if isinstance(x, list | tuple):
return to_numpy(_parse_value(x))
Expand Down

0 comments on commit 164cf84

Please sign in to comment.