Skip to content

Commit

Permalink
Merge pull request #60 from Eclectic-Sheep/fix/dreamer_v2_is_first
Browse files Browse the repository at this point in the history
Fix/dreamer v2 is first
  • Loading branch information
belerico authored Jul 20, 2023
2 parents 59637d2 + 539665e commit ff308c8
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 62 deletions.
6 changes: 6 additions & 0 deletions sheeprl/algos/dreamer_v2/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class DreamerV2Args(StandardArgs):
layer_norm: bool = Arg(
default=False, help="whether to apply nn.LayerNorm after every Linear/Conv2D/ConvTranspose2D"
)
objective_mix: float = Arg(
default=1.0,
help="the mixing coefficient for the actor objective: '0' uses the dynamics backpropagation, "
"i.e. it tries to maximize the estimated lambda values; '1' uses the standard reinforce objective, "
"i.e. log(p) * Advantage. ",
)

# Environment settings
expl_amount: float = Arg(default=0.0, help="the exploration amout to add to the actions")
Expand Down
81 changes: 40 additions & 41 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os
import pathlib
import time
Expand All @@ -25,13 +24,13 @@
from sheeprl.algos.dreamer_v2.agent import Player, WorldModel, build_models
from sheeprl.algos.dreamer_v2.args import DreamerV2Args
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import make_env, test
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, make_env, test
from sheeprl.data.buffers import EpisodeBuffer, SequentialReplayBuffer
from sheeprl.utils.callback import CheckpointCallback
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.parser import HfArgumentParser
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.utils import compute_lambda_values, polynomial_decay
from sheeprl.utils.utils import polynomial_decay

# Decomment the following two lines if you cannot start an experiment with DMC environments
# os.environ["PYOPENGL_PLATFORM"] = ""
Expand Down Expand Up @@ -258,21 +257,21 @@ def train(
predicted_target_values = target_critic(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)
if args.use_continues and world_model.continue_model:
done_mask = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) * args.gamma
done_mask = torch.cat((true_done, done_mask[1:]))
continues = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
true_done = (1 - data["dones"]).reshape(1, -1, 1) * args.gamma
continues = torch.cat((true_done, continues[1:]))
else:
done_mask = torch.ones_like(predicted_rewards.detach()) * args.gamma
continues = torch.ones_like(predicted_rewards.detach()) * args.gamma

# compute the lambda_values, by passing as last values the values of the last imagined state
# the dimensions of the lambda_values tensor are
# (horizon, batch_size * sequence_length, recurrent_state_size + stochastic_size)
lambda_values = compute_lambda_values(
predicted_rewards,
predicted_target_values,
done_mask,
last_values=predicted_target_values[-1],
horizon=args.horizon + 1,
predicted_rewards[:-1],
predicted_target_values[:-1],
continues[:-1],
bootstrap=predicted_target_values[-2:-1],
horizon=args.horizon,
lmbda=args.lmbda,
)

Expand All @@ -283,56 +282,59 @@ def train(
# the imagined trajectory would have ended.
#
# Suppose the case in which the continue model is not used and gamma = .99
# done_mask.shape = (15, 2500, 1)
# done_mask = [
# continues.shape = (15, 2500, 1)
# continues = [
# [ [.99], ..., [.99] ], (2500 columns)
# ...
# ] (15 rows)
# torch.ones_like(done_mask[:1]) = [
# torch.ones_like(continues[:1]) = [
# [ [1.], ..., [1.] ]
# ] (1 row and 2500 columns), the discount of the time step 0 is 1.
# done_mask[:-2] = [
# continues[:-2] = [
# [ [.99], ..., [.99] ], (2500 columns)
# ...
# ] (13 rows)
# torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-2]), 0) = [
# torch.cat((torch.ones_like(continues[:1]), continues[:-2]), 0) = [
# [ [1.], ..., [1.] ], (2500 columns)
# [ [.99], ..., [.99] ],
# ...,
# [ [.99], ..., [.99] ],
# ] (14 rows), the total number of imagined steps is 15, but one is lost because of the values computation
# torch.cumprod(torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-2]), 0), 0) = [
# torch.cumprod(torch.cat((torch.ones_like(continues[:1]), continues[:-2]), 0), 0) = [
# [ [1.], ..., [1.] ], (2500 columns)
# [ [.99], ..., [.99] ],
# [ [.9801], ..., [.9801] ],
# ...,
# [ [.8775], ..., [.8775] ],
# ] (14 rows)
discount = torch.cumprod(torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-1]), 0), 0)
discount = torch.cumprod(torch.cat((torch.ones_like(continues[:1]), continues[:-1]), 0), 0)

# actor optimization step. Eq. 6 from the paper
actor_optimizer.zero_grad(set_to_none=True)
policies: Sequence[Distribution] = actor(imagined_trajectories[:-2].detach())[1]
if is_continuous:
objective = lambda_values[1:]
else:
baseline = target_critic(imagined_trajectories[:-2])
advantage = (lambda_values[1:] - baseline).detach()
objective = (
torch.stack(
[
p.log_prob(imgnd_act[1:-1].detach()).unsqueeze(-1)
for p, imgnd_act in zip(policies, torch.split(imagined_actions, actions_dim, -1))
],
-1,
).sum(-1)
* advantage
)

# Dynamics backpropagation
dynamics = lambda_values[1:]

# Reinforce
baseline = target_critic(imagined_trajectories[:-2])
advantage = (lambda_values[1:] - baseline).detach()
reinforce = (
torch.stack(
[
p.log_prob(imgnd_act[1:-1].detach()).unsqueeze(-1)
for p, imgnd_act in zip(policies, torch.split(imagined_actions, actions_dim, -1))
],
-1,
).sum(-1)
* advantage
)
objective = args.objective_mix * reinforce + (1 - args.objective_mix) * dynamics
try:
entropy = args.actor_ent_coef * torch.stack([p.entropy() for p in policies], -1).sum(-1)
except NotImplementedError:
entropy = torch.zeros_like(objective)
policy_loss = -torch.mean(discount[:-2] * (objective + entropy.unsqueeze(-1)))
policy_loss = -torch.mean(discount[:-2].detach() * (objective + entropy.unsqueeze(-1)))
fabric.backward(policy_loss)
if args.clip_gradients is not None and args.clip_gradients > 0:
actor_grads = fabric.clip_gradients(
Expand Down Expand Up @@ -424,9 +426,6 @@ def main():
log_dir = data[0]
os.makedirs(log_dir, exist_ok=True)

# Save args as dict automatically
args.log_dir = log_dir

env: gym.Env = make_env(
args.env_id,
args.seed + rank * args.num_envs,
Expand Down Expand Up @@ -592,7 +591,7 @@ def main():
step_data["dones"] = torch.zeros(args.num_envs, 1)
step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim))
step_data["rewards"] = torch.zeros(args.num_envs, 1)
step_data["is_first"] = copy.deepcopy(step_data["dones"])
step_data["is_first"] = torch.ones_like(step_data["dones"])
if buffer_type == "sequential":
rb.add(step_data[None, ...])
else:
Expand Down Expand Up @@ -630,7 +629,6 @@ def main():
else:
real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions])

step_data["is_first"] = copy.deepcopy(step_data["dones"])
o, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape))
dones = np.logical_or(dones, truncated)
if args.dry_run and buffer_type == "episode":
Expand All @@ -654,6 +652,7 @@ def main():
obs = next_obs

step_data["dones"] = dones
step_data["is_first"] = torch.zeros_like(step_data["dones"])
step_data["actions"] = actions
step_data["rewards"] = clip_rewards_fn(rewards)
data_to_add = step_data[None, ...]
Expand All @@ -676,7 +675,7 @@ def main():
step_data["dones"] = torch.zeros(args.num_envs, 1)
step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim))
step_data["rewards"] = torch.zeros(args.num_envs, 1)
step_data["is_first"] = copy.deepcopy(step_data["dones"])
step_data["is_first"] = torch.ones_like(step_data["dones"])
data_to_add = step_data[None, ...]
if buffer_type == "sequential":
rb.add(data_to_add)
Expand Down
20 changes: 20 additions & 0 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,26 @@ def init_weights(m: nn.Module):
nn.init.constant_(m.bias.data, 0)


def compute_lambda_values(
rewards: Tensor,
values: Tensor,
continues: Tensor,
bootstrap: Optional[Tensor] = None,
horizon: int = 15,
lmbda: float = 0.95,
):
if bootstrap is None:
bootstrap = torch.zeros_like(values[-2:-1])
agg = bootstrap
next_val = torch.cat((values[1:], bootstrap), dim=0)
inputs = rewards + continues * next_val * (1 - lmbda)
lv = []
for i in reversed(range(horizon)):
agg = inputs[i] + continues[i] * lmbda * agg
lv.append(agg)
return torch.cat(list(reversed(lv)), dim=0)


@torch.no_grad()
def test(
player: "Player", fabric: Fabric, args: DreamerV2Args, cnn_keys: List[str], mlp_keys: List[str], test_name: str = ""
Expand Down
42 changes: 21 additions & 21 deletions sheeprl/algos/p2e_dv2/p2e_dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sheeprl.algos.dreamer_v2.agent import Player, WorldModel
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import init_weights, make_env, test
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, init_weights, make_env, test
from sheeprl.algos.p2e_dv2.agent import build_models
from sheeprl.algos.p2e_dv2.args import P2EDV2Args
from sheeprl.data.buffers import EpisodeBuffer, SequentialReplayBuffer
Expand All @@ -35,7 +35,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.parser import HfArgumentParser
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.utils import compute_lambda_values, polynomial_decay
from sheeprl.utils.utils import polynomial_decay

# Decomment the following two lines if you are using MineDojo on an headless machine
# os.environ["MINEDOJO_HEADLESS"] = "1"
Expand Down Expand Up @@ -275,26 +275,26 @@ def train(
aggregator.update("Rewards/intrinsic", intrinsic_reward.detach().cpu().mean())

if args.use_continues and world_model.continue_model:
done_mask = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
continues = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) * args.gamma
done_mask = torch.cat((true_done, done_mask[1:]))
continues = torch.cat((true_done, continues[1:]))
else:
done_mask = torch.ones_like(intrinsic_reward.detach()) * args.gamma
continues = torch.ones_like(intrinsic_reward.detach()) * args.gamma

lambda_values = compute_lambda_values(
intrinsic_reward,
predicted_target_values,
done_mask,
last_values=predicted_target_values[-1],
horizon=args.horizon + 1,
intrinsic_reward[:-1],
predicted_target_values[:-1],
continues[:-1],
bootstrap=predicted_target_values[-2:-1],
horizon=args.horizon,
lmbda=args.lmbda,
)

aggregator.update("Values_exploration/predicted_values", predicted_target_values.detach().cpu().mean())
aggregator.update("Values_exploration/lambda_values", lambda_values.detach().cpu().mean())

with torch.no_grad():
discount = torch.cumprod(torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-1]), 0), 0)
discount = torch.cumprod(torch.cat((torch.ones_like(continues[:1]), continues[:-1]), 0), 0)

actor_exploration_optimizer.zero_grad(set_to_none=True)
policies: Sequence[Distribution] = actor_exploration(imagined_trajectories[:-2].detach())[1]
Expand Down Expand Up @@ -379,23 +379,23 @@ def train(
predicted_target_values = target_critic_task(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)
if args.use_continues and world_model.continue_model:
done_mask = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) * args.gamma
done_mask = torch.cat((true_done, done_mask[1:]))
continues = Independent(Bernoulli(logits=world_model.continue_model(imagined_trajectories)), 1).mean
true_done = (1 - data["dones"]).reshape(1, -1, 1) * args.gamma
continues = torch.cat((true_done, continues[1:]))
else:
done_mask = torch.ones_like(predicted_rewards.detach()) * args.gamma
continues = torch.ones_like(predicted_rewards.detach()) * args.gamma

lambda_values = compute_lambda_values(
predicted_rewards,
predicted_target_values,
done_mask,
last_values=predicted_target_values[-1],
horizon=args.horizon + 1,
predicted_rewards[:-1],
predicted_target_values[:-1],
continues[:-1],
bootstrap=predicted_target_values[-2:-1],
horizon=args.horizon,
lmbda=args.lmbda,
)

with torch.no_grad():
discount = torch.cumprod(torch.cat((torch.ones_like(done_mask[:1]), done_mask[:-1]), 0), 0)
discount = torch.cumprod(torch.cat((torch.ones_like(continues[:1]), continues[:-1]), 0), 0)

actor_task_optimizer.zero_grad(set_to_none=True)
policies: Sequence[Distribution] = actor_task(imagined_trajectories[:-2].detach())[1]
Expand Down

0 comments on commit ff308c8

Please sign in to comment.