Skip to content

Commit

Permalink
Add tanh_normal dist to PPO (#312)
Browse files Browse the repository at this point in the history
* Add tanh_normal dist to PPO

* Removed unneeded check

* Fix argument to safeatanh

* Fix distribution check
  • Loading branch information
belerico authored Jul 12, 2024
1 parent b9e57ed commit 33b6366
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 25 deletions.
100 changes: 81 additions & 19 deletions sheeprl/algos/ppo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from sheeprl.models.models import MLP, MultiEncoder, NatureCNN
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.utils import safeatanh, safetanh


class CNNEncoder(nn.Module):
Expand Down Expand Up @@ -69,11 +70,18 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor:


class PPOActor(nn.Module):
def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None:
def __init__(
self,
actor_backbone: torch.nn.Module,
actor_heads: torch.nn.ModuleList,
is_continuous: bool,
distribution: str = "auto",
) -> None:
super().__init__()
self.actor_backbone = actor_backbone
self.actor_heads = actor_heads
self.is_continuous = is_continuous
self.distribution = distribution

def forward(self, x: Tensor) -> List[Tensor]:
x = self.actor_backbone(x)
Expand All @@ -97,6 +105,21 @@ def __init__(
super().__init__()
self.is_continuous = is_continuous
self.distribution_cfg = distribution_cfg
self.distribution = distribution_cfg.get("type", "auto").lower()
if self.distribution not in ("auto", "normal", "tanh_normal", "discrete"):
raise ValueError(
"The distribution must be on of: `auto`, `discrete`, `normal` and `tanh_normal`. "
f"Found: {self.distribution}"
)
if self.distribution == "discrete" and is_continuous:
raise ValueError("You have choose a discrete distribution but `is_continuous` is true")
elif self.distribution not in {"discrete", "auto"} and not is_continuous:
raise ValueError("You have choose a continuous distribution but `is_continuous` is false")
if self.distribution == "auto":
if is_continuous:
self.distribution = "normal"
else:
self.distribution = "discrete"
self.actions_dim = actions_dim
in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys])
mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys])
Expand Down Expand Up @@ -158,7 +181,29 @@ def __init__(
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
else:
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim])
self.actor = PPOActor(actor_backbone, actor_heads, is_continuous)
self.actor = PPOActor(actor_backbone, actor_heads, is_continuous, self.distribution)

def _normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = actions[0]
log_prob = normal.log_prob(actions)
return actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1)

def _tanh_normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
tanh_actions = actions[0].float()
actions = safeatanh(tanh_actions, eps=torch.finfo(tanh_actions.dtype).resolution)
log_prob = normal.log_prob(actions)
log_prob -= 2.0 * (
torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device))
- tanh_actions
- torch.nn.functional.softplus(-2.0 * tanh_actions)
).sum(-1, keepdim=False)
return tanh_actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1)

def forward(
self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None
Expand All @@ -167,17 +212,11 @@ def forward(
actor_out: List[Tensor] = self.actor(feat)
values = self.critic(feat)
if self.is_continuous:
mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
if actions is None:
actions = normal.sample()
else:
# always composed by a tuple of one element containing all the
# continuous actions
actions = actions[0]
log_prob = normal.log_prob(actions)
return tuple([actions]), log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1), values
if self.distribution == "normal":
actions, log_prob, entropy = self._normal(actor_out[0], actions)
elif self.distribution == "tanh_normal":
actions, log_prob, entropy = self._tanh_normal(actor_out[0], actions)
return tuple([actions]), log_prob, entropy, values
else:
should_append = False
actions_logprobs: List[Tensor] = []
Expand Down Expand Up @@ -207,17 +246,38 @@ def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn.
self.critic = critic
self.actor = actor

def _normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
log_prob = normal.log_prob(actions)
return actions, log_prob.unsqueeze(dim=-1)

def _tanh_normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample().float()
tanh_actions = safetanh(actions, eps=torch.finfo(actions.dtype).resolution)
log_prob = normal.log_prob(actions)
log_prob -= 2.0 * (
torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device))
- tanh_actions
- torch.nn.functional.softplus(-2.0 * tanh_actions)
).sum(-1, keepdim=False)
return tanh_actions, log_prob.unsqueeze(dim=-1)

def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]:
feat = self.feature_extractor(obs)
values = self.critic(feat)
actor_out: List[Tensor] = self.actor(feat)
if self.actor.is_continuous:
mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
log_prob = normal.log_prob(actions)
return tuple([actions]), log_prob.unsqueeze(dim=-1), values
if self.actor.distribution == "normal":
actions, log_prob = self._normal(actor_out[0])
elif self.actor.distribution == "tanh_normal":
actions, log_prob = self._tanh_normal(actor_out[0])
return tuple([actions]), log_prob, values
else:
actions_dist: List[Distribution] = []
actions_logprobs: List[Tensor] = []
Expand Down Expand Up @@ -247,6 +307,8 @@ def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
if self.actor.distribution == "tanh_normal":
actions = safeatanh(actions, eps=torch.finfo(actions.dtype).resolution)
return tuple([actions])
else:
actions: List[Tensor] = []
Expand Down
11 changes: 7 additions & 4 deletions sheeprl/algos/ppo/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ def value_loss(
) -> Tensor:
if not clip_vloss:
values_pred = new_values
# return F.mse_loss(values_pred, returns, reduction=reduction)
return F.mse_loss(values_pred, returns, reduction=reduction)
else:
values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
# return torch.max((new_values - returns) ** 2, (values_pred - returns) ** 2).mean()
return F.mse_loss(values_pred, returns, reduction=reduction)
v_loss_unclipped = (new_values - returns) ** 2
v_clipped = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
return v_loss


def entropy_loss(entropy: Tensor, reduction: str = "mean") -> Tensor:
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if is_continuous
else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
)
clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r
# Create the actor and critic models
agent, player = build_agent(
fabric,
Expand Down Expand Up @@ -304,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
vals = player.get_values(real_next_obs).cpu().numpy()
rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape)
dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8)
rewards = rewards.reshape(cfg.env.num_envs, -1)
rewards = clip_rewards_fn(rewards).reshape(cfg.env.num_envs, -1).astype(np.float32)

# Update the step data
step_data["dones"] = dones[np.newaxis]
Expand Down Expand Up @@ -347,7 +348,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
local_data["rewards"],
local_data["values"],
local_data["dones"],
next_values,
Expand Down
4 changes: 4 additions & 0 deletions sheeprl/configs/exp/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ algo:
mlp_keys:
encoder: [state]

# Distribution
distribution:
type: "auto"

# Buffer
buffer:
share_data: False
Expand Down
2 changes: 2 additions & 0 deletions sheeprl/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class OneHotCategoricalValidateArgs(Distribution):
probs (Tensor): event probabilities
logits (Tensor): event log probabilities (unnormalized)
"""

arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot
has_enumerate_support = True
Expand Down Expand Up @@ -391,6 +392,7 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs
[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
(Bengio et al, 2013)
"""

has_rsample = True

def rsample(self, sample_shape=torch.Size()):
Expand Down
13 changes: 13 additions & 0 deletions sheeprl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,16 @@ def load_state_dict(self, state_dict: Mapping[str, Any]):
self._prev = state_dict["_prev"]
self._pretrain_steps = state_dict["_pretrain_steps"]
return self


# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L156
def safetanh(x, eps):
lim = 1.0 - eps
y = x.tanh()
return y.clamp(-lim, lim)


# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L161
def safeatanh(y, eps):
lim = 1.0 - eps
return y.clamp(-lim, lim).atanh()

0 comments on commit 33b6366

Please sign in to comment.