Skip to content

Commit

Permalink
Improvements in batch (thu-ml#1181)
Browse files Browse the repository at this point in the history
This PR contains several important extensions and improvements of Batch.

1. A method for setting a new sequence. This is a first step towards
using code that's less reliant on setting attributes directly. The
method also permits setting a subsequence and filling up with default
values. This will help ensuring that all entries in a batch are of the
same length, something that we should start enforcing soon.
2. A new method for applying arbitrary transformations to "leaves" in a
Batch. This simplifies already existing code and is generally very handy
for users
3. The arbitrary transformations allowed implementing `isnull`,
`hasnull` and `dropnull`, which helps finding errors early.
4. The arbitrary transformations now also allow extracting a `schema`
from a batch. This was used in `Batch.cat_` to perform additional input
validation (we now make sure there that the structures are the same when
concatenating batches). This input validation is a **breaking change**!
Some tests that concatenated incompatible batches were removed.
Eventually, we can add a `get_schema` method to the batch that will
retrieve metainfo like shapes and datatypes. For now, this is delegated
to the user who can use the new `apply_values_transform`
5. New feature: slicing of torch distributions. Previously several
batches contained instances of `Distribution` which were not properly
sliced, inviting bugs and errors in user code. This is now fixed -
albeit not in a pretty way (torch doesn't allow slicing the objects
natively)
6. Stricter typing and some further minor extensions and simplifications

The new code was extensively tested and documented.
  • Loading branch information
MischaPanch authored Jul 31, 2024
1 parent 33e30b2 commit 6154292
Show file tree
Hide file tree
Showing 16 changed files with 1,053 additions and 486 deletions.
2 changes: 1 addition & 1 deletion docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,4 @@ v_s
v_s_
obs
obs_next

dtype
18 changes: 11 additions & 7 deletions examples/inverse/irl_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datetime
import os
import pprint
from typing import SupportsFloat
from typing import SupportsFloat, cast

import d4rl
import gymnasium as gym
Expand All @@ -16,6 +16,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.data.types import RolloutBatchProtocol
from tianshou.env import SubprocVectorEnv, VectorEnvNormObs
from tianshou.policy import GAILPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -185,12 +186,15 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:

for i in range(dataset_size):
expert_buffer.add(
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
cast(
RolloutBatchProtocol,
Batch(
obs=dataset["observations"][i],
act=dataset["actions"][i],
rew=dataset["rewards"][i],
done=dataset["terminals"][i],
obs_next=dataset["next_observations"][i],
),
),
)
print("dataset loaded")
Expand Down
6 changes: 4 additions & 2 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,16 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf.
if self.index == self.size:
self.terminated = True
return self._get_state(), self._get_reward(), self.terminated, False, {}

info_dict = {"key": 1, "env": self}
if action == 0:
self.index = max(self.index - 1, 0)
return (
self._get_state(),
self._get_reward(),
self.terminated,
False,
{"key": 1, "env": self} if self.dict_state else {},
info_dict,
)
if action == 1:
self.index += 1
Expand All @@ -164,7 +166,7 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf.
self._get_reward(),
self.terminated,
False,
{"key": 1, "env": self},
info_dict,
)
return None

Expand Down
Loading

0 comments on commit 6154292

Please sign in to comment.