-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for user input to play against trained tianshou models
- Loading branch information
elliottower
committed
Feb 1, 2023
1 parent
c05041c
commit 5b5843f
Showing
3 changed files
with
229 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Extending tianshou collector class to work with manual policy (for user input) | ||
import time | ||
import warnings | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
|
||
from tianshou.data import ( | ||
Batch, | ||
ReplayBuffer, | ||
) | ||
from tianshou.env import BaseVectorEnv, DummyVectorEnv | ||
from tianshou.policy import BasePolicy | ||
|
||
from tianshou.data.collector import Collector | ||
|
||
|
||
class ManualPolicyCollector(Collector): | ||
def __init__( | ||
self, | ||
policy: BasePolicy, | ||
env: Union[gym.Env, BaseVectorEnv], | ||
buffer: Optional[ReplayBuffer] = None, | ||
preprocess_fn: Optional[Callable[..., Batch]] = None, | ||
exploration_noise: bool = False, | ||
) -> None: | ||
super(ManualPolicyCollector, self).__init__(policy=policy, env=env, exploration_noise=exploration_noise) | ||
|
||
# Custom function to collect the result of an inputted action | ||
def collect_result( | ||
self, | ||
action: int = None, | ||
render: Optional[float] = None, | ||
no_grad: bool = True, | ||
gym_reset_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
"""Collect the results of an inputted action.. | ||
:param int action: the action you want to collect the results for. | ||
:param float render: the sleep time between rendering consecutive frames. | ||
Default to None (no rendering). | ||
:param bool no_grad: whether to retain gradient in policy.forward(). Default to | ||
True (no gradient retaining). | ||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's | ||
reset function. Defaults to None (extra keyword arguments) | ||
:return: A dict including the following keys | ||
* ``n/ep`` collected number of episodes. | ||
* ``n/st`` collected number of steps. | ||
* ``rews`` array of episode reward over collected episodes. | ||
* ``lens`` array of episode length over collected episodes. | ||
* ``idxs`` array of episode start index in buffer over collected episodes. | ||
* ``rew`` mean of episodic rewards. | ||
* ``len`` mean of episodic lengths. | ||
* ``rew_std`` standard error of episodic rewards. | ||
* ``len_std`` standard error of episodic lengths. | ||
""" | ||
assert not self.env.is_async, "Please use AsyncCollector if using async venv." | ||
|
||
ready_env_ids = np.arange(self.env_num) | ||
|
||
start_time = time.time() | ||
|
||
step_count = 0 | ||
episode_count = 0 | ||
episode_rews = [] | ||
episode_lens = [] | ||
episode_start_indices = [] | ||
|
||
while True: | ||
assert len(self.data) == len(ready_env_ids) | ||
# restore the state: if the last state is None, it won't store | ||
last_state = self.data.policy.pop("hidden_state", None) | ||
|
||
# use hard coded action (rather than using a policy or randomly sampling) | ||
self.data.update(act=action) | ||
|
||
# hard code to only update for a single step | ||
n_step = 1 | ||
|
||
# get bounded and remapped actions first (not saved into buffer) | ||
action_remap = self.policy.map_action(self.data.act) | ||
# step in env | ||
result = self.env.step(action_remap, ready_env_ids) # type: ignore | ||
if len(result) == 5: | ||
obs_next, rew, terminated, truncated, info = result | ||
done = np.logical_or(terminated, truncated) | ||
elif len(result) == 4: | ||
obs_next, rew, done, info = result | ||
if isinstance(info, dict): | ||
truncated = info["TimeLimit.truncated"] | ||
else: | ||
truncated = np.array( | ||
[ | ||
info_item.get("TimeLimit.truncated", False) | ||
for info_item in info | ||
] | ||
) | ||
terminated = np.logical_and(done, ~truncated) | ||
else: | ||
raise ValueError() | ||
|
||
self.data.update( | ||
obs_next=obs_next, | ||
rew=rew, | ||
terminated=terminated, | ||
truncated=truncated, | ||
done=done, | ||
info=info | ||
) | ||
if self.preprocess_fn: | ||
self.data.update( | ||
self.preprocess_fn( | ||
obs_next=self.data.obs_next, | ||
rew=self.data.rew, | ||
done=self.data.done, | ||
info=self.data.info, | ||
policy=self.data.policy, | ||
env_id=ready_env_ids, | ||
) | ||
) | ||
|
||
if render: | ||
self.env.render() | ||
if render > 0 and not np.isclose(render, 0): | ||
time.sleep(render) | ||
|
||
# add data into the buffer | ||
ptr, ep_rew, ep_len, ep_idx = self.buffer.add( | ||
self.data, buffer_ids=ready_env_ids | ||
) | ||
|
||
# collect statistics | ||
step_count += len(ready_env_ids) | ||
|
||
if np.any(done): | ||
env_ind_local = np.where(done)[0] | ||
env_ind_global = ready_env_ids[env_ind_local] | ||
episode_count += len(env_ind_local) | ||
episode_lens.append(ep_len[env_ind_local]) | ||
episode_rews.append(ep_rew[env_ind_local]) | ||
episode_start_indices.append(ep_idx[env_ind_local]) | ||
# now we copy obs_next to obs, but since there might be | ||
# finished episodes, we have to reset finished envs first. | ||
self._reset_env_with_ids( | ||
env_ind_local, env_ind_global, gym_reset_kwargs | ||
) | ||
for i in env_ind_local: | ||
self._reset_state(i) | ||
|
||
# remove surplus env id from ready_env_ids | ||
# to avoid bias in selecting environments | ||
|
||
|
||
self.data.obs = self.data.obs_next | ||
|
||
if (n_step and step_count >= n_step): | ||
break | ||
|
||
# generate statistics | ||
self.collect_step += step_count | ||
self.collect_episode += episode_count | ||
self.collect_time += max(time.time() - start_time, 1e-9) | ||
|
||
if episode_count > 0: | ||
rews, lens, idxs = list( | ||
map( | ||
np.concatenate, | ||
[episode_rews, episode_lens, episode_start_indices] | ||
) | ||
) | ||
rew_mean, rew_std = rews.mean(), rews.std() | ||
len_mean, len_std = lens.mean(), lens.std() | ||
else: | ||
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) | ||
rew_mean = rew_std = len_mean = len_std = 0 | ||
|
||
return { | ||
"n/ep": episode_count, | ||
"n/st": step_count, | ||
"rews": rews, | ||
"lens": lens, | ||
"idxs": idxs, | ||
"rew": rew_mean, | ||
"len": len_mean, | ||
"rew_std": rew_std, | ||
"len_std": len_std, | ||
} |
Binary file not shown.