diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py new file mode 100644 index 00000000..7968804c --- /dev/null +++ b/tests/baselines_test/dcrlme_test.py @@ -0,0 +1,234 @@ +from typing import Any, Dict, Tuple +import functools +import pytest + +import jax +import jax.numpy as jnp + +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function +from qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function +from qdax import environments +from qdax.environments import behavior_descriptor_extractor +from qdax.core.map_elites import MAPElites +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter +from qdax.core.neuroevolution.buffers.buffer import DCRLTransition +from qdax.core.neuroevolution.networks.networks import MLP, MLPDC +from qdax.utils.metrics import default_qd_metrics + + +def test_dcrlme() -> None: + seed = 42 + + env_name = "ant_omni" + episode_length = 100 + min_bd = -30. + max_bd = 30. + + num_iterations = 5 + batch_size = 256 + + # Archive + num_init_cvt_samples = 50000 + num_centroids = 1024 + policy_hidden_layer_sizes = (128, 128) + + # DCRL-ME + ga_batch_size = 128 + dcg_batch_size = 64 + ai_batch_size = 64 + lengthscale = 0.1 + + # GA emitter + iso_sigma = 0.005 + line_sigma = 0.05 + + # DCRL emitter + critic_hidden_layer_size = (256, 256) + num_critic_training_steps = 3000 + num_pg_training_steps = 150 + pg_batch_size = 100 + replay_buffer_size = 1_000_000 + discount = 0.99 + reward_scaling = 1.0 + critic_learning_rate = 3e-4 + actor_learning_rate = 3e-4 + policy_learning_rate = 5e-3 + noise_clip = 0.5 + policy_noise = 0.2 + soft_tau_update = 0.005 + policy_delay = 2 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init environment + env = environments.create(env_name, episode_length=episode_length) + reset_fn = jax.jit(env.reset) + + # Compute the centroids + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + actor_dc_network = MLPDC( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size)) + init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs) + + # Define the fonction to play a step with the policy in the environment + def play_step_fn(env_state, policy_params, random_key): + actions = policy_network.apply(policy_params, env_state.obs) + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = DCRLTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + truncations=next_state.info["truncation"], + actions=actions, + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + ) + + return next_state, policy_params, random_key, transition + + # Prepare the scoring function + bd_extraction_fn = behavior_descriptor_extractor[env_name] + scoring_fn = functools.partial( + scoring_function, + episode_length=episode_length, + play_reset_fn=reset_fn, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key): + desc_prime_normalized = dcg_emitter.emitters[0]._normalize_desc(desc) + actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc_prime_normalized) + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = DCRLTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + truncations=next_state.info["truncation"], + actions=actions, + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + ) + + return next_state, actor_dc_params, desc, random_key, transition + + # Prepare the scoring function + scoring_actor_dc_fn = jax.jit(functools.partial( + scoring_actor_dc_function, + episode_length=episode_length, + play_reset_fn=reset_fn, + play_step_actor_dc_fn=play_step_actor_dc_fn, + behavior_descriptor_extractor=bd_extraction_fn, + )) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = 0 + + # Define a metrics function + metrics_function = functools.partial( + default_qd_metrics, + qd_offset=reward_offset * episode_length, + ) + + # Define the DCG-emitter config + dcg_emitter_config = DCRLMEConfig( + ga_batch_size=ga_batch_size, + dcg_batch_size=dcg_batch_size, + ai_batch_size=ai_batch_size, + lengthscale=lengthscale, + critic_hidden_layer_size=critic_hidden_layer_size, + num_critic_training_steps=num_critic_training_steps, + num_pg_training_steps=num_pg_training_steps, + batch_size=batch_size, + replay_buffer_size=replay_buffer_size, + discount=discount, + reward_scaling=reward_scaling, + critic_learning_rate=critic_learning_rate, + actor_learning_rate=actor_learning_rate, + policy_learning_rate=policy_learning_rate, + noise_clip=noise_clip, + policy_noise=policy_noise, + soft_tau_update=soft_tau_update, + policy_delay=policy_delay, + ) + + # Get the emitter + variation_fn = functools.partial( + isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma + ) + + dcg_emitter = DCRLMEEmitter( + config=dcg_emitter_config, + policy_network=policy_network, + actor_network=actor_dc_network, + env=env, + variation_fn=variation_fn, + ) + + # Instantiate MAP Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=dcg_emitter, + metrics_function=metrics_function, + ) + + # compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key) + + @jax.jit + def update_scan_fn(carry: Any, unused: Any) -> Any: + # iterate over grid + repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) + + return (repertoire, emitter_state, random_key), metrics + + # Run the algorithm + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( + update_scan_fn, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None)