Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO #272

Merged
merged 86 commits into from
Mar 28, 2023
Merged

PPO #272

Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
dae80b6
Adding initial PPO Code
kshitijkg May 3, 2022
f817ce9
Added buffer sampling and solved some bugs
kshitijkg May 3, 2022
4de8417
ppo agent: device and type errors fixed
sriyash421 May 3, 2022
bc5817e
ppo updater fixed
sriyash421 May 4, 2022
f422b36
ppo config updated
May 4, 2022
d1b7ee2
Updated Hive to use gym spaces instead of raw tuples to represent act…
dapatil211 Apr 20, 2022
81ac508
Updated tests to affect api change of 1106ec2
dapatil211 Apr 20, 2022
e2c1133
Adding initial PPO Code
kshitijkg May 3, 2022
6eb1e4e
Added buffer sampling and solved some bugs
kshitijkg May 3, 2022
2277abd
ppo agent: device and type errors fixed
sriyash421 May 3, 2022
d780683
ppo updater fixed
sriyash421 May 4, 2022
e0cc263
ppo config updated
May 4, 2022
ab5ef03
ppo replay added
sriyash421 May 6, 2022
301e77c
ppo replay conflict
sriyash421 May 6, 2022
6459606
ppo replay fixed
sriyash421 May 6, 2022
babfcce
ppo agent updated
sriyash421 May 6, 2022
c0f9039
ppo agent and config updated
sriyash421 May 6, 2022
a973cb3
ppo code running but buggy
May 9, 2022
0c78e2d
cartpole working
May 12, 2022
f5271bf
ppo configs
May 18, 2022
ad3ed9b
ppo net fixed
May 24, 2022
0f63802
merge dev
May 24, 2022
f598433
atari configs added
May 24, 2022
a3c3c1c
ppo_nets done
May 24, 2022
47e106f
ppo_replay done
May 24, 2022
12a4b73
ppo env wrappers added
May 24, 2022
277b9bc
ppo agent done
May 24, 2022
9417e05
configs done
May 24, 2022
d3616a7
stack size > 1 handled temporarily
May 27, 2022
9de6e3a
linting fixed
sriyash421 May 27, 2022
c151b7b
Merge branch 'dev' into ppo_spaces
sriyash421 Jun 27, 2022
7201e97
last batch drop fix
sriyash421 Jun 27, 2022
45a9da4
config changes
sriyash421 Jun 29, 2022
4ea7527
Merge branch 'ppo_spaces' of github.com:chandar-lab/RLHive into ppo_s…
sriyash421 Jun 29, 2022
a9848e1
shared network added
sriyash421 Jul 7, 2022
3adcb73
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 7, 2022
54996f3
reward wrapper added
sriyash421 Jul 13, 2022
fa9297b
linting fixed
sriyash421 Jul 13, 2022
2ac73ba
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 13, 2022
3d11136
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 29, 2022
4f5b4b8
docs fixed
sriyash421 Aug 28, 2022
cc4aab8
replay changed
sriyash421 Aug 28, 2022
2d55afa
update loop
sriyash421 Aug 28, 2022
dcf2aa7
type specification
sriyash421 Aug 28, 2022
936c3b1
env wrappers registered
sriyash421 Aug 28, 2022
04d8692
linting fixed
sriyash421 Aug 29, 2022
360dc00
Merge branch 'dev' into ppo_spaces
kshitijkg Sep 25, 2022
b1613cd
Removed one off transition, cleaned up replay buffer
kshitijkg Sep 25, 2022
bdcd11e
Fixed linter issues
kshitijkg Sep 25, 2022
5a8e2da
wrapper error fixed
sriyash421 Sep 29, 2022
a54377a
added vars to dict; fixed long lines and var names; moved wrapper reg…
sriyash421 Oct 11, 2022
9680185
config fixed
sriyash421 Oct 13, 2022
2c9295f
addded normalisation and fixed log
sriyash421 Oct 13, 2022
767f96c
norm filed added
sriyash421 Oct 14, 2022
b4f2ea1
norm bug fixed
sriyash421 Nov 3, 2022
58f5ec2
rew norm updated
sriyash421 Nov 11, 2022
306faea
fixes
sriyash421 Nov 11, 2022
35d6aeb
fixing norm bug; config
sriyash421 Nov 23, 2022
7d31faf
config fixes
sriyash421 Nov 23, 2022
b84722e
obs norm
sriyash421 Nov 24, 2022
a4c1692
hardcoded wrappers added
sriyash421 Nov 24, 2022
11ccb21
normaliser shape fixed
sriyash421 Dec 6, 2022
0991e84
rew shape fixed; norm structure updated
sriyash421 Dec 6, 2022
c7f42a1
rew norm
sriyash421 Dec 6, 2022
84d933e
configs and wrapper fixed
sriyash421 Dec 7, 2022
3f01532
merge dev
sriyash421 Dec 19, 2022
54799c2
Merge branch 'dev' into ppo_spaces
sriyash421 Dec 19, 2022
8fb9902
Fixed formatting and naming
kshitijkg Jan 30, 2023
bd5c587
Added env wrapper logic
kshitijkg Jan 30, 2023
697a78c
Merging dev
kshitijkg Jan 30, 2023
a1e77fa
Renamed PPO Replay Buffer to On Policy Replay buffer
kshitijkg Jan 30, 2023
031f462
Made PPO Stateless Agent
kshitijkg Jan 30, 2023
28733ec
Fixed linting issues
kshitijkg Jan 30, 2023
8885a89
Minor modifications
kshitijkg Feb 7, 2023
0e42146
Fixed changed
kshitijkg Feb 8, 2023
d785c85
Formatting and minor changes
kshitijkg Mar 2, 2023
4946874
Merge branch 'dev' into ppo_spaces
dapatil211 Mar 20, 2023
308f111
Refactored Advatange Computation
kshitijkg Mar 21, 2023
543fc74
Reformating with black
kshitijkg Mar 21, 2023
43c3fb1
Renaming
kshitijkg Mar 21, 2023
4d82f99
Refactored Normalization code
kshitijkg Mar 21, 2023
e7d08d5
Added saving and loading of state dict for normalizers
kshitijkg Mar 21, 2023
aba7c49
Fixed multiplayer replay buffer for PPO
kshitijkg Mar 21, 2023
000c4e4
Fixed minor bug
kshitijkg Mar 22, 2023
3d6d076
Renamed file
kshitijkg Mar 22, 2023
aabeed0
Added lr annealing
dapatil211 Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hive/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from hive.agents.ddpg import DDPG
from hive.agents.dqn import DQNAgent
from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent
from hive.agents.ppo import PPOAgent
from hive.agents.rainbow import RainbowDQNAgent
from hive.agents.random import RandomAgent
from hive.agents.td3 import TD3
Expand All @@ -13,6 +14,7 @@
{
"DDPG": DDPG,
"DQNAgent": DQNAgent,
"PPOAgent": PPOAgent,
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
Expand Down
384 changes: 384 additions & 0 deletions hive/agents/ppo.py

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions hive/agents/qnets/ppo_nets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import gym
import numpy as np
import torch

from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.utils import calculate_output_dim


# TODO: add Multi-Discrete
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
class CategoricalHead(torch.nn.Module):
"""A module that implements a discrete actor head. It uses the ouput from the
:obj:`actor_net`, and adds creates a :py:class:`~torch.distributions.categorical.Categorical`
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
object to compute the action distribution."""

def __init__(self, feature_dim, action_space: gym.spaces.Discrete) -> None:
"""
Args:
feature dim: Expected output shape of the actor network.
action_shape: Expected shape of actions.
"""
super().__init__()
self.network = torch.nn.Linear(feature_dim, action_space.n)
self.distribution = torch.distributions.categorical.Categorical

def forward(self, x):
logits = self.network(x)
return self.distribution(logits=logits)


class GaussianPolicyHead(torch.nn.Module):
"""A module that implements a continuous actor head. It uses the output from the
:obj:`actor_net` and state independent learnable parameter :obj:`policy_logstd` to
create a :py:class:`~torch.distributions.normal.Normal` object to compute
the action distribution."""

def __init__(self, feature_dim, action_space: gym.spaces.Box) -> None:
"""
Args:
feature dim: Expected output shape of the actor network.
action_shape: Expected shape of actions.
"""
super().__init__()
self._action_shape = action_space.shape
self.policy_mean = torch.nn.Sequential(
torch.nn.Linear(feature_dim, np.prod(self._action_shape))
)
self.policy_logstd = torch.nn.Parameter(
torch.zeros(1, np.prod(action_space.shape))
)
self.distribution = torch.distributions.normal.Normal

def forward(self, x):
_mean = self.policy_mean(x)
_std = self.policy_logstd.repeat(x.shape[0], 1).exp()
distribution = self.distribution(
torch.reshape(_mean, (x.size(0), *self._action_shape)),
torch.reshape(_std, (x.size(0), *self._action_shape)),
)
return distribution


class PPOActorCriticNetwork(torch.nn.Module):
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"""A module that implements the TD3 actor computation. It puts together the
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved
:obj:`representation_network` and :obj:`actor_net`, and adds a final
:py:class:`~torch.nn.Linear` layer to compute the action."""

def __init__(
self,
representation_network,
actor_net,
critic_net,
network_output_dim,
action_space,
continuous_action,
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__()
self._network = representation_network
self._continuous_action = continuous_action
if actor_net is None:
actor_network = torch.nn.Identity()
else:
actor_network = actor_net(network_output_dim)
feature_dim = np.prod(calculate_output_dim(actor_network, network_output_dim))
actor_head = GaussianPolicyHead if self._continuous_action else CategoricalHead

self.actor = torch.nn.Sequential(
actor_network,
torch.nn.Flatten(),
actor_head(feature_dim, action_space),
)

if critic_net is None:
critic_network = torch.nn.Identity()
else:
critic_network = critic_net(network_output_dim)
feature_dim = np.prod(calculate_output_dim(critic_network, network_output_dim))
self.critic = torch.nn.Sequential(
critic_network,
torch.nn.Flatten(),
torch.nn.Linear(feature_dim, 1),
)

def forward(self, x, action=None):
hidden_state = self._network(x)
distribution = self.actor(hidden_state)
value = self.critic(hidden_state)
if action is None:
action = distribution.sample()

logprob, entropy = distribution.log_prob(action), distribution.entropy()
if self._continuous_action:
logprob, entropy = logprob.sum(dim=-1), entropy.sum(dim=-1)
return action, logprob, entropy, value
64 changes: 64 additions & 0 deletions hive/configs/atari/ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
run_name: &run_name 'atari-ppo'
train_steps: 10000000
test_frequency: 250000
test_episodes: 10
max_steps_per_episode: 27000
stack_size: &stack_size 4
save_dir: 'experiment'
saving_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 1000000
environment:
name: 'AtariEnv'
kwargs:
env_name: 'Breakout'

agent:
name: 'PPOAgent'
kwargs:
representation_net:
name: 'ConvNetwork'
kwargs:
channels: [32, 64, 64]
kernel_sizes: [8, 4, 3]
strides: [4, 2, 1]
paddings: [2, 2, 1]
mlp_layers: [512]
optimizer_fn:
name: 'Adam'
kwargs:
lr: .00025
init_fn:
name: 'orthogonal'
replay_buffer:
name: 'PPOReplayBuffer'
kwargs:
transitions_per_update: 4096
stack_size: *stack_size
use_gae: True
gae_lambda: .95
discount_rate: .99
grad_clip: .5
clip_coef: .1
ent_coef: .0
clip_vloss: True
vf_coef: .5
num_epochs_per_update: 4
normalize_advantages: True
batch_size: 256
device: 'cuda'
id: 'agent'
# List of logger configs used.
loggers:
-
name: ChompLogger
-
name: WandbLogger
kwargs:
project: Hive
name: *run_name
resume: "allow"
start_method: "fork"
64 changes: 64 additions & 0 deletions hive/configs/gym/ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
run_name: &run_name 'ppo'
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
train_steps: 500000
test_frequency: 1000
test_episodes: 10
max_steps_per_episode: 500
stack_size: &stack_size 1
save_dir: 'experiment'
saving_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 10000

environment:
name: 'GymEnv'
kwargs:
env_name: 'CartPole-v1'

agent:
name: 'PPOAgent'
kwargs:
actor_net:
name: 'MLPNetwork'
kwargs:
hidden_units: [256, 256]
critic_net:
name: 'MLPNetwork'
kwargs:
hidden_units: [256, 256]
replay_buffer:
name: 'PPOReplayBuffer'
kwargs:
transitions_per_update: 2048
use_gae: True
gae_lambda: .95
optimizer_fn:
name: 'Adam'
kwargs:
lr: .00025
discount_rate: .99
grad_clip: .5
clip_coef: .2
ent_coef: .01
clip_vloss: True
vf_coef: .5
num_epochs_per_update: 4
normalize_advantages: True
batch_size: 128
device: 'cuda'
id: 'agent'
init_fn:
name: 'orthogonal'
# List of logger configs used.
loggers:
-
name: ChompLogger
-
name: WandbLogger
kwargs:
project: Hive
name: *run_name
resume: "allow"
start_method: "fork"
65 changes: 65 additions & 0 deletions hive/configs/mujoco/ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
run_name: &run_name 'ppo'
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
train_steps: 2000000
test_frequency: 20000
test_episodes: 5
max_steps_per_episode: 100000
stack_size: &stack_size 1
save_dir: 'experiment'
saving_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 500000
environment:
name: 'GymEnv'
kwargs:
env_name: 'Hopper-v3'
mujoco_wrapper: True

agent:
name: 'PPOAgent'
kwargs:
actor_net:
name: 'MLPNetwork'
kwargs:
hidden_units: [64, 64]
critic_net:
name: 'MLPNetwork'
kwargs:
hidden_units: [64, 64]
replay_buffer:
name: 'PPOReplayBuffer'
kwargs:
transitions_per_update: 2048
use_gae: True
gae_lambda: .95
optimizer_fn:
name: 'Adam'
kwargs:
lr: .0003
discount_rate: .99
grad_clip: .5
clip_coef: .2
ent_coef: .0
clip_vloss: True
vf_coef: .5
num_epochs_per_update: 10
normalize_advantages: True
batch_size: 64
device: 'cuda'
id: 'agent'
init_fn:
name: 'orthogonal'

# List of logger configs used.
loggers:
-
name: ChompLogger
-
name: WandbLogger
kwargs:
project: Hive
name: *run_name
resume: "allow"
start_method: "fork"
22 changes: 20 additions & 2 deletions hive/envs/gym_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gym

import numpy as np
from hive.envs.base import BaseEnv
from hive.envs.env_spec import EnvSpec

Expand Down Expand Up @@ -29,7 +29,25 @@ def create_env(self, env_name, **kwargs):
Args:
env_name (str): Name of the environment
"""
self._env = gym.make(env_name)
env = gym.make(env_name)
if kwargs.get("mujoco_wrapper", False):
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.ClipAction(env)
env = gym.wrappers.NormalizeObservation(env)
env = gym.wrappers.TransformObservation(
env, lambda obs: np.clip(obs, -10, 10)
)
env = gym.wrappers.NormalizeReward(env)
env = gym.wrappers.TransformReward(
env, lambda reward: np.clip(reward, -10, 10)
)

elif kwargs.get("atari_wrapper", False):
env = gym.wrappers.NoopResetEnv(env, noop_max=30)
env = gym.wrappers.EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = gym.wrappers.FireResetEnv(env)
self._env = env
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved

def create_env_spec(self, env_name, **kwargs):
"""Function used to create the specification. Subclasses can override this method
Expand Down
2 changes: 2 additions & 0 deletions hive/replays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from hive.replays.circular_replay import CircularReplayBuffer, SimpleReplayBuffer
from hive.replays.legal_moves_replay import LegalMovesBuffer
from hive.replays.ppo_replay import PPOReplayBuffer
from hive.replays.prioritized_replay import PrioritizedReplayBuffer
from hive.replays.replay_buffer import BaseReplayBuffer
from hive.utils.registry import registry
Expand All @@ -11,6 +12,7 @@
"SimpleReplayBuffer": SimpleReplayBuffer,
"PrioritizedReplayBuffer": PrioritizedReplayBuffer,
"LegalMovesBuffer": LegalMovesBuffer,
"PPOReplayBuffer": PPOReplayBuffer,
},
)

Expand Down
2 changes: 1 addition & 1 deletion hive/replays/circular_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def sample(self, batch_size):

if self._n_step == 1:
is_terminal = terminals
trajectory_lengths = np.ones(batch_size)
trajectory_lengths = np.ones(terminals.shape[0])
sriyash421 marked this conversation as resolved.
Show resolved Hide resolved
else:
is_terminal = terminals.any(axis=1).astype(int)
trajectory_lengths = (
Expand Down
Loading