Skip to content

Commit

Permalink
Allow two (same/different) Batch objs to be tested for equality
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai committed Apr 4, 2024
1 parent 8a0629d commit a4206ee
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
42 changes: 42 additions & 0 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,47 @@ def test_batch_standard_compatibility() -> None:
Batch()[0]


def test_batch_eq() -> None:
# Different keys
batch1 = Batch(a=[1, 2], b=[100, 50])
batch2 = Batch(b=[1, 2], c=[100, 50])
assert batch1 != batch2, "Keys are not matching."

# Missing keys
batch1 = Batch(a=[1, 2], b=[2, 3, 4])
batch2 = Batch(a=[1, 2], b=[2, 3, 4])
batch2.pop("b")
assert batch1 != batch2, "Keys are not matching."

# Different types for the same key
batch1 = Batch(a=[1, 2, 3], b=[4, 5])
batch2 = Batch(a=[1, 2, 3], b=Batch(a=[4, 5]))
assert batch1 != batch2, "Objects have different types"

# Different array types for the same key
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5]))
batch2 = Batch(a=[1, 2, 3], b=torch.Tensor([4, 5]))
assert batch1 != batch2, "Objects have different types"

# Nested Batch objects with values
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 2, 4]), b=[4, 5])
assert batch1 != batch2, "Nested batches have different values."

# Arrays with different shapes or values
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5])
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
assert batch1 != batch2, "Nested objects have different lengths"

# Same slice from the same batch
batch1 = Batch(a=[1, 2, 3])
assert batch1[:2] == batch1[:2], "Batch slice should be the same"

# Same slice from the same batch with ellipsis and slice
batch1 = Batch(a=Batch(a=[1, 2, 3]), b=[4, 5], c=[100, 1001, 2000])
assert batch1[..., 1:] == batch1[..., 1:], "Batch slice should be the same"


if __name__ == "__main__":
test_batch()
test_batch_over_batch()
Expand All @@ -576,3 +617,4 @@ def test_batch_standard_compatibility() -> None:
test_batch_cat_and_stack()
test_batch_copy()
test_batch_empty()
test_batch_eq()
32 changes: 32 additions & 0 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def __repr__(self) -> str:
def __iter__(self) -> Iterator[Self]:
...

def __eq__(self, other: Self) -> bool: # type: ignore
...

def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...
Expand Down Expand Up @@ -500,6 +503,35 @@ def __getitem__(self, index: str | IndexType) -> Any:
return new_batch
raise IndexError("Cannot access item from empty Batch object.")

def __eq__(self, other: Self) -> bool: # type: ignore
this_dict = self.__dict__
other_dict = other.__dict__

if len(this_dict) != len(other_dict):
return False
for batch_key, obs in this_dict.items():
if batch_key not in other_dict:
return False

other_val = other.__dict__[batch_key]

if batch_key in other_dict:
if isinstance(obs, Batch) and isinstance(other_val, Batch):
if not obs == other_val:
return False
elif isinstance(obs, np.ndarray) and isinstance(other_val, np.ndarray):
if not np.all(np.equal(obs.shape, other_val.shape)):
return False
if not np.all(np.equal(obs, other_val)):
return False
elif isinstance(obs, torch.Tensor) and isinstance(other_val, torch.Tensor):
if not torch.equal(obs, other_val):
return False
else:
return False

return True

def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
Expand Down

0 comments on commit a4206ee

Please sign in to comment.