diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 8d76b31d88a..1abf951c44b 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -106,6 +106,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \ record_video=True \ record_frames=4 \ buffer_size=120 +python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + optim.batch_size=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.device=cuda:0 \ + optim.optim_steps_per_batch=1 \ + replay_buffer.size=120 \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \ num_workers=4 \ collector.total_frames=48 \ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 26979e2ae96..29bfa7d466e 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -136,6 +136,7 @@ CQL :template: rl_template_noinherit.rst CQLLoss + DiscreteCQLLoss DT ---- diff --git a/examples/cql/discrete_cql_config.yaml b/examples/cql/discrete_cql_config.yaml new file mode 100644 index 00000000000..1bfbb6916e9 --- /dev/null +++ b/examples/cql/discrete_cql_config.yaml @@ -0,0 +1,57 @@ +# Task and env +env: + name: CartPole-v1 + task: "" + library: gym + exp_name: cql_cartpole_gym + n_samples_stats: 1000 + max_episode_steps: 200 + seed: 0 + +# Collector +collector: + frames_per_batch: 200 + total_frames: 20000 + multi_step: 0 + init_random_frames: 1000 + env_per_collector: 1 + device: cpu + max_frames_per_traj: 200 + annealing_frames: 10000 + eps_start: 1.0 + eps_end: 0.01 +# logger +logger: + backend: wandb + log_interval: 5000 # record interval in frames + eval_steps: 200 + mode: online + eval_iter: 1000 + +# Buffer +replay_buffer: + prb: 0 + buffer_prefetch: 64 + size: 1_000_000 + scratch_dir: ${env.exp_name}_${env.seed} + +# Optimization +optim: + utd_ratio: 1 + device: cuda:0 + lr: 1e-3 + weight_decay: 0.0 + batch_size: 256 + lr_scheduler: "" + optim_steps_per_batch: 200 + +# Policy and model +model: + hidden_sizes: [256, 256] + activation: relu + +# loss +loss: + loss_function: l2 + gamma: 0.99 + tau: 0.005 diff --git a/examples/cql/discrete_cql_online.py b/examples/cql/discrete_cql_online.py new file mode 100644 index 00000000000..5dfde6a082d --- /dev/null +++ b/examples/cql/discrete_cql_online.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Discrete (DQN) CQL Example. + +This is a simple self-contained example of a discrete CQL training script. + +It supports state environments like gym and gymnasium. + +The helper functions are coded in the utils.py associated with this script. +""" + +import time + +import hydra +import numpy as np +import torch +import torch.cuda +import tqdm + +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from torchrl.record.loggers import generate_exp_name, get_logger +from utils import ( + log_metrics, + make_collector, + make_cql_optimizer, + make_discretecql_model, + make_discreteloss, + make_environment, + make_replay_buffer, +) + + +@hydra.main(version_base="1.1", config_path=".", config_name="discrete_cql_config") +def main(cfg: "DictConfig"): # noqa: F821 + device = torch.device(cfg.optim.device) + + # Create logger + exp_name = generate_exp_name("DiscreteCQL", cfg.env.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="discretecql_logging", + experiment_name=exp_name, + wandb_kwargs={"mode": cfg.logger.mode, "config": cfg}, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create environments + train_env, eval_env = make_environment(cfg) + + # Create agent + model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device) + + # Create loss + loss_module, target_net_updater = make_discreteloss(cfg.loss, model) + + # Create off-policy collector + collector = make_collector(cfg, train_env, explore_policy) + + # Create replay buffer + replay_buffer = make_replay_buffer( + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + buffer_scratch_dir=cfg.replay_buffer.scratch_dir, + device="cpu", + ) + + # Create optimizers + optimizer = make_cql_optimizer(cfg, loss_module) + + # Main loop + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_rollout_steps = cfg.env.max_episode_steps + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch + + start_time = sampling_start = time.time() + for tensordict in collector: + sampling_time = time.time() - sampling_start + + # Update exploration policy + explore_policy[1].step(tensordict.numel()) + + # Update weights of the inference policy + collector.update_policy_weights_() + + pbar.update(tensordict.numel()) + + tensordict = tensordict.reshape(-1) + current_frames = tensordict.numel() + # Add to replay buffer + replay_buffer.extend(tensordict.cpu()) + collected_frames += current_frames + + # Optimization steps + training_start = time.time() + if collected_frames >= init_random_frames: + ( + q_losses, + cql_losses, + ) = ([], []) + for _ in range(num_updates): + + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to( + device, non_blocking=True + ) + else: + sampled_tensordict = sampled_tensordict.clone() + + # Compute loss + loss_dict = loss_module(sampled_tensordict) + + q_loss = loss_dict["loss_qvalue"] + cql_loss = loss_dict["loss_cql"] + loss = q_loss + cql_loss + + # Update model + optimizer.zero_grad() + loss.backward() + optimizer.step() + q_losses.append(q_loss.item()) + cql_losses.append(cql_loss.item()) + + # Update target params + target_net_updater.step() + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + metrics_to_log["train/epsilon"] = explore_policy[1].eps + + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses) + metrics_to_log["train/cql_loss"] = np.mean(cql_losses) + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model, + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/cql/utils.py b/examples/cql/utils.py index ac62eea28bc..c64e9d62db7 100644 --- a/examples/cql/utils.py +++ b/examples/cql/utils.py @@ -1,10 +1,11 @@ import torch.nn import torch.optim -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from tensordict.nn.distributions import NormalParamExtractor from torchrl.collectors import SyncDataCollector from torchrl.data import ( + CompositeSpec, LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, @@ -18,17 +19,23 @@ DoubleToFloat, EnvCreator, ParallelEnv, - RewardScaling, + RewardSum, TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator -from torchrl.objectives import CQLLoss, SoftUpdate +from torchrl.modules import ( + EGreedyModule, + MLP, + ProbabilisticActor, + QValueActor, + TanhNormal, + ValueOperator, +) +from torchrl.objectives import CQLLoss, DiscreteCQLLoss, SoftUpdate from torchrl.trainers.helpers.models import ACTIVATIONS - # ==================================================================== # Environment utils # ----------------- @@ -55,8 +62,9 @@ def apply_env_transforms(env, reward_scaling=1.0): transformed_env = TransformedEnv( env, Compose( - RewardScaling(loc=0.0, scale=reward_scaling), + # RewardScaling(loc=0.0, scale=reward_scaling), DoubleToFloat(), + RewardSum(), ), ) return transformed_env @@ -208,6 +216,43 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): return model +def make_discretecql_model(cfg, train_env, eval_env, device="cpu"): + model_cfg = cfg.model + + action_spec = train_env.action_spec + + actor_net_kwargs = { + "num_cells": model_cfg.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": ACTIVATIONS[model_cfg.activation], + } + actor_net = MLP(**actor_net_kwargs) + qvalue_module = QValueActor( + module=actor_net, + spec=CompositeSpec(action=action_spec), + in_keys=["observation"], + ) + qvalue_module = qvalue_module.to(device) + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = eval_env.reset() + td = td.to(device) + qvalue_module(td) + + del td + greedy_module = EGreedyModule( + annealing_num_steps=cfg.collector.annealing_frames, + eps_init=cfg.collector.eps_start, + eps_end=cfg.collector.eps_end, + spec=action_spec, + ) + model_explore = TensorDictSequential( + qvalue_module, + greedy_module, + ).to(device) + return qvalue_module, model_explore + + def make_cql_modules_state(model_cfg, proof_environment): action_spec = proof_environment.action_spec @@ -258,10 +303,29 @@ def make_loss(loss_cfg, model): return loss_module, target_net_updater -def make_cql_optimizer(optim_cfg, loss_module): +def make_cql_optimizer(cfg, loss_module): optim = torch.optim.Adam( loss_module.parameters(), - lr=optim_cfg.lr, - weight_decay=optim_cfg.weight_decay, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, ) return optim + + +def make_discreteloss(loss_cfg, model): + loss_module = DiscreteCQLLoss( + model, + loss_function=loss_cfg.loss_function, + delay_value=True, + gamma=loss_cfg.gamma, + ) + loss_module.make_value_estimator(gamma=loss_cfg.gamma) + target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) + + return loss_module, target_net_updater + + +def log_metrics(logger, metrics, step): + if logger is not None: + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) diff --git a/test/test_cost.py b/test/test_cost.py index c74bd0e3ca0..eddf1dfc3bf 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -99,6 +99,7 @@ ClipPPOLoss, CQLLoss, DDPGLoss, + DiscreteCQLLoss, DiscreteSACLoss, DistributionalDQNLoss, DQNLoss, @@ -5164,6 +5165,367 @@ def test_cql_batcher( ) +class TestDiscreteCQL(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + action_spec_type, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + is_nn_module=False, + action_value_key=None, + ): + # Actor + if action_spec_type == "one_hot": + action_spec = OneHotDiscreteTensorSpec(action_dim) + elif action_spec_type == "categorical": + action_spec = DiscreteTensorSpec(action_dim) + else: + raise ValueError(f"Wrong action spec type: {action_spec_type}") + + module = nn.Linear(obs_dim, action_dim) + if is_nn_module: + return module.to(device) + actor = QValueActor( + spec=CompositeSpec( + { + "action": action_spec, + "action_value" + if action_value_key is None + else action_value_key: None, + "chosen_action_value": None, + }, + shape=[], + ), + action_space=action_spec_type, + module=module, + action_value_key=action_value_key, + ).to(device) + return actor + + def _create_mock_data_dcql( + self, + action_spec_type, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + action_key="action", + action_value_key="action_value", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim) + next_obs = torch.randn(batch, obs_dim) + + action_value = torch.randn(batch, action_dim) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + + if action_spec_type == "categorical": + action_value = torch.max(action_value, -1, keepdim=True)[0] + action = torch.argmax(action, -1, keepdim=False) + reward = torch.randn(batch, 1) + done = torch.zeros(batch, 1, dtype=torch.bool) + terminated = torch.zeros(batch, 1, dtype=torch.bool) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "next": { + "observation": next_obs, + "done": done, + "terminated": terminated, + "reward": reward, + }, + action_key: action, + action_value_key: action_value, + }, + device=device, + ) + return td + + def _create_seq_mock_data_dcql( + self, + action_spec_type, + batch=2, + T=4, + obs_dim=3, + action_dim=4, + device="cpu", + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + + action_value = torch.randn(batch, T, action_dim, device=device) + action = (action_value == action_value.max(-1, True)[0]).to(torch.long) + + # action_value = action_value.unsqueeze(-1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + if action_spec_type == "categorical": + action_value = torch.max(action_value, -1, keepdim=True)[0] + action = torch.argmax(action, -1, keepdim=False) + action = action.masked_fill_(~mask, 0.0) + else: + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "done": done, + "terminated": terminated, + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + "collector": {"mask": mask}, + "action": action, + "action_value": action_value.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + names=[None, "time"], + ) + return td + + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_dcql(self, delay_value, device, action_spec_type, td_est): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + td = self._create_mock_data_dcql( + action_spec_type=action_spec_type, device=device + ) + loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) + if td_est is ValueEstimators.GAE: + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if delay_value + else contextlib.nullcontext() + ), _check_td_steady(td): + loss = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys(True) + + sum([item for key, item in loss.items() if key.startswith("loss")]).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = loss_fn.target_value_network_params.clone() + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_value2 = loss_fn.target_value_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + def test_dcql_state_dict(self, delay_value, device, action_spec_type): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) + sd = loss_fn.state_dict() + loss_fn2 = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("n", range(4)) + @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) + def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): + torch.manual_seed(self.seed) + actor = self._create_mock_actor( + action_spec_type=action_spec_type, device=device + ) + + td = self._create_seq_mock_data_dcql( + action_spec_type=action_spec_type, device=device + ) + loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + ms_td = ms(td.clone()) + + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if delay_value + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*td.keys(True, True))) + _loss = sum([item for key, item in loss.items() if key.startswith("loss_")]) + _loss_ms = sum( + [item for key, item in loss_ms.items() if key.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum( + [item for key, item in loss_ms.items() if key.startswith("loss_")] + ).backward() + assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + + # Check param update effect on targets + target_value = loss_fn.target_value_network_params.clone() + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_value2 = loss_fn.target_value_network_params.clone() + if loss_fn.delay_value: + assert_allclose_td(target_value, target_value2) + else: + assert not (target_value == target_value2).any() + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_dcql_tensordict_keys(self, td_est): + torch.manual_seed(self.seed) + action_spec_type = "one_hot" + actor = self._create_mock_actor(action_spec_type=action_spec_type) + loss_fn = DQNLoss(actor) + + default_keys = { + "value_target": "value_target", + "value": "chosen_action_value", + "priority": "td_error", + "action_value": "action_value", + "action": "action", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test(loss_fn, default_keys=default_keys) + + loss_fn = DiscreteCQLLoss(actor) + key_mapping = { + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + actor = self._create_mock_actor( + action_spec_type=action_spec_type, action_value_key="chosen_action_value_2" + ) + loss_fn = DiscreteCQLLoss(actor) + key_mapping = { + "value": ("value", "chosen_action_value_2"), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("action_spec_type", ("categorical", "one_hot")) + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_dcql_tensordict_run(self, action_spec_type, td_est): + torch.manual_seed(self.seed) + tensor_keys = { + "action_value": "action_value_test", + "action": "action_test", + "priority": "priority_test", + } + actor = self._create_mock_actor( + action_spec_type=action_spec_type, + action_value_key=tensor_keys["action_value"], + ) + td = self._create_mock_data_dcql( + action_spec_type=action_spec_type, + action_key=tensor_keys["action"], + action_value_key=tensor_keys["action_value"], + ) + + loss_fn = DiscreteCQLLoss(actor, loss_function="l2") + loss_fn.set_keys(**tensor_keys) + + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with _check_td_steady(td): + _ = loss_fn(td) + assert loss_fn.tensor_keys.priority in td.keys() + + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_dcql_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): + n_obs = 3 + n_action = 4 + action_spec = OneHotDiscreteTensorSpec(n_action) + module = nn.Linear(n_obs, n_action) # a simple value model + actor = QValueActor( + spec=action_spec, + action_space="one_hot", + module=module, + in_keys=[observation_key], + ) + loss = DiscreteCQLLoss(actor) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) + # define data + observation = torch.randn(n_obs) + next_observation = torch.randn(n_obs) + action = action_spec.rand() + next_reward = torch.randn(1) + next_done = torch.zeros(1, dtype=torch.bool) + next_terminated = torch.zeros(1, dtype=torch.bool) + kwargs = { + observation_key: observation, + f"next_{observation_key}": next_observation, + f"next_{reward_key}": next_reward, + f"next_{done_key}": next_done, + f"next_{terminated_key}": next_terminated, + "action": action, + } + td = TensorDict(kwargs, []).unflatten_keys("_") + loss_val = loss(**kwargs) + + loss_val_td = loss(td) + + torch.testing.assert_close(loss_val_td.get(loss.out_keys[0]), loss_val[0]) + torch.testing.assert_close(loss_val_td.get(loss.out_keys[1]), loss_val[1]) + + class TestPPO(LossModuleTestBase): seed = 0 diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index d2e8ed8e3a1..46f71e2b3d6 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -93,6 +93,10 @@ def __init__( action_key: Optional[NestedKey] = "action", action_mask_key: Optional[NestedKey] = None, ): + if not isinstance(eps_init, float): + warnings.warn("eps_init should be a float.") + if eps_end > eps_init: + raise RuntimeError("eps should decrease over time or be constant") self.action_key = action_key self.action_mask_key = action_mask_key in_keys = [self.action_key] @@ -105,8 +109,6 @@ def __init__( self.register_buffer("eps_init", torch.tensor([eps_init])) self.register_buffer("eps_end", torch.tensor([eps_end])) - if self.eps_end > self.eps_init: - raise RuntimeError("eps should decrease over time or be constant") self.annealing_num_steps = annealing_num_steps self.register_buffer("eps", torch.tensor([eps_init])) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 023b22ba3c4..4840d12b2d4 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -5,7 +5,7 @@ from .a2c import A2CLoss from .common import LossModule -from .cql import CQLLoss +from .cql import CQLLoss, DiscreteCQLLoss from .ddpg import DDPGLoss from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 249166a6bd2..9055e5464c6 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -6,19 +6,22 @@ import warnings from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch +import torch.nn as nn from tensordict.nn import dispatch, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase -from tensordict.utils import NestedKey +from tensordict.utils import NestedKey, unravel_key from torch import Tensor from torchrl.data import CompositeSpec +from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import ProbabilisticActor +from torchrl.modules import ProbabilisticActor, QValueActor +from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, @@ -27,6 +30,7 @@ distance_loss, ValueEstimators, ) + from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator try: @@ -43,7 +47,7 @@ class CQLLoss(LossModule): - """TorchRL implementation of the CQL loss. + """TorchRL implementation of the continuous CQL loss. Presented in "Conservative Q-Learning for Offline Reinforcement Learning" https://arxiv.org/abs/2006.04779 @@ -793,3 +797,366 @@ def _alpha(self): with torch.no_grad(): alpha = self.log_alpha.exp() return alpha + + +class DiscreteCQLLoss(LossModule): + """TorchRL implementation of the discrete CQL loss. + + This class implements the discrete conservative Q-learning (CQL) loss function, as presented in the paper + "Conservative Q-Learning for Offline Reinforcement Learning" (https://arxiv.org/abs/2006.04779). + + Args: + value_network (Union[QValueActor, nn.Module]): The Q-value network used to estimate state-action values. + Keyword Args: + loss_function (Optional[str]): The distance function used to calculate the distance between the predicted + Q-values and the target Q-values. Defaults to ``l2``. + delay_value (bool): Whether to separate the target Q value + networks from the Q value networks used for data collection. + Default is ``True``. + gamma (float, optional): Discount factor. Default is ``None``. + action_space: The action space of the environment. If None, it is inferred from the value network. + Defaults to None. + + + Examples: + >>> from torchrl.modules import MLP + >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> n_obs, n_act = 4, 3 + >>> value_net = MLP(in_features=n_obs, out_features=n_act) + >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) + >>> loss = DiscreteCQLLoss(actor, action_space=spec) + >>> batch = [10,] + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": spec.rand(batch), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1) + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + loss: Tensor(shape=torch.Size([]), device=cuda:0, dtype=torch.float32, is_shared=True), + loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]``, + and a single loss value is returned. + + Examples: + >>> from torchrl.objectives import DiscreteCQLLoss + >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torch import nn + >>> import torch + >>> n_obs = 3 + >>> n_action = 4 + >>> action_spec = OneHotDiscreteTensorSpec(n_action) + >>> value_network = nn.Linear(n_obs, n_action) # a simple value model + >>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec) + >>> # define data + >>> observation = torch.randn(n_obs) + >>> next_observation = torch.randn(n_obs) + >>> action = action_spec.rand() + >>> next_reward = torch.randn(1) + >>> next_done = torch.zeros(1, dtype=torch.bool) + >>> next_terminated = torch.zeros(1, dtype=torch.bool) + >>> loss_val = dcql_loss( + ... observation=observation, + ... next_observation=next_observation, + ... next_reward=next_reward, + ... next_done=next_done, + ... next_terminated=next_terminated, + ... action=action) + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + value_target (NestedKey): The input tensordict key where the target state value is expected. + Will be used for the underlying value estimator Defaults to ``"value_target"``. + value (NestedKey): The input tensordict key where the chosen action value is expected. + Will be used for the underlying value estimator. Defaults to ``"chosen_action_value"``. + action_value (NestedKey): The input tensordict key where the action value is expected. + Defaults to ``"action_value"``. + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + pred_val (NestedKey): The key where the predicted value will be written + in the input tensordict. This value is subsequently used by cql_loss. + Defaults to ``"pred_val"``. + + """ + + value_target: NestedKey = "value_target" + value: NestedKey = "chosen_action_value" + action_value: NestedKey = "action_value" + action: NestedKey = "action" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + pred_val: NestedKey = "pred_val" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + out_keys = [ + "loss_qvalue", + "loss_cql", + ] + + def __init__( + self, + value_network: Union[QValueActor, nn.Module], + *, + loss_function: Optional[str] = "l2", + delay_value: bool = True, + gamma: float = None, + action_space=None, + ) -> None: + super().__init__() + self._in_keys = None + self.delay_value = delay_value + value_network = ensure_tensordict_compatible( + module=value_network, + wrapper_type=QValueActor, + action_space=action_space, + ) + + self.convert_to_functional( + value_network, + "value_network", + create_target_params=self.delay_value, + ) + + self.value_network_in_keys = value_network.in_keys + + self.loss_function = loss_function + if action_space is None: + # infer from value net + try: + action_space = value_network.spec + except AttributeError: + # let's try with action_space then + try: + action_space = value_network.action_space + except AttributeError: + raise ValueError(self.ACTION_SPEC_ERROR) + if action_space is None: + warnings.warn( + "action_space was not specified. DiscreteCQLLoss will default to 'one-hot'. " + "This behaviour will be deprecated soon and a space will have to be passed. " + "Check the DiscreteCQLLoss documentation to see how to pass the action space." + ) + action_space = "one-hot" + self.action_space = _find_action_space(action_space) + + if gamma is not None: + warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) + self.gamma = gamma + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + value_target=self.tensor_keys.value_target, + value=self._tensor_keys.value, + reward=self._tensor_keys.reward, + done=self._tensor_keys.done, + terminated=self._tensor_keys.terminated, + ) + self._set_in_keys() + + def _set_in_keys(self): + in_keys = { + self.tensor_keys.action, + unravel_key(("next", self.tensor_keys.reward)), + unravel_key(("next", self.tensor_keys.done)), + unravel_key(("next", self.tensor_keys.terminated)), + *self.value_network.in_keys, + *[unravel_key(("next", key)) for key in self.value_network.in_keys], + } + self._in_keys = sorted(in_keys, key=str) + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + + # we will take care of computing the next value inside this module + value_net = self.value_network + + hp = dict(default_value_kwargs(value_type)) + hp.update(hyperparams) + if value_type is ValueEstimators.TD1: + self._value_estimator = TD1Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.TD0: + self._value_estimator = TD0Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type is ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator( + **hp, + value_network=value_net, + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "value_target": "value_target", + "value": self.tensor_keys.value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, + } + self._value_estimator.set_keys(**tensor_keys) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @dispatch + def value_loss( + self, + tensordict: TensorDictBase, + ) -> Tuple[torch.Tensor, dict]: + td_copy = tensordict.clone(False) + self.value_network( + td_copy, + params=self.value_network_params, + ) + + action = tensordict.get(self.tensor_keys.action) + pred_val = td_copy.get(self.tensor_keys.action_value) + + if self.action_space == "categorical": + if action.shape != pred_val.shape: + # unsqueeze the action if it lacks on trailing singleton dim + action = action.unsqueeze(-1) + pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1) + else: + action = action.to(torch.float) + pred_val_index = (pred_val * action).sum(-1) + + # calculate target value + with torch.no_grad(): + target_value = self.value_estimator.value_estimate( + td_copy, + target_params=self._cached_detached_target_value_params, + ).squeeze(-1) + + with torch.no_grad(): + td_error = (pred_val_index - target_value).pow(2) + td_error = td_error.unsqueeze(-1) + if tensordict.device is not None: + td_error = td_error.to(tensordict.device) + + tensordict.set( + self.tensor_keys.priority, + td_error, + inplace=True, + ) + tensordict.set( + self.tensor_keys.pred_val, + pred_val, + inplace=True, + ) + loss = ( + 0.5 * distance_loss(pred_val_index, target_value, self.loss_function).mean() + ) + + metadata = { + "td_error": td_error.mean(0).detach(), + "pred_value": pred_val.mean().detach(), + "target_value": target_value.mean().detach(), + } + + return loss, metadata + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDict: + """Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer. + + This function will also write a "td_error" key that can be used by prioritized replay buffers to assign + a priority to items in the tensordict. + + Args: + tensordict (TensorDictBase): a tensordict with keys ["action"] and the in_keys of + the value network (observations, "done", "terminated", "reward" in a "next" tensordict). + + Returns: + a tensor containing the CQL loss. + + """ + loss_qval, metadata = self.value_loss(tensordict) + loss_cql, _ = self.cql_loss(tensordict) + source = { + "loss_qvalue": loss_qval, + "loss_cql": loss_cql, + } + source.update(metadata) + td_out = TensorDict( + source=source, + batch_size=[], + ) + + return td_out + + @property + @_cache_values + def _cached_detached_target_value_params(self): + return self.target_value_network_params.detach() + + def cql_loss(self, tensordict): + qvalues = tensordict.get(self.tensor_keys.pred_val, default=None) + if qvalues is None: + raise KeyError( + "Couldn't find the predicted qvalue with key {self.tensor_keys.pred_val} in the input tensordict. " + "This could be caused by calling cql_loss method before value_loss." + ) + + current_action = tensordict.get(self.tensor_keys.action) + + logsumexp = torch.logsumexp(qvalues, dim=-1, keepdim=True) + if self.action_space == "categorical": + if current_action.shape != qvalues.shape: + # unsqueeze the action if it lacks on trailing singleton dim + current_action = current_action.unsqueeze(-1) + q_a = qvalues.gather(-1, current_action) + else: + q_a = (qvalues * current_action).sum(dim=-1, keepdim=True) + + return (logsumexp - q_a).mean(), {}