Skip to content

Commit

Permalink
Added support for user input to play against trained tianshou models
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower committed Feb 1, 2023
1 parent c05041c commit 5b5843f
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 12 deletions.
50 changes: 38 additions & 12 deletions gobblet/examples/example_DQN_tianshou.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torch.utils.tensorboard import SummaryWriter

from gobblet import gobblet_v1
from gobblet.game.collector_manual_policy import ManualPolicyCollector


def get_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -155,8 +156,8 @@ def get_agents(
return policy, optim, env.agents


def get_env(render_mode=None, debug=False):
return PettingZooEnv(gobblet_v1.env(render_mode=render_mode, debug=debug))
def get_env(render_mode=None, args=None):
return PettingZooEnv(gobblet_v1.env(render_mode=render_mode, args=args))


def train_agent(
Expand Down Expand Up @@ -256,7 +257,7 @@ def watch(
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, debug=args.debug)])
env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)])
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)
Expand All @@ -270,7 +271,10 @@ def watch(
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")

# TODO: Look more into Tianshou and see if self play is possible
# For watching I think it could just be the same policy for both agents, but for training I think self play would be different
def watch_selfplay(args, agent):
raise NotImplementedError()
env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, debug=args.debug)])
agent.set_eps(args.eps_test)
policy = MultiAgentPolicyManager([agent, deepcopy(agent)], env) # fixed here
Expand All @@ -280,26 +284,48 @@ def watch_selfplay(args, agent):
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")

if __name__ == "__main__":
# train the agent and watch its performance in a match!
args = get_args()
result, agent = train_agent(args)
watch(args, agent)
# play(args, agent)

# Allows the user to input moves and play vs the learned agent
def play(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode)])
env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)])
# env = get_env(render_mode=args.render_mode, args=args) # Throws error because collector looks for length, could just override though since I'm using my own collector
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)
policy.eval()
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)

collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions

pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env
manual_policy = gobblet_v1.ManualPolicy(pettingzoo_env) # Gobblet keyboard input requires access to raw_env (uses functions from board)

# Get the first move from the CPU player
result = collector.collect(n_step=1, render=args.render)

while not (collector.data.terminated or collector.data.truncated):
agent_id = collector.data.obs.agent_id
if agent_id == pettingzoo_env.agents[1]:
# action_mask = collector.data.obs.mask[0]
# action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask))
observation = {"observation": collector.data.obs.obs,
"action_mask": collector.data.obs.mask} # PettingZoo expects a dict with this format
action = manual_policy(observation, agent_id)

result = collector.collect_result(action=action.reshape(1), render=args.render)
else:
result = collector.collect(n_step=1, render=args.render)

rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")

if __name__ == "__main__":
# train the agent and watch its performance in a match!
args = get_args()
result, agent = train_agent(args)
# watch(args, agent)
play(args, agent)
191 changes: 191 additions & 0 deletions gobblet/game/collector_manual_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Extending tianshou collector class to work with manual policy (for user input)
import time
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import gym
import numpy as np
import torch

from tianshou.data import (
Batch,
ReplayBuffer,
)
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy

from tianshou.data.collector import Collector


class ManualPolicyCollector(Collector):
def __init__(
self,
policy: BasePolicy,
env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[ReplayBuffer] = None,
preprocess_fn: Optional[Callable[..., Batch]] = None,
exploration_noise: bool = False,
) -> None:
super(ManualPolicyCollector, self).__init__(policy=policy, env=env, exploration_noise=exploration_noise)

# Custom function to collect the result of an inputted action
def collect_result(
self,
action: int = None,
render: Optional[float] = None,
no_grad: bool = True,
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Collect the results of an inputted action..
:param int action: the action you want to collect the results for.
:param float render: the sleep time between rendering consecutive frames.
Default to None (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
:return: A dict including the following keys
* ``n/ep`` collected number of episodes.
* ``n/st`` collected number of steps.
* ``rews`` array of episode reward over collected episodes.
* ``lens`` array of episode length over collected episodes.
* ``idxs`` array of episode start index in buffer over collected episodes.
* ``rew`` mean of episodic rewards.
* ``len`` mean of episodic lengths.
* ``rew_std`` standard error of episodic rewards.
* ``len_std`` standard error of episodic lengths.
"""
assert not self.env.is_async, "Please use AsyncCollector if using async venv."

ready_env_ids = np.arange(self.env_num)

start_time = time.time()

step_count = 0
episode_count = 0
episode_rews = []
episode_lens = []
episode_start_indices = []

while True:
assert len(self.data) == len(ready_env_ids)
# restore the state: if the last state is None, it won't store
last_state = self.data.policy.pop("hidden_state", None)

# use hard coded action (rather than using a policy or randomly sampling)
self.data.update(act=action)

# hard code to only update for a single step
n_step = 1

# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
result = self.env.step(action_remap, ready_env_ids) # type: ignore
if len(result) == 5:
obs_next, rew, terminated, truncated, info = result
done = np.logical_or(terminated, truncated)
elif len(result) == 4:
obs_next, rew, done, info = result
if isinstance(info, dict):
truncated = info["TimeLimit.truncated"]
else:
truncated = np.array(
[
info_item.get("TimeLimit.truncated", False)
for info_item in info
]
)
terminated = np.logical_and(done, ~truncated)
else:
raise ValueError()

self.data.update(
obs_next=obs_next,
rew=rew,
terminated=terminated,
truncated=truncated,
done=done,
info=info
)
if self.preprocess_fn:
self.data.update(
self.preprocess_fn(
obs_next=self.data.obs_next,
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
policy=self.data.policy,
env_id=ready_env_ids,
)
)

if render:
self.env.render()
if render > 0 and not np.isclose(render, 0):
time.sleep(render)

# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
self.data, buffer_ids=ready_env_ids
)

# collect statistics
step_count += len(ready_env_ids)

if np.any(done):
env_ind_local = np.where(done)[0]
env_ind_global = ready_env_ids[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.append(ep_len[env_ind_local])
episode_rews.append(ep_rew[env_ind_local])
episode_start_indices.append(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
self._reset_env_with_ids(
env_ind_local, env_ind_global, gym_reset_kwargs
)
for i in env_ind_local:
self._reset_state(i)

# remove surplus env id from ready_env_ids
# to avoid bias in selecting environments


self.data.obs = self.data.obs_next

if (n_step and step_count >= n_step):
break

# generate statistics
self.collect_step += step_count
self.collect_episode += episode_count
self.collect_time += max(time.time() - start_time, 1e-9)

if episode_count > 0:
rews, lens, idxs = list(
map(
np.concatenate,
[episode_rews, episode_lens, episode_start_indices]
)
)
rew_mean, rew_std = rews.mean(), rews.std()
len_mean, len_std = lens.mean(), lens.std()
else:
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
rew_mean = rew_std = len_mean = len_std = 0

return {
"n/ep": episode_count,
"n/st": step_count,
"rews": rews,
"lens": lens,
"idxs": idxs,
"rew": rew_mean,
"len": len_mean,
"rew_std": rew_std,
"len_std": len_std,
}
Binary file removed gobblet/log/gobblet/dqn/policy.pth
Binary file not shown.

0 comments on commit 5b5843f

Please sign in to comment.