Skip to content

Commit

Permalink
Add objective mix for the actor loss
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Jul 19, 2023
1 parent 1de03ee commit 539665e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
6 changes: 6 additions & 0 deletions sheeprl/algos/dreamer_v2/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class DreamerV2Args(StandardArgs):
layer_norm: bool = Arg(
default=False, help="whether to apply nn.LayerNorm after every Linear/Conv2D/ConvTranspose2D"
)
objective_mix: float = Arg(
default=1.0,
help="the mixing coefficient for the actor objective: '0' uses the dynamics backpropagation, "
"i.e. it tries to maximize the estimated lambda values; '1' uses the standard reinforce objective, "
"i.e. log(p) * Advantage. ",
)

# Environment settings
expl_amount: float = Arg(default=0.0, help="the exploration amout to add to the actions")
Expand Down
36 changes: 20 additions & 16 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,26 +312,29 @@ def train(
# actor optimization step. Eq. 6 from the paper
actor_optimizer.zero_grad(set_to_none=True)
policies: Sequence[Distribution] = actor(imagined_trajectories[:-2].detach())[1]
if is_continuous:
objective = lambda_values[1:]
else:
baseline = target_critic(imagined_trajectories[:-2])
advantage = (lambda_values[1:] - baseline).detach()
objective = (
torch.stack(
[
p.log_prob(imgnd_act[1:-1].detach()).unsqueeze(-1)
for p, imgnd_act in zip(policies, torch.split(imagined_actions, actions_dim, -1))
],
-1,
).sum(-1)
* advantage
)

# Dynamics backpropagation
dynamics = lambda_values[1:]

# Reinforce
baseline = target_critic(imagined_trajectories[:-2])
advantage = (lambda_values[1:] - baseline).detach()
reinforce = (
torch.stack(
[
p.log_prob(imgnd_act[1:-1].detach()).unsqueeze(-1)
for p, imgnd_act in zip(policies, torch.split(imagined_actions, actions_dim, -1))
],
-1,
).sum(-1)
* advantage
)
objective = args.objective_mix * reinforce + (1 - args.objective_mix) * dynamics
try:
entropy = args.actor_ent_coef * torch.stack([p.entropy() for p in policies], -1).sum(-1)
except NotImplementedError:
entropy = torch.zeros_like(objective)
policy_loss = -torch.mean(discount[:-2] * (objective + entropy.unsqueeze(-1)))
policy_loss = -torch.mean(discount[:-2].detach() * (objective + entropy.unsqueeze(-1)))
fabric.backward(policy_loss)
if args.clip_gradients is not None and args.clip_gradients > 0:
actor_grads = fabric.clip_gradients(
Expand Down Expand Up @@ -422,6 +425,7 @@ def main():
world_collective.broadcast_object_list(data, src=0)
log_dir = data[0]
os.makedirs(log_dir, exist_ok=True)

env: gym.Env = make_env(
args.env_id,
args.seed + rank * args.num_envs,
Expand Down

0 comments on commit 539665e

Please sign in to comment.