diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index f8a1e66d..f5f18610 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -14,6 +14,7 @@ from sheeprl.models.models import MLP, MultiEncoder, NatureCNN from sheeprl.utils.fabric import get_single_device_fabric +from sheeprl.utils.utils import safeatanh, safetanh class CNNEncoder(nn.Module): @@ -69,11 +70,18 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor: class PPOActor(nn.Module): - def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None: + def __init__( + self, + actor_backbone: torch.nn.Module, + actor_heads: torch.nn.ModuleList, + is_continuous: bool, + distribution: str = "auto", + ) -> None: super().__init__() self.actor_backbone = actor_backbone self.actor_heads = actor_heads self.is_continuous = is_continuous + self.distribution = distribution def forward(self, x: Tensor) -> List[Tensor]: x = self.actor_backbone(x) @@ -97,6 +105,21 @@ def __init__( super().__init__() self.is_continuous = is_continuous self.distribution_cfg = distribution_cfg + self.distribution = distribution_cfg.get("type", "auto").lower() + if self.distribution not in ("auto", "normal", "tanh_normal", "discrete"): + raise ValueError( + "The distribution must be on of: `auto`, `discrete`, `normal` and `tanh_normal`. " + f"Found: {self.distribution}" + ) + if self.distribution == "discrete" and is_continuous: + raise ValueError("You have choose a discrete distribution but `is_continuous` is true") + elif self.distribution not in {"discrete", "auto"} and not is_continuous: + raise ValueError("You have choose a continuous distribution but `is_continuous` is false") + if self.distribution == "auto": + if is_continuous: + self.distribution = "normal" + else: + self.distribution = "discrete" self.actions_dim = actions_dim in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys]) mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys]) @@ -158,7 +181,29 @@ def __init__( actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) else: actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) - self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) + self.actor = PPOActor(actor_backbone, actor_heads, is_continuous, self.distribution) + + def _normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]: + mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = actions[0] + log_prob = normal.log_prob(actions) + return actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1) + + def _tanh_normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]: + mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + tanh_actions = actions[0].float() + actions = safeatanh(tanh_actions, eps=torch.finfo(tanh_actions.dtype).resolution) + log_prob = normal.log_prob(actions) + log_prob -= 2.0 * ( + torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device)) + - tanh_actions + - torch.nn.functional.softplus(-2.0 * tanh_actions) + ).sum(-1, keepdim=False) + return tanh_actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1) def forward( self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None @@ -167,17 +212,11 @@ def forward( actor_out: List[Tensor] = self.actor(feat) values = self.critic(feat) if self.is_continuous: - mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) - std = log_std.exp() - normal = Independent(Normal(mean, std), 1) - if actions is None: - actions = normal.sample() - else: - # always composed by a tuple of one element containing all the - # continuous actions - actions = actions[0] - log_prob = normal.log_prob(actions) - return tuple([actions]), log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1), values + if self.distribution == "normal": + actions, log_prob, entropy = self._normal(actor_out[0], actions) + elif self.distribution == "tanh_normal": + actions, log_prob, entropy = self._tanh_normal(actor_out[0], actions) + return tuple([actions]), log_prob, entropy, values else: should_append = False actions_logprobs: List[Tensor] = [] @@ -207,17 +246,38 @@ def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn. self.critic = critic self.actor = actor + def _normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]: + mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + log_prob = normal.log_prob(actions) + return actions, log_prob.unsqueeze(dim=-1) + + def _tanh_normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]: + mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample().float() + tanh_actions = safetanh(actions, eps=torch.finfo(actions.dtype).resolution) + log_prob = normal.log_prob(actions) + log_prob -= 2.0 * ( + torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device)) + - tanh_actions + - torch.nn.functional.softplus(-2.0 * tanh_actions) + ).sum(-1, keepdim=False) + return tanh_actions, log_prob.unsqueeze(dim=-1) + def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]: feat = self.feature_extractor(obs) values = self.critic(feat) actor_out: List[Tensor] = self.actor(feat) if self.actor.is_continuous: - mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) - std = log_std.exp() - normal = Independent(Normal(mean, std), 1) - actions = normal.sample() - log_prob = normal.log_prob(actions) - return tuple([actions]), log_prob.unsqueeze(dim=-1), values + if self.actor.distribution == "normal": + actions, log_prob = self._normal(actor_out[0]) + elif self.actor.distribution == "tanh_normal": + actions, log_prob = self._tanh_normal(actor_out[0]) + return tuple([actions]), log_prob, values else: actions_dist: List[Distribution] = [] actions_logprobs: List[Tensor] = [] @@ -247,6 +307,8 @@ def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[ std = log_std.exp() normal = Independent(Normal(mean, std), 1) actions = normal.sample() + if self.actor.distribution == "tanh_normal": + actions = safeatanh(actions, eps=torch.finfo(actions.dtype).resolution) return tuple([actions]) else: actions: List[Tensor] = [] diff --git a/sheeprl/algos/ppo/loss.py b/sheeprl/algos/ppo/loss.py index 5422da54..15209a47 100644 --- a/sheeprl/algos/ppo/loss.py +++ b/sheeprl/algos/ppo/loss.py @@ -52,11 +52,14 @@ def value_loss( ) -> Tensor: if not clip_vloss: values_pred = new_values - # return F.mse_loss(values_pred, returns, reduction=reduction) + return F.mse_loss(values_pred, returns, reduction=reduction) else: - values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef) - # return torch.max((new_values - returns) ** 2, (values_pred - returns) ** 2).mean() - return F.mse_loss(values_pred, returns, reduction=reduction) + v_loss_unclipped = (new_values - returns) ** 2 + v_clipped = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef) + v_loss_clipped = (v_clipped - returns) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + return v_loss def entropy_loss(entropy: Tensor, reduction: str = "mean") -> Tensor: diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 95057f2d..205489d9 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -169,6 +169,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) ) + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r # Create the actor and critic models agent, player = build_agent( fabric, @@ -304,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): vals = player.get_values(real_next_obs).cpu().numpy() rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) - rewards = rewards.reshape(cfg.env.num_envs, -1) + rewards = clip_rewards_fn(rewards).reshape(cfg.env.num_envs, -1).astype(np.float32) # Update the step data step_data["dones"] = dones[np.newaxis] @@ -347,7 +348,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) next_values = player.get_values(torch_obs) returns, advantages = gae( - local_data["rewards"].to(torch.float64), + local_data["rewards"], local_data["values"], local_data["dones"], next_values, diff --git a/sheeprl/configs/exp/ppo.yaml b/sheeprl/configs/exp/ppo.yaml index c5c05719..f149611d 100644 --- a/sheeprl/configs/exp/ppo.yaml +++ b/sheeprl/configs/exp/ppo.yaml @@ -13,6 +13,10 @@ algo: mlp_keys: encoder: [state] +# Distribution +distribution: + type: "auto" + # Buffer buffer: share_data: False diff --git a/sheeprl/utils/distribution.py b/sheeprl/utils/distribution.py index 31765bb6..842a745d 100644 --- a/sheeprl/utils/distribution.py +++ b/sheeprl/utils/distribution.py @@ -307,6 +307,7 @@ class OneHotCategoricalValidateArgs(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.one_hot has_enumerate_support = True @@ -391,6 +392,7 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (Bengio et al, 2013) """ + has_rsample = True def rsample(self, sample_shape=torch.Size()): diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index 74bf8a35..971c4192 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -298,3 +298,16 @@ def load_state_dict(self, state_dict: Mapping[str, Any]): self._prev = state_dict["_prev"] self._pretrain_steps = state_dict["_pretrain_steps"] return self + + +# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L156 +def safetanh(x, eps): + lim = 1.0 - eps + y = x.tanh() + return y.clamp(-lim, lim) + + +# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L161 +def safeatanh(y, eps): + lim = 1.0 - eps + return y.clamp(-lim, lim).atanh()