Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] DQN compatibility with compile #2571

Open
wants to merge 25 commits into
base: gh/vmoens/41/base
Choose a base branch
from
5 changes: 5 additions & 0 deletions sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ loss:
gamma: 0.99
hard_update_freq: 10_000
num_updates: 1

compile:
compile: False
compile_mode:
cudagraphs: False
5 changes: 5 additions & 0 deletions sota-implementations/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ loss:
gamma: 0.99
hard_update_freq: 50
num_updates: 1

compile:
compile: False
compile_mode:
cudagraphs: False
113 changes: 65 additions & 48 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
Deep Q-Learning Algorithm on Atari Environments.
"""
import tempfile
import time
import warnings

import hydra
import torch.nn
import torch.optim
import tqdm
from tensordict.nn import TensorDictSequential
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule, TensorDictSequential
from torchrl._utils import timeit

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
Expand Down Expand Up @@ -58,18 +58,6 @@ def main(cfg: "DictConfig"): # noqa: F821
greedy_module,
).to(device)

# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
init_random_frames=init_random_frames,
)

# Create the replay buffer
if cfg.buffer.scratch_dir is None:
tempdir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -127,25 +115,68 @@ def main(cfg: "DictConfig"): # noqa: F821
)
test_env.eval()

def update(sampled_tensordict):
loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=max_grad
)
optimizer.step()
target_net_updater.step()
return q_loss.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
init_random_frames=init_random_frames,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Main loop
collected_frames = 0
start_time = time.time()
sampling_start = time.time()
num_updates = cfg.loss.num_updates
max_grad = cfg.optim.max_grad_norm
num_test_episodes = cfg.logger.num_test_episodes
q_losses = torch.zeros(num_updates, device=device)
pbar = tqdm.tqdm(total=total_frames)
for i, data in enumerate(collector):

c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
data = next(c_iter)
log_info = {}
sampling_time = time.time() - sampling_start
pbar.update(data.numel())
data = data.reshape(-1)
current_frames = data.numel() * frame_skip
collected_frames += current_frames
greedy_module.step(current_frames)
replay_buffer.extend(data)
with timeit("rb - extend"):
replay_buffer.extend(data)

# Get and log training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
Expand All @@ -167,24 +198,13 @@ def main(cfg: "DictConfig"): # noqa: F821
continue

# optimization steps
training_start = time.time()
for j in range(num_updates):

sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)

loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(loss_module.parameters()), max_norm=max_grad
)
optimizer.step()
target_net_updater.step()
q_losses[j].copy_(q_loss.detach())

training_time = time.time() - training_start
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
q_loss = update(sampled_tensordict)
q_losses[j].copy_(q_loss)

# Get and log q-values, loss, epsilon, sampling time and training time
log_info.update(
Expand All @@ -193,48 +213,45 @@ def main(cfg: "DictConfig"): # noqa: F821
/ frames_per_batch,
"train/q_loss": q_losses.mean().item(),
"train/epsilon": greedy_module.eps,
"train/sampling_time": sampling_time,
"train/training_time": training_time,
}
)

# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
model.eval()
eval_start = time.time()
test_rewards = eval_model(
model, test_env, num_episodes=num_test_episodes
)
eval_time = time.time() - eval_start
log_info.update(
{
"eval/reward": test_rewards,
"eval/eval_time": eval_time,
}
)
model.train()

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()

# Log all the information
if logger:
for key, value in log_info.items():
logger.log_scalar(key, value, step=collected_frames)

# update weights of the inference policy
collector.update_policy_weights_()
sampling_start = time.time()

collector.shutdown()
if not test_env.is_closed:
test_env.close()

end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Loading
Loading