From b4125c3159f4bddb9ff9683c09e88cd923fd83c9 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Tue, 9 Jan 2024 10:03:18 +0000 Subject: [PATCH 01/20] feat(algo): add DCG-MAP-Elites (#167) * a new method for MAP-Elites repertoire, that enables to samples individuals with their corresponding descriptors. * a new output extra_info for Emitter.emit methods that is similar to the extra_scores of the scoring function, and that enables to pass information from the emit step to the state_update (necessary for DCG-MAP-Elites). * a new DCGTransition that add desc and desc_prime to the QDTransition. * descriptor-conditioned TD3 loss, descriptor-conditioned scoring functions, descriptor-conditioned MLP * two new reward wrappers to clip and offset the reward (necessary for DCG-MAP-Elites). --- examples/distributed_mapelites.ipynb | 2 +- examples/me_sac_pbt.ipynb | 2 +- examples/me_td3_pbt.ipynb | 2 +- examples/mome.ipynb | 4 +- examples/nsga2_spea2.ipynb | 6 +- qdax/baselines/genetic_algorithm.py | 21 +- qdax/baselines/nsga2.py | 15 +- qdax/baselines/spea2.py | 15 +- qdax/core/aurora.py | 24 +- qdax/core/containers/mapelites_repertoire.py | 32 + qdax/core/distributed_map_elites.py | 21 +- qdax/core/emitters/cma_emitter.py | 14 +- qdax/core/emitters/cma_mega_emitter.py | 16 +- qdax/core/emitters/cma_pool_emitter.py | 25 +- qdax/core/emitters/cma_rnd_emitter.py | 10 +- qdax/core/emitters/dcg_me_emitter.py | 88 ++ qdax/core/emitters/dpg_emitter.py | 25 +- qdax/core/emitters/emitter.py | 12 +- qdax/core/emitters/mees_emitter.py | 26 +- qdax/core/emitters/multi_emitter.py | 29 +- qdax/core/emitters/omg_mega_emitter.py | 28 +- qdax/core/emitters/pbt_me_emitter.py | 20 +- qdax/core/emitters/qdcg_emitter.py | 763 ++++++++++++++++++ qdax/core/emitters/qpg_emitter.py | 27 +- qdax/core/emitters/standard_emitters.py | 6 +- qdax/core/map_elites.py | 26 +- qdax/core/mels.py | 17 +- qdax/core/mome.py | 17 +- qdax/core/neuroevolution/buffers/buffer.py | 159 +++- qdax/core/neuroevolution/losses/td3_loss.py | 110 ++- qdax/core/neuroevolution/mdp_utils.py | 56 +- qdax/core/neuroevolution/networks/networks.py | 105 ++- qdax/environments/base_wrappers.py | 11 +- qdax/environments/wrappers.py | 79 +- qdax/tasks/brax_envs.py | 154 +++- requirements.txt | 5 +- tests/baselines_test/ga_test.py | 6 +- tests/baselines_test/me_pbt_sac_test.py | 4 +- tests/baselines_test/me_pbt_td3_test.py | 4 +- tests/core_test/mome_test.py | 4 +- 40 files changed, 1798 insertions(+), 192 deletions(-) create mode 100644 qdax/core/emitters/dcg_me_emitter.py create mode 100644 qdax/core/emitters/qdcg_emitter.py diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 434725a3..18d4f0f3 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -348,7 +348,7 @@ "repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n", " centroids=centroids,\n", " devices=devices,\n", - ")(init_genotypes=init_variables, random_key=random_key)" + ")(genotypes=init_variables, random_key=random_key)" ] }, { diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 6d4dfdfe..6b4ae0b5 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -311,7 +311,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 1cf17c5e..ca127e72 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -314,7 +314,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { diff --git a/examples/mome.ipynb b/examples/mome.ipynb index a4ca36a6..05387158 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -212,7 +212,7 @@ "# initial population\n", "random_key = jax.random.PRNGKey(42)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -303,7 +303,7 @@ "outputs": [], "source": [ "repertoire, emitter_state, random_key = mome.init(\n", - " init_genotypes,\n", + " genotypes,\n", " centroids,\n", " pareto_front_max_length,\n", " random_key\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 51c5f5bd..e10c0d91 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -189,7 +189,7 @@ "# Initial population\n", "random_key = jax.random.PRNGKey(0)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -238,7 +238,7 @@ "\n", "# init nsga2\n", "repertoire, emitter_state, random_key = nsga2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " random_key\n", ")" @@ -303,7 +303,7 @@ "\n", "# init spea2\n", "repertoire, emitter_state, random_key = spea2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " num_neighbours,\n", " random_key\n", diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index 0714fb6c..a01a13b1 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -39,12 +39,12 @@ def __init__( @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]: """Initialize a GARepertoire with an initial population of genotypes. Args: - init_genotypes: the initial population of genotypes + genotypes: the initial population of genotypes population_size: the maximal size of the repertoire random_key: a random key to handle stochastic operations @@ -54,26 +54,21 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = GARepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=None, extra_scores=extra_scores, @@ -108,7 +103,7 @@ def update( """ # generate offsprings - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) @@ -127,7 +122,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=None, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index 663d6f0e..a889eadc 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -28,31 +28,36 @@ class NSGA2(GeneticAlgorithm): @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]: # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = NSGA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index 72ec2791..c52063b6 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -40,7 +40,7 @@ class SPEA2(GeneticAlgorithm): ) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, population_size: int, num_neighbours: int, random_key: RNGKey, @@ -48,12 +48,12 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = SPEA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, num_neighbours=num_neighbours, @@ -61,14 +61,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index fed716e3..a0968ccc 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -118,7 +118,7 @@ def container_size_control( def init( self, - init_genotypes: Genotype, + genotypes: Genotype, aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, @@ -128,7 +128,7 @@ def init( genotypes. Also performs the first training of the AURORA encoder. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) aurora_extra_info: information to perform AURORA encodings, such as the encoder parameters @@ -141,7 +141,7 @@ def init( the emitter, and the updated information to perform AURORA encodings """ fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, + genotypes, random_key, ) @@ -150,7 +150,7 @@ def init( descriptors = self._encoder_fn(observations, aurora_extra_info) repertoire = UnstructuredRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, observations=observations, @@ -160,13 +160,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, - genotypes=init_genotypes, + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -208,9 +204,10 @@ def update( a new key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, @@ -232,10 +229,11 @@ def update( # update emitter state after scoring is made emitter_state = self._emitter.state_update( emitter_state=emitter_state, + repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index aed74c78..b1145c34 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -228,6 +228,38 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey return samples, random_key + @partial(jax.jit, static_argnames=("num_samples",)) + def sample_with_descs( + self, + random_key: RNGKey, + num_samples: int, + ) -> Tuple[Genotype, Descriptor, RNGKey]: + """Sample elements in the repertoire. + + Args: + random_key: a jax PRNG random key + num_samples: the number of elements to be sampled + + Returns: + samples: a batch of genotypes sampled in the repertoire + random_key: an updated jax PRNG random key + """ + + repertoire_empty = self.fitnesses == -jnp.inf + p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) + + random_key, subkey = jax.random.split(random_key) + samples = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.genotypes, + ) + descs = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.descriptors, + ) + + return samples, descs, random_key + @jax.jit def add( self, diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index c8a1ea44..7b5609f2 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -17,7 +17,7 @@ class DistributedMAPElites(MAPElites): @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -30,7 +30,7 @@ def init( devices. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -41,7 +41,7 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # gather across all devices @@ -51,7 +51,7 @@ def init( gathered_descriptors, ) = jax.tree_util.tree_map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), - (init_genotypes, fitnesses, descriptors), + (genotypes, fitnesses, descriptors), ) # init the repertoire @@ -64,14 +64,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -108,7 +113,7 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) # scores the offsprings @@ -138,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index f9d58caa..66e5677a 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -99,14 +99,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -135,7 +141,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -154,7 +160,7 @@ def emit( cmaes_state=emitter_state.cmaes_state, random_key=random_key ) - return offsprings, random_key + return offsprings, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index f63654fd..1fd0e1e6 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -100,14 +100,20 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAMEGAState, RNGKey]: """ Initializes the CMA-MEGA emitter. Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -117,7 +123,7 @@ def init( # define init theta as 0 theta = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x[:1, ...]), - init_genotypes, + genotypes, ) # score it @@ -147,7 +153,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAMEGAState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -181,7 +187,7 @@ def emit( # Compute new candidates new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) - return new_thetas, random_key + return new_thetas, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index d5424a01..24556f8b 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -49,14 +49,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAPoolEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -67,7 +73,14 @@ def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: random_key = carry - emitter_state, random_key = self._emitter.init(init_genotypes, random_key) + emitter_state, random_key = self._emitter.init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) return random_key, emitter_state # init all the emitter states @@ -91,7 +104,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAPoolEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. @@ -111,11 +124,11 @@ def emit( ) # use it to emit offsprings - offsprings, random_key = self._emitter.emit( + offsprings, extra_info, random_key = self._emitter.emit( repertoire, used_emitter_state, random_key ) - return offsprings, random_key + return offsprings, extra_info, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index 4afb2f5d..e05cc453 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -35,14 +35,20 @@ class CMARndEmitterState(CMAEmitterState): class CMARndEmitter(CMAEmitter): @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMARndEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py new file mode 100644 index 00000000..94e0bb9d --- /dev/null +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Callable, Tuple + +import flax.linen as nn + +from qdax.core.emitters.multi_emitter import MultiEmitter +from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Params, RNGKey + + +@dataclass +class DCGMEConfig: + """Configuration for DCGME Algorithm""" + + ga_batch_size: int = 128 + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + # PG emitter + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class DCGMEEmitter(MultiEmitter): + def __init__( + self, + config: DCGMEConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]], + ) -> None: + self._config = config + self._env = env + self._variation_fn = variation_fn + + qdcg_config = QualityDCGConfig( + qpg_batch_size=config.qpg_batch_size, + ai_batch_size=config.ai_batch_size, + lengthscale=config.lengthscale, + critic_hidden_layer_size=config.critic_hidden_layer_size, + num_critic_training_steps=config.num_critic_training_steps, + num_pg_training_steps=config.num_pg_training_steps, + batch_size=config.batch_size, + replay_buffer_size=config.replay_buffer_size, + discount=config.discount, + reward_scaling=config.reward_scaling, + critic_learning_rate=config.critic_learning_rate, + actor_learning_rate=config.actor_learning_rate, + policy_learning_rate=config.policy_learning_rate, + noise_clip=config.noise_clip, + policy_noise=config.policy_noise, + soft_tau_update=config.soft_tau_update, + policy_delay=config.policy_delay, + ) + + # define the quality emitter + q_emitter = QualityDCGEmitter( + config=qdcg_config, + policy_network=policy_network, + actor_network=actor_network, + env=env, + ) + + # define the GA emitter + ga_emitter = MixingEmitter( + mutation_fn=lambda x, r: (x, r), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=config.ga_batch_size, + ) + + super().__init__(emitters=(q_emitter, ga_emitter)) diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 8b858db4..2c55cbd2 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -77,12 +77,18 @@ def __init__( self._score_novelty = score_novelty def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[DiversityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -90,7 +96,14 @@ def init( """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init(init_genotypes, random_key) + diversity_emitter_state, random_key = super().init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) # store elements in a dictionary attributes_dict = vars(diversity_emitter_state) @@ -102,6 +115,12 @@ def init( max_size=self._config.archive_max_size, ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + archive = archive.insert(transitions.state_desc) + # init emitter state emitter_state = DiversityPGEmitterState( # retrieve all attributes from the QualityPGEmitterState diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index d32ed981..056798ba 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -30,14 +30,20 @@ class EmitterState(PyTreeNode): class Emitter(ABC): def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be outputted. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -51,7 +57,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire. diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index b5bb1ada..0a03a6ba 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -236,26 +236,32 @@ def batch_size(self) -> int: static_argnames=("self",), ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[MEESEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: The initial state of the MEESEmitter, a new random key. """ # Initialisation requires one initial genotype - if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1: - init_genotypes = jax.tree_util.tree_map( + if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1: + genotypes = jax.tree_util.tree_map( lambda x: x[0], - init_genotypes, + genotypes, ) # Initialise optimizer - initial_optimizer_state = self._optimizer.init(init_genotypes) + initial_optimizer_state = self._optimizer.init(genotypes) # Create empty Novelty archive if self._config.use_explore: @@ -270,7 +276,7 @@ def init( # Create empty updated genotypes and fitness last_updated_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]), - init_genotypes, + genotypes, ) last_updated_fitnesses = -jnp.inf * jnp.ones( shape=self._config.last_updated_size @@ -280,7 +286,7 @@ def init( MEESEmitterState( initial_optimizer_state=initial_optimizer_state, optimizer_state=initial_optimizer_state, - offspring=init_genotypes, + offspring=genotypes, generation_count=0, novelty_archive=novelty_archive, last_updated_genotypes=last_updated_genotypes, @@ -300,7 +306,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: MEESEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Return the offspring generated through gradient update. Params: @@ -313,7 +319,7 @@ def emit( a new jax PRNG key """ - return emitter_state.offspring, random_key + return emitter_state.offspring, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 2da46639..b3ad23c6 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -56,13 +56,19 @@ def get_indexes_separation_batches( return tuple(indexes_separation_batches) def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """ Initialize the state of the emitter. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -76,7 +82,14 @@ def init( # init all emitter states - gather them emitter_states = [] for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init(init_genotypes, subkey_emitter) + emitter_state, _ = emitter.init( + subkey_emitter, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) emitter_states.append(emitter_state) return MultiEmitterState(tuple(emitter_states)), random_key @@ -87,7 +100,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[MultiEmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Emit new population. Use all the sub emitters to emit subpopulation and gather them. @@ -108,21 +121,25 @@ def emit( # emit from all emitters and gather offsprings all_offsprings = [] + all_extra_info: ExtraScores = {} for emitter, sub_emitter_state, subkey_emitter in zip( self.emitters, emitter_state.emitter_states, subkeys, ): - genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) + genotype, extra_info, _ = emitter.emit( + repertoire, sub_emitter_state, subkey_emitter + ) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) + all_extra_info = {**all_extra_info, **extra_info} # concatenate offsprings together offsprings = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=0), *all_offsprings ) - return offsprings, random_key + return offsprings, all_extra_info, random_key @partial(jax.jit, static_argnames=("self",)) def state_update( diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 7336750d..54766152 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -84,20 +84,26 @@ def __init__( self._num_descriptors = num_descriptors def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[OMGMEGAEmitterState, RNGKey]: """Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: The initial emitter state. """ # retrieve one genotype from the population - first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes) # add a dimension of size num descriptors + 1 gradient_genotype = jax.tree_util.tree_map( @@ -112,6 +118,18 @@ def init( genotype=gradient_genotype, centroids=self._centroids ) + # get gradients out of the extra scores + assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key" + gradients = extra_scores["gradients"] + + # update the gradients repertoire + gradients_repertoire = gradients_repertoire.add( + gradients, + descriptors, + fitnesses, + extra_scores, + ) + return ( OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire), random_key, @@ -126,7 +144,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: OMGMEGAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates. @@ -190,7 +208,7 @@ def emit( lambda x, y: x + y, genotypes, update_grad ) - return new_genotypes, random_key + return new_genotypes, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 3fdb4418..a2266bfa 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -91,12 +91,18 @@ def __init__( ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[PBTEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -145,13 +151,13 @@ def init( # Create emitter state # keep only pg population size training states if more are provided - init_genotypes = jax.tree_util.tree_map( - lambda x: x[: self._config.pg_population_size_per_device], init_genotypes + genotypes = jax.tree_util.tree_map( + lambda x: x[: self._config.pg_population_size_per_device], genotypes ) emitter_state = PBTEmitterState( replay_buffers=replay_buffers, env_states=env_states, - training_states=init_genotypes, + training_states=genotypes, random_key=subkey2, ) @@ -166,7 +172,7 @@ def emit( repertoire: Repertoire, emitter_state: PBTEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid. @@ -199,7 +205,7 @@ def emit( else: genotypes = x_mutation_pg - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py new file mode 100644 index 00000000..0d560cbb --- /dev/null +++ b/qdax/core/emitters/qdcg_emitter.py @@ -0,0 +1,763 @@ +"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm +in JAX for Brax environments. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Any, Tuple + +import flax.linen as nn +import jax +import optax +from jax import numpy as jnp + +from qdax.core.containers.repertoire import Repertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer +from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn +from qdax.core.neuroevolution.networks.networks import QModuleDC +from qdax.environments.base_wrappers import QDEnv +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey + + +@dataclass +class QualityDCGConfig: + """Configuration for QualityDCG Emitter""" + + qpg_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class QualityDCGEmitterState(EmitterState): + """Contains training state for the learner.""" + + critic_params: Params + critic_opt_state: optax.OptState + actor_params: Params + actor_opt_state: optax.OptState + target_critic_params: Params + target_actor_params: Params + replay_buffer: ReplayBuffer + random_key: RNGKey + steps: jnp.ndarray + + +class QualityDCGEmitter(Emitter): + """ + A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites + (PGA-Map-Elites) algorithm. + """ + + def __init__( + self, + config: QualityDCGConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + ) -> None: + self._config = config + self._env = env + self._policy_network = policy_network + self._actor_network = actor_network + + # Init Critics + critic_network = QModuleDC( + n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size + ) + self._critic_network = critic_network + + # Set up the losses and optimizers - return the opt states + ( + self._policy_loss_fn, + self._actor_loss_fn, + self._critic_loss_fn, + ) = make_td3_loss_dc_fn( + policy_fn=policy_network.apply, + actor_fn=actor_network.apply, + critic_fn=critic_network.apply, + reward_scaling=self._config.reward_scaling, + discount=self._config.discount, + noise_clip=self._config.noise_clip, + policy_noise=self._config.policy_noise, + ) + + # Init optimizers + self._actor_optimizer = optax.adam( + learning_rate=self._config.actor_learning_rate + ) + self._critic_optimizer = optax.adam( + learning_rate=self._config.critic_learning_rate + ) + self._policies_optimizer = optax.adam( + learning_rate=self._config.policy_learning_rate + ) + + @property + def batch_size(self) -> int: + """ + Returns: + the batch size emitted by the emitter. + """ + return self._config.qpg_batch_size + self._config.ai_batch_size + + @property + def use_all_data(self) -> bool: + """Whether to use all data or not when used along other emitters. + + QualityPGEmitter uses the transitions from the genotypes that were generated + by other emitters. + """ + return True + + def init( + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> Tuple[QualityDCGEmitterState, RNGKey]: + """Initializes the emitter state. + + Args: + genotypes: The initial population. + random_key: A random key. + + Returns: + The initial state of the PGAMEEmitter, a new random key. + """ + + observation_size = jax.tree_util.tree_leaves(genotypes)[1].shape[1] + descriptor_size = self._env.behavior_descriptor_length + action_size = self._env.action_size + + # Initialise critic, greedy actor and population + random_key, subkey = jax.random.split(random_key) + fake_obs = jnp.zeros(shape=(observation_size,)) + fake_desc = jnp.zeros(shape=(descriptor_size,)) + fake_action = jnp.zeros(shape=(action_size,)) + + critic_params = self._critic_network.init( + subkey, obs=fake_obs, actions=fake_action, desc=fake_desc + ) + target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + + random_key, subkey = jax.random.split(random_key) + actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) + target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) + + # Prepare init optimizer states + critic_opt_state = self._critic_optimizer.init(critic_params) + actor_opt_state = self._actor_optimizer.init(actor_params) + + # Initialize replay buffer + dummy_transition = DCGTransition.init_dummy( + observation_dim=self._env.observation_size, + action_dim=action_size, + descriptor_dim=descriptor_size, + ) + + replay_buffer = ReplayBuffer.init( + buffer_size=self._config.replay_buffer_size, transition=dummy_transition + ) + + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_normalized + ) + replay_buffer = replay_buffer.insert(transitions) + + # Initial training state + random_key, subkey = jax.random.split(random_key) + emitter_state = QualityDCGEmitterState( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + replay_buffer=replay_buffer, + random_key=subkey, + steps=jnp.array(0), + ) + + return emitter_state, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: + """Compute the similarity between two batches of descriptors. + Args: + descs_1: batch of descriptors. + descs_2: batch of descriptors. + Returns: + batch of similarity measures. + """ + return jnp.exp( + -jnp.linalg.norm(descs_1 - descs_2, axis=-1) / self._config.lengthscale + ) + + @partial(jax.jit, static_argnames=("self",)) + def _normalize_desc(self, desc: Descriptor) -> Descriptor: + return ( + 2 + * (desc - self._env.behavior_descriptor_limits[0]) + / ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) + - 1 + ) + + @partial(jax.jit, static_argnames=("self",)) + def _unnormalize_desc(self, desc_normalized: Descriptor) -> Descriptor: + return 0.5 * ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) * desc_normalized + 0.5 * ( + self._env.behavior_descriptor_limits[1] + + self._env.behavior_descriptor_limits[0] + ) + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_kernel_bias_with_desc( + self, actor_dc_params: Params, desc: Descriptor + ) -> Tuple[Params, Params]: + """ + Compute the equivalent bias of the first layer of the actor network + given a descriptor. + """ + # Extract kernel and bias of the first layer + kernel = actor_dc_params["params"]["Dense_0"]["kernel"] + bias = actor_dc_params["params"]["Dense_0"]["bias"] + + # Compute the equivalent bias + equivalent_kernel = kernel[: -desc.shape[0], :] + equivalent_bias = bias + jnp.dot(desc, kernel[-desc.shape[0] :]) + + return equivalent_kernel, equivalent_bias + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_params_with_desc( + self, actor_dc_params: Params, desc: Descriptor + ) -> Params: + desc_normalized = self._normalize_desc(desc) + ( + equivalent_kernel, + equivalent_bias, + ) = self._compute_equivalent_kernel_bias_with_desc( + actor_dc_params, desc_normalized + ) + actor_dc_params["params"]["Dense_0"]["kernel"] = equivalent_kernel + actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias + return actor_dc_params + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit( + self, + repertoire: Repertoire, + emitter_state: QualityDCGEmitterState, + random_key: RNGKey, + ) -> Tuple[Genotype, ExtraScores, RNGKey]: + """Do a step of PG emission. + + Args: + repertoire: the current repertoire of genotypes + emitter_state: the state of the emitter used + random_key: a random key + + Returns: + A batch of offspring, the new emitter state and a new key. + """ + # PG emitter + parents_pg, descs_pg, random_key = repertoire.sample_with_descs( + random_key, self._config.qpg_batch_size + ) + genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) + + # Actor injection emitter + _, descs_ai, random_key = repertoire.sample_with_descs( + random_key, self._config.ai_batch_size + ) + descs_ai = descs_ai.reshape( + descs_ai.shape[0], self._env.behavior_descriptor_length + ) + genotypes_ai = self.emit_ai(emitter_state, descs_ai) + + # Concatenate PG and AI genotypes + genotypes = jax.tree_util.tree_map( + lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai + ) + + return ( + genotypes, + {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, + random_key, + ) + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit_pg( + self, + emitter_state: QualityDCGEmitterState, + parents: Genotype, + descs: Descriptor, + ) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + mutation_fn = partial( + self._mutation_function_pg, + emitter_state=emitter_state, + ) + offsprings = jax.vmap(mutation_fn)(parents, descs) + + return offsprings + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit_ai( + self, emitter_state: QualityDCGEmitterState, descs: Descriptor + ) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + offsprings = jax.vmap( + self._compute_equivalent_params_with_desc, in_axes=(None, 0) + )(emitter_state.actor_params, descs) + + return offsprings + + @partial(jax.jit, static_argnames=("self",)) + def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: + """Emit the greedy actor. + + Simply needs to be retrieved from the emitter state. + + Args: + emitter_state: the current emitter state, it stores the + greedy actor. + + Returns: + The parameters of the actor. + """ + return emitter_state.actor_params + + @partial( + jax.jit, + static_argnames=("self",), + ) + def state_update( + self, + emitter_state: QualityDCGEmitterState, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> QualityDCGEmitterState: + """This function gives an opportunity to update the emitter state + after the genotypes have been scored. + + Here it is used to fill the Replay Buffer with the transitions + from the scoring of the genotypes, and then the training of the + critic/actor happens. Hence the params of critic/actor are updated, + as well as their optimizer states. + + Args: + emitter_state: current emitter state. + repertoire: the current genotypes repertoire + genotypes: unused here - but compulsory in the signature. + fitnesses: unused here - but compulsory in the signature. + descriptors: unused here - but compulsory in the signature. + extra_scores: extra information coming from the scoring function, + this contains the transitions added to the replay buffer. + + Returns: + New emitter state where the replay buffer has been filled with + the new experienced transitions. + """ + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc_prime = jnp.concatenate( + [ + extra_scores["desc_prime"], + descriptors[self._config.qpg_batch_size + self._config.ai_batch_size :], + ], + axis=0, + ) + desc_prime = jnp.repeat(desc_prime[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + + desc_prime_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc_prime) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_prime_normalized + ) + + # Add transitions to replay buffer + replay_buffer = emitter_state.replay_buffer.insert(transitions) + emitter_state = emitter_state.replace(replay_buffer=replay_buffer) + + # sample transitions from the replay buffer + random_key, subkey = jax.random.split(emitter_state.random_key) + transitions, random_key = replay_buffer.sample( + subkey, self._config.num_critic_training_steps * self._config.batch_size + ) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, + ( + self._config.num_critic_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) + transitions = transitions.replace( + rewards=self._similarity(transitions.desc, transitions.desc_prime) + * transitions.rewards + ) + emitter_state = emitter_state.replace(random_key=random_key) + + def scan_train_critics( + carry: QualityDCGEmitterState, + transitions: DCGTransition, + ) -> Tuple[QualityDCGEmitterState, Any]: + emitter_state = carry + new_emitter_state = self._train_critics(emitter_state, transitions) + return new_emitter_state, () + + # Train critics and greedy actor + emitter_state, _ = jax.lax.scan( + scan_train_critics, + emitter_state, + transitions, + length=self._config.num_critic_training_steps, + ) + + return emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _train_critics( + self, emitter_state: QualityDCGEmitterState, transitions: DCGTransition + ) -> QualityDCGEmitterState: + """Apply one gradient step to critics and to the greedy actor + (contained in carry in training_state), then soft update target critics + and target actor. + + Those updates are very similar to those made in TD3. + + Args: + emitter_state: actual emitter state + + Returns: + New emitter state where the critic and the greedy actor have been + updated. Optimizer states have also been updated in the process. + """ + # Update Critic + ( + critic_opt_state, + critic_params, + target_critic_params, + random_key, + ) = self._update_critic( + critic_params=emitter_state.critic_params, + target_critic_params=emitter_state.target_critic_params, + target_actor_params=emitter_state.target_actor_params, + critic_opt_state=emitter_state.critic_opt_state, + transitions=transitions, + random_key=emitter_state.random_key, + ) + + # Update greedy actor + (actor_opt_state, actor_params, target_actor_params,) = jax.lax.cond( + emitter_state.steps % self._config.policy_delay == 0, + lambda x: self._update_actor(*x), + lambda _: ( + emitter_state.actor_opt_state, + emitter_state.actor_params, + emitter_state.target_actor_params, + ), + operand=( + emitter_state.actor_params, + emitter_state.actor_opt_state, + emitter_state.target_actor_params, + emitter_state.critic_params, + transitions, + ), + ) + + # Create new training state + new_emitter_state = emitter_state.replace( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + random_key=random_key, + steps=emitter_state.steps + 1, + ) + + return new_emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _update_critic( + self, + critic_params: Params, + target_critic_params: Params, + target_actor_params: Params, + critic_opt_state: Params, + transitions: DCGTransition, + random_key: RNGKey, + ) -> Tuple[Params, Params, Params, RNGKey]: + + # compute loss and gradients + random_key, subkey = jax.random.split(random_key) + critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + critic_params, + target_actor_params, + target_critic_params, + transitions, + subkey, + ) + critic_updates, critic_opt_state = self._critic_optimizer.update( + critic_gradient, critic_opt_state + ) + + # update critic + critic_params = optax.apply_updates(critic_params, critic_updates) + + # Soft update of target critic network + target_critic_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_critic_params, + critic_params, + ) + + return critic_opt_state, critic_params, target_critic_params, random_key + + @partial(jax.jit, static_argnames=("self",)) + def _update_actor( + self, + actor_params: Params, + actor_opt_state: optax.OptState, + target_actor_params: Params, + critic_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params, Params]: + + # Update greedy actor + policy_loss, policy_gradient = jax.value_and_grad(self._actor_loss_fn)( + actor_params, + critic_params, + transitions, + ) + ( + policy_updates, + actor_opt_state, + ) = self._actor_optimizer.update(policy_gradient, actor_opt_state) + actor_params = optax.apply_updates(actor_params, policy_updates) + + # Soft update of target greedy actor + target_actor_params = jax.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_actor_params, + actor_params, + ) + + return ( + actor_opt_state, + actor_params, + target_actor_params, + ) + + @partial( + jax.jit, + static_argnames=("self",), + ) + def _mutation_function_pg( + self, + policy_params: Genotype, + descs: Descriptor, + emitter_state: QualityDCGEmitterState, + ) -> Genotype: + """Apply pg mutation to a policy via multiple steps of gradient descent. + First, update the rewards to be diversity rewards, then apply the gradient + steps. + + Args: + policy_params: a policy, supposed to be a differentiable neural + network. + emitter_state: the current state of the emitter, containing among others, + the replay buffer, the critic. + + Returns: + The updated params of the neural network. + """ + # Get transitions + transitions, random_key = emitter_state.replay_buffer.sample( + emitter_state.random_key, + sample_size=self._config.num_pg_training_steps * self._config.batch_size, + ) + descs_prime = jnp.tile( + descs, (self._config.num_pg_training_steps * self._config.batch_size, 1) + ) + descs_prime_normalized = jax.vmap(self._normalize_desc)(descs_prime) + transitions = transitions.replace( + rewards=self._similarity(transitions.desc, descs_prime_normalized) + * transitions.rewards, + desc_prime=descs_prime_normalized, + ) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, + ( + self._config.num_pg_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) + + # Replace random_key + emitter_state = emitter_state.replace(random_key=random_key) + + # Define new policy optimizer state + policy_opt_state = self._policies_optimizer.init(policy_params) + + def scan_train_policy( + carry: Tuple[QualityDCGEmitterState, Genotype, optax.OptState], + transitions: DCGTransition, + ) -> Tuple[Tuple[QualityDCGEmitterState, Genotype, optax.OptState], Any]: + emitter_state, policy_params, policy_opt_state = carry + ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ) = self._train_policy( + emitter_state, + policy_params, + policy_opt_state, + transitions, + ) + return ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ), () + + (emitter_state, policy_params, policy_opt_state,), _ = jax.lax.scan( + scan_train_policy, + (emitter_state, policy_params, policy_opt_state), + transitions, + length=self._config.num_pg_training_steps, + ) + + return policy_params + + @partial(jax.jit, static_argnames=("self",)) + def _train_policy( + self, + emitter_state: QualityDCGEmitterState, + policy_params: Params, + policy_opt_state: optax.OptState, + transitions: DCGTransition, + ) -> Tuple[QualityDCGEmitterState, Params, optax.OptState]: + """Apply one gradient step to a policy (called policy_params). + + Args: + emitter_state: current state of the emitter. + policy_params: parameters corresponding to the weights and bias of + the neural network that defines the policy. + + Returns: + The new emitter state and new params of the NN. + """ + # update policy + policy_opt_state, policy_params = self._update_policy( + critic_params=emitter_state.critic_params, + policy_opt_state=policy_opt_state, + policy_params=policy_params, + transitions=transitions, + ) + + return emitter_state, policy_params, policy_opt_state + + @partial(jax.jit, static_argnames=("self",)) + def _update_policy( + self, + critic_params: Params, + policy_opt_state: optax.OptState, + policy_params: Params, + transitions: DCGTransition, + ) -> Tuple[optax.OptState, Params]: + + # compute loss + _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_params, + critic_params, + transitions, + ) + # Compute gradient and update policies + ( + policy_updates, + policy_opt_state, + ) = self._policies_optimizer.update(policy_gradient, policy_opt_state) + policy_params = optax.apply_updates(policy_params, policy_updates) + + return policy_opt_state, policy_params diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c07e3b18..4a173b51 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -119,12 +119,18 @@ def use_all_data(self) -> bool: return True def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[QualityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -144,8 +150,8 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) - target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) + target_actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) # Prepare init optimizer states critic_optimizer_state = self._critic_optimizer.init(critic_params) @@ -162,6 +168,13 @@ def init( buffer_size=self._config.replay_buffer_size, transition=dummy_transition ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + # add transitions in the replay buffer + replay_buffer = replay_buffer.insert(transitions) + # Initial training state random_key, subkey = jax.random.split(random_key) emitter_state = QualityPGEmitterState( @@ -171,9 +184,9 @@ def init( actor_opt_state=actor_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, + replay_buffer=replay_buffer, random_key=subkey, steps=jnp.array(0), - replay_buffer=replay_buffer, ) return emitter_state, random_key @@ -187,7 +200,7 @@ def emit( repertoire: Repertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: @@ -223,7 +236,7 @@ def emit( offspring_actor, ) - return genotypes, random_key + return genotypes, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 8b877792..860962d4 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -6,7 +6,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Genotype, RNGKey +from qdax.types import ExtraScores, Genotype, RNGKey class MixingEmitter(Emitter): @@ -31,7 +31,7 @@ def emit( repertoire: Repertoire, emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, @@ -75,7 +75,7 @@ def emit( x_mutation, ) - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index c71b0013..8b649d0c 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -23,7 +23,7 @@ class MAPElites: """Core elements of the MAP-Elites algorithm. Note: Although very similar to the GeneticAlgorithm, we decided to keep the - MAPElites class independent of the GeneticAlgorithm class at the moment to keep + MAPElites class independant of the GeneticAlgorithm class at the moment to keep elements explicit. Args: @@ -52,7 +52,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -62,9 +62,9 @@ def init( such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tessellation centroids of shape (batch_size, num_descriptors) + centroids: tesselation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. Returns: @@ -73,12 +73,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MapElitesRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -87,14 +87,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -129,9 +124,10 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, random_key @@ -147,7 +143,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 6c06b785..6dc8f551 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -55,7 +55,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: @@ -64,7 +64,7 @@ def init( be computed with any method such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -75,12 +75,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MELSRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -89,14 +89,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/mome.py b/qdax/core/mome.py index 2a004f59..db450b9a 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -23,7 +23,7 @@ class MOME(MAPElites): @partial(jax.jit, static_argnames=("self", "pareto_front_max_length")) def init( self, - init_genotypes: jnp.ndarray, + genotypes: jnp.ndarray, centroids: Centroid, pareto_front_max_length: int, random_key: RNGKey, @@ -33,7 +33,7 @@ def init( CVT or Euclidean mapping. Args: - init_genotypes: genotypes of the initial population. + genotypes: genotypes of the initial population. centroids: centroids of the repertoire. pareto_front_max_length: maximum size of the pareto front. This is necessary to respect jax.jit fixed shape size constraint. @@ -45,12 +45,12 @@ def init( # first score fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MOMERepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -60,14 +60,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 42ed7552..d25c8c6c 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,15 @@ import jax import jax.numpy as jnp -from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor +from qdax.types import ( + Action, + Descriptor, + Done, + Observation, + Reward, + RNGKey, + StateDescriptor, +) class Transition(flax.struct.PyTreeNode): @@ -262,6 +270,155 @@ def init_dummy( # type: ignore return dummy_transition +class DCGTransition(QDTransition): + """Stores data corresponding to a transition collected by a QD algorithm.""" + + desc: Descriptor + desc_prime: Descriptor + + @property + def descriptor_dim(self) -> int: + """ + Returns: + the dimension of the descriptors. + """ + return self.state_desc.shape[-1] # type: ignore + + @property + def flatten_dim(self) -> int: + """ + Returns: + the dimension of the transition once flattened. + """ + flatten_dim = ( + 2 * self.observation_dim + + self.action_dim + + 3 + + 2 * self.state_descriptor_dim + + 2 * self.descriptor_dim + ) + return flatten_dim + + def flatten(self) -> jnp.ndarray: + """ + Returns: + a jnp.ndarray that corresponds to the flattened transition. + """ + flatten_transition = jnp.concatenate( + [ + self.obs, + self.next_obs, + jnp.expand_dims(self.rewards, axis=-1), + jnp.expand_dims(self.dones, axis=-1), + jnp.expand_dims(self.truncations, axis=-1), + self.actions, + self.state_desc, + self.next_state_desc, + self.desc, + self.desc_prime, + ], + axis=-1, + ) + return flatten_transition + + @classmethod + def from_flatten( + cls, + flattened_transition: jnp.ndarray, + transition: QDTransition, + ) -> QDTransition: + """ + Creates a transition from a flattened transition in a jnp.ndarray. + Args: + flattened_transition: flattened transition in a jnp.ndarray of shape + (batch_size, flatten_dim) + transition: a transition object (might be a dummy one) to + get the dimensions right + Returns: + a Transition object + """ + obs_dim = transition.observation_dim + action_dim = transition.action_dim + state_desc_dim = transition.state_descriptor_dim + desc_dim = transition.descriptor_dim + + obs = flattened_transition[:, :obs_dim] + next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)] + rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)]) + dones = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)] + ) + truncations = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)] + ) + actions = flattened_transition[ + :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim) + ] + state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim) : ( + 2 * obs_dim + 3 + action_dim + state_desc_dim + ), + ] + next_state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + ), + ] + desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim + ), + ] + desc_prime = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + 2 * desc_dim + ), + ] + return cls( + obs=obs, + next_obs=next_obs, + rewards=rewards, + dones=dones, + truncations=truncations, + actions=actions, + state_desc=state_desc, + next_state_desc=next_state_desc, + desc=desc, + desc_prime=desc_prime, + ) + + @classmethod + def init_dummy( # type: ignore + cls, observation_dim: int, action_dim: int, descriptor_dim: int + ) -> QDTransition: + """ + Initialize a dummy transition that then can be passed to constructors to get + all shapes right. + Args: + observation_dim: observation dimension + action_dim: action dimension + Returns: + a dummy transition + """ + dummy_transition = DCGTransition( + obs=jnp.zeros(shape=(1, observation_dim)), + next_obs=jnp.zeros(shape=(1, observation_dim)), + rewards=jnp.zeros(shape=(1,)), + dones=jnp.zeros(shape=(1,)), + truncations=jnp.zeros(shape=(1,)), + actions=jnp.zeros(shape=(1, action_dim)), + state_desc=jnp.zeros(shape=(1, descriptor_dim)), + next_state_desc=jnp.zeros(shape=(1, descriptor_dim)), + desc=jnp.zeros(shape=(1, descriptor_dim)), + desc_prime=jnp.zeros(shape=(1, descriptor_dim)), + ) + return dummy_transition + + class ReplayBuffer(flax.struct.PyTreeNode): """ A replay buffer where transitions are flattened before being stored. diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 7f34a036..e12797b9 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.types import Action, Descriptor, Observation, Params, RNGKey def make_td3_loss_fn( @@ -94,6 +94,110 @@ def _critic_loss_fn( return _policy_loss_fn, _critic_loss_fn +def make_td3_loss_dc_fn( + policy_fn: Callable[[Params, Observation], jnp.ndarray], + actor_fn: Callable[[Params, Observation, Descriptor], jnp.ndarray], + critic_fn: Callable[[Params, Observation, Action, Descriptor], jnp.ndarray], + reward_scaling: float, + discount: float, + noise_clip: float, + policy_noise: float, +) -> Tuple[ + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray], +]: + """Creates the loss functions for TD3. + Args: + policy_fn: forward pass through the neural network defining the policy. + actor_fn: forward pass through the neural network defining the + descriptor-conditioned policy. + critic_fn: forward pass through the neural network defining the + descriptor-conditioned critic. + reward_scaling: value to multiply the reward given by the environment. + discount: discount factor. + noise_clip: value that clips the noise to avoid extreme values. + policy_noise: noise applied to smooth the bootstrapping. + Returns: + Return the loss functions used to train the policy and the critic in TD3. + """ + + @jax.jit + def _policy_loss_fn( + policy_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Policy loss function for TD3 agent""" + action = policy_fn(policy_params, transitions.obs) + q_value = critic_fn( + critic_params, transitions.obs, action, transitions.desc_prime + ) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _actor_loss_fn( + actor_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Descriptor-conditioned policy loss function for TD3 agent""" + action = actor_fn(actor_params, transitions.obs, transitions.desc_prime) + q_value = critic_fn( + critic_params, transitions.obs, action, transitions.desc_prime + ) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _critic_loss_fn( + critic_params: Params, + target_actor_params: Params, + target_critic_params: Params, + transitions: Transition, + random_key: RNGKey, + ) -> jnp.ndarray: + """Descriptor-conditioned critic loss function for TD3 agent""" + noise = ( + jax.random.normal(random_key, shape=transitions.actions.shape) + * policy_noise + ).clip(-noise_clip, noise_clip) + + next_action = ( + actor_fn(target_actor_params, transitions.next_obs, transitions.desc_prime) + + noise + ).clip(-1.0, 1.0) + next_q = critic_fn( + target_critic_params, + transitions.next_obs, + next_action, + transitions.desc_prime, + ) + next_v = jnp.min(next_q, axis=-1) + target_q = jax.lax.stop_gradient( + transitions.rewards * reward_scaling + + (1.0 - transitions.dones) * discount * next_v + ) + q_old_action = critic_fn( + critic_params, transitions.obs, transitions.actions, transitions.desc_prime + ) + q_error = q_old_action - jnp.expand_dims(target_q, -1) + + # Better bootstrapping for truncated episodes. + q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1) + + # compute the loss + q_losses = jnp.mean(jnp.square(q_error), axis=-2) + q_loss = jnp.sum(q_losses, axis=-1) + + return q_loss + + return _policy_loss_fn, _actor_loss_fn, _critic_loss_fn + + def td3_policy_loss_fn( policy_params: Params, critic_params: Params, @@ -115,9 +219,7 @@ def td3_policy_loss_fn( """ action = policy_fn(policy_params, transitions.obs) - q_value = critic_fn( - critic_params, obs=transitions.obs, actions=action # type: ignore - ) + q_value = critic_fn(critic_params, transitions.obs, action) # type: ignore q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) return policy_loss diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 984d1aeb..3b077069 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Genotype, Params, RNGKey +from qdax.types import Descriptor, Genotype, Params, RNGKey class TrainingState(PyTreeNode): @@ -67,6 +67,60 @@ def _scan_play_step_fn( return state, transitions +@partial(jax.jit, static_argnames=("play_step_actor_dc_fn", "episode_length")) +def generate_unroll_actor_dc( + init_state: EnvState, + actor_dc_params: Params, + desc: Descriptor, + random_key: RNGKey, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[ + EnvState, + Descriptor, + Params, + RNGKey, + Transition, + ], + ], +) -> Tuple[EnvState, Transition]: + """Generates an episode according to the agent's policy and descriptor, + returns the final state of the episode and the transitions of the episode. + + Args: + init_state: first state of the rollout. + policy_dc_params: descriptor-conditioned policy params. + desc: descriptor the policy attempts to achieve. + random_key: random key for stochasiticity handling. + episode_length: length of the rollout. + play_step_fn: function describing how a step need to be taken. + + Returns: + A new state, the experienced transition. + """ + + def _scan_play_step_fn( + carry: Tuple[EnvState, Params, Descriptor, RNGKey], unused_arg: Any + ) -> Tuple[Tuple[EnvState, Params, Descriptor, RNGKey], Transition]: + ( + env_state, + actor_dc_params, + desc, + random_key, + transitions, + ) = play_step_actor_dc_fn(*carry) + return (env_state, actor_dc_params, desc, random_key), transitions + + (state, _, _, _), transitions = jax.lax.scan( + _scan_play_step_fn, + (init_state, actor_dc_params, desc, random_key), + (), + length=episode_length, + ) + return state, transitions + + @jax.jit def get_first_episode(transition: Transition) -> Transition: """Extracts the first episode from a batch of transitions, returns the batch of diff --git a/qdax/core/neuroevolution/networks/networks.py b/qdax/core/neuroevolution/networks/networks.py index b2b176ef..365c8d56 100644 --- a/qdax/core/neuroevolution/networks/networks.py +++ b/qdax/core/neuroevolution/networks/networks.py @@ -5,31 +5,51 @@ import flax.linen as nn import jax import jax.numpy as jnp -from brax.training import networks -class QModule(nn.Module): - """Q Module.""" +class MLP(nn.Module): + """MLP module.""" - hidden_layer_sizes: Tuple[int, ...] - n_critics: int = 2 + layer_sizes: Tuple[int, ...] + activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() + final_activation: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None + bias: bool = True + kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: - hidden = jnp.concatenate([obs, actions], axis=-1) - res = [] - for _ in range(self.n_critics): - q = networks.MLP( - layer_sizes=self.hidden_layer_sizes + (1,), - activation=nn.relu, - kernel_init=jax.nn.initializers.lecun_uniform(), - )(hidden) - res.append(q) - return jnp.concatenate(res, axis=-1) + def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: + hidden = obs + for i, hidden_size in enumerate(self.layer_sizes): + if i != len(self.layer_sizes) - 1: + hidden = nn.Dense( + hidden_size, + kernel_init=self.kernel_init, + use_bias=self.bias, + )(hidden) + hidden = self.activation(hidden) # type: ignore -class MLP(nn.Module): - """MLP module.""" + else: + if self.kernel_init_final is not None: + kernel_init = self.kernel_init_final + else: + kernel_init = self.kernel_init + + hidden = nn.Dense( + hidden_size, + kernel_init=kernel_init, + use_bias=self.bias, + )(hidden) + + if self.final_activation is not None: + hidden = self.final_activation(hidden) + + return hidden + + +class MLPDC(nn.Module): + """Descriptor-conditioned MLP module.""" layer_sizes: Tuple[int, ...] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @@ -39,15 +59,13 @@ class MLP(nn.Module): kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, data: jnp.ndarray) -> jnp.ndarray: - hidden = data + def __call__(self, obs: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, desc], axis=-1) for i, hidden_size in enumerate(self.layer_sizes): if i != len(self.layer_sizes) - 1: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", with this version of flax, changing the name - # changes the initialization kernel_init=self.kernel_init, use_bias=self.bias, )(hidden) @@ -61,7 +79,6 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", kernel_init=kernel_init, use_bias=self.bias, )(hidden) @@ -70,3 +87,45 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = self.final_activation(hidden) return hidden + + +class QModule(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden) + res.append(q) + return jnp.concatenate(res, axis=-1) + + +class QModuleDC(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__( + self, obs: jnp.ndarray, actions: jnp.ndarray, desc: jnp.ndarray + ) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLPDC( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden, desc) + res.append(q) + return jnp.concatenate(res, axis=-1) diff --git a/qdax/environments/base_wrappers.py b/qdax/environments/base_wrappers.py index 6f317e7f..3f709fa7 100644 --- a/qdax/environments/base_wrappers.py +++ b/qdax/environments/base_wrappers.py @@ -1,6 +1,7 @@ from abc import abstractmethod -from typing import Any, List, Tuple +from typing import Any, Tuple +import jax from brax.v1 import jumpy as jp from brax.v1.envs import Env, State @@ -22,7 +23,7 @@ def state_descriptor_name(self) -> str: @property @abstractmethod - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -32,7 +33,7 @@ def behavior_descriptor_length(self) -> int: @property @abstractmethod - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -71,7 +72,7 @@ def state_descriptor_name(self) -> str: return self.env.state_descriptor_name @property - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.state_descriptor_limits @property @@ -79,7 +80,7 @@ def behavior_descriptor_length(self) -> int: return self.env.behavior_descriptor_length @property - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.behavior_descriptor_limits @property diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 720f662a..cf0c3336 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -3,7 +3,7 @@ import flax.struct import jax from brax.v1 import jumpy as jp -from brax.v1.envs import State, Wrapper +from brax.v1.envs import Env, State, Wrapper class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -69,3 +69,80 @@ def step(self, state: State, action: jp.ndarray) -> State: ) nstate.info[self.STATE_INFO_KEY] = eval_metrics return nstate + + +class ClipRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__( + self, env: Env, clip_min: float = None, clip_max: float = None + ) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + +class AffineRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__( + self, env: Env, clip_min: float = None, clip_max: float = None + ) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + +class OffsetRewardWrapper(Wrapper): + """Wraps gym environments to offset the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, offset: float = 0.0) -> None: + super().__init__(env) + self._offset = offset + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=state.reward + self._offset) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=state.reward + self._offset) diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 931ee9d3..1a588e52 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -10,7 +10,7 @@ import qdax.environments from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.core.neuroevolution.mdp_utils import generate_unroll +from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP from qdax.types import ( Descriptor, @@ -160,6 +160,81 @@ def scoring_function_brax_envs( ) +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + init_states: EnvState, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel in + deterministic or pseudo-deterministic environments. + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarily + evaluated with the same environment everytime, this won't be determinist. + When the init states are different, this is not purely stochastic. + + Args: + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from evaluation + random_key: The updated random key. + """ + + # Perform rollouts with each policy + random_key, subkey = jax.random.split(random_key) + unroll_fn = partial( + generate_unroll_actor_dc, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + random_key=subkey, + ) + + _final_state, data = jax.vmap(unroll_fn)(init_states, actors_dc_params, descs) + + # create a mask to extract data properly + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + + # Scores - add offset to ensure positive fitness (through positive rewards) + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors = behavior_descriptor_extractor(data, mask) + + return ( + fitnesses, + descriptors, + { + "transitions": data, + }, + random_key, + ) + + @partial( jax.jit, static_argnames=( @@ -225,6 +300,83 @@ def reset_based_scoring_function_brax_envs( return fitnesses, descriptors, extra_scores, random_key +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_reset_fn", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def reset_based_scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + episode_length: int, + play_reset_fn: Callable[[RNGKey], EnvState], + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel. + The play_reset_fn function allows for a more general scoring_function that can be + called with different batch-size and not only with a batch-size of the same + dimension as init_states. + + To define purely stochastic environments, using the reset function from the + environment, use "play_reset_fn = env.reset". + + To define purely deterministic environments, as in "scoring_function", generate + a single init_state using "init_state = env.reset(random_key)", then use + "play_reset_fn = lambda random_key: init_state". + + Args: + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_reset_fn: The function to reset the environment + and obtain initial states. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from the evaluation + random_key: The updated random key. + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split( + subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0] + ) + reset_fn = jax.vmap(play_reset_fn) + init_states = reset_fn(keys) + + ( + fitnesses, + descriptors, + extra_scores, + random_key, + ) = scoring_actor_dc_function_brax_envs( + actors_dc_params=actors_dc_params, + descs=descs, + random_key=random_key, + init_states=init_states, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + behavior_descriptor_extractor=behavior_descriptor_extractor, + ) + + return fitnesses, descriptors, extra_scores, random_key + + def create_brax_scoring_fn( env: brax.envs.Env, policy_network: nn.Module, diff --git a/requirements.txt b/requirements.txt index 978a1c87..50d7899c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + absl-py==1.0.0 brax==0.9.2 chex==0.1.83 @@ -5,8 +7,7 @@ dm-haiku==0.0.10 flax==0.7.4 gym==0.26.2 ipython -jax==0.4.16 -jaxlib==0.4.16 +jax[cuda12_pip] jumanji==0.3.1 jupyter numpy==1.24.1 diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 4e11370b..5f9ec5f7 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -73,7 +73,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, genotype_dim), minval=minval, @@ -111,11 +111,11 @@ def scoring_fn( if isinstance(algo_instance, SPEA2): repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, num_neighbours, random_key + genotypes, population_size, num_neighbours, random_key ) else: repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, random_key + genotypes, population_size, random_key ) # Run the algorithm diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 079fde45..98a5b960 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -126,7 +126,7 @@ def scoring_function(genotypes, random_key): # type: ignore lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -178,7 +178,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 5c6fbb0a..fc2e89b0 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -124,7 +124,7 @@ def scoring_function(genotypes, random_key): # type: ignore lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -176,7 +176,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index c70683ef..103f9489 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -81,7 +81,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, num_variables), minval=minval, @@ -127,7 +127,7 @@ def scoring_fn( ) repertoire, emitter_state, random_key = mome.init( - init_genotypes, centroids, pareto_front_max_length, random_key + genotypes, centroids, pareto_front_max_length, random_key ) # Run the algorithm From 30ea2087742c6fcafbf12c53d1b8abd018bf0799 Mon Sep 17 00:00:00 2001 From: Hannah Janmohamed <49594227+hannah-jan@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:49:36 +0000 Subject: [PATCH 02/20] fix: Fix pareto dominance definition (#174) Fixes the definition of Pareto dominance in pareto_front.py to account for solutions which have the same fitness values along one axis. --- qdax/utils/pareto_front.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 692a2bde..f9bd77ae 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -24,7 +24,10 @@ def compute_pareto_dominance( Return booleans when the vector is dominated by the batch. """ diff = jnp.subtract(batch_of_criteria, criteria_point) - return jnp.any(jnp.all(diff > 0, axis=-1)) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray: @@ -67,7 +70,10 @@ def compute_masked_pareto_dominance( diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( neutral_values, diff ) - return jnp.any(jnp.all(diff > 0, axis=-1)) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) def compute_masked_pareto_front( From fe5c66a65df1b34c94f2fc8564b96aa81131ff8e Mon Sep 17 00:00:00 2001 From: Felix Chalumeau Date: Mon, 10 Jun 2024 00:16:38 +0200 Subject: [PATCH 03/20] docs(contribution): clarify contribution process (#171) From 552df36f8ca984282f6049058f868dbcc4653780 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 12 Jun 2024 10:57:21 +0100 Subject: [PATCH 04/20] update version number --- qdax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qdax/__init__.py b/qdax/__init__.py index 493f7415..6a9beea8 100644 --- a/qdax/__init__.py +++ b/qdax/__init__.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.4.0" From 039e79b525067a2d5087cbb2c5fa55776c4652a9 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Wed, 12 Jun 2024 10:58:49 +0100 Subject: [PATCH 05/20] update citation --- README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 551680eb..dc767f37 100644 --- a/README.md +++ b/README.md @@ -166,13 +166,14 @@ Issues and contributions are welcome. Please refer to the [contribution guide](h ## Citing QDax If you use QDax in your research and want to cite it in your work, please use: ``` -@misc{chalumeau2023qdax, - title={QDax: A Library for Quality-Diversity and Population-based Algorithms with Hardware Acceleration}, - author={Felix Chalumeau and Bryan Lim and Raphael Boige and Maxime Allard and Luca Grillotti and Manon Flageat and Valentin Macé and Arthur Flajolet and Thomas Pierrot and Antoine Cully}, - year={2023}, - eprint={2308.03665}, - archivePrefix={arXiv}, - primaryClass={cs.AI} +@article{chalumeau2024qdax, + title={Qdax: A library for quality-diversity and population-based algorithms with hardware acceleration}, + author={Chalumeau, Felix and Lim, Bryan and Boige, Raphael and Allard, Maxime and Grillotti, Luca and Flageat, Manon and Mac{\'e}, Valentin and Richard, Guillaume and Flajolet, Arthur and Pierrot, Thomas and others}, + journal={Journal of Machine Learning Research}, + volume={25}, + number={108}, + pages={1--16}, + year={2024} } ``` From f2c6311401031439da7634b9b83f729498da04f6 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 30 Aug 2024 17:32:32 +0100 Subject: [PATCH 06/20] feat: Upgrade Library Versions and Python Version (#187) - moving to Python 3.10 - Upgrade all library versions in requirements.txt and setup.py - Remove DM-Haiku cause it is now deprecated. - all networks now are based on a single MLP class - jax.tree_map has been replaced with jax.tree_util.tree_map everywhere to avoid Deprecation Warnings. - types -> custom_types - add extra_require for jax[cuda12] - fix all notebooks and update typing extensions (for running notebooks) - fix dependabot security issues - added instructions for pip install qdax[cuda12] in README - fix observation space in jumanji test script --- .pre-commit-config.yaml | 16 +- .readthedocs.yaml | 2 +- README.md | 6 + dev.Dockerfile | 6 +- docs/installation.md | 2 +- environment.yaml | 4 +- examples/aurora.ipynb | 7 +- examples/cmaes.ipynb | 36 +-- examples/cmame.ipynb | 8 +- examples/cmamega.ipynb | 2 +- examples/dads.ipynb | 8 +- examples/diayn.ipynb | 8 +- examples/distributed_mapelites.ipynb | 16 +- examples/jumanji_snake.ipynb | 74 ++--- examples/me_sac_pbt.ipynb | 11 +- examples/me_td3_pbt.ipynb | 15 +- examples/mome.ipynb | 48 ++-- examples/pgame.ipynb | 2 +- examples/sac_pbt.ipynb | 71 +++-- examples/scripts/me_example.py | 7 +- examples/smerl.ipynb | 6 - examples/td3_pbt.ipynb | 47 ++- qdax/baselines/dads.py | 17 +- qdax/baselines/dads_smerl.py | 2 +- qdax/baselines/diayn.py | 2 +- qdax/baselines/diayn_smerl.py | 2 +- qdax/baselines/genetic_algorithm.py | 3 +- qdax/baselines/nsga2.py | 2 +- qdax/baselines/pbt.py | 2 +- qdax/baselines/sac.py | 17 +- qdax/baselines/sac_pbt.py | 2 +- qdax/baselines/spea2.py | 2 +- qdax/baselines/td3.py | 12 +- qdax/baselines/td3_pbt.py | 7 +- qdax/core/aurora.py | 4 +- qdax/core/cmaes.py | 3 +- qdax/core/containers/archive.py | 10 +- qdax/core/containers/ga_repertoire.py | 2 +- qdax/core/containers/mapelites_repertoire.py | 13 +- qdax/core/containers/mels_repertoire.py | 11 +- qdax/core/containers/mome_repertoire.py | 2 +- qdax/core/containers/nsga2_repertoire.py | 8 +- qdax/core/containers/repertoire.py | 7 +- qdax/core/containers/spea2_repertoire.py | 2 +- .../containers/uniform_replacement_archive.py | 4 +- .../containers/unstructured_repertoire.py | 23 +- qdax/core/distributed_map_elites.py | 18 +- qdax/core/emitters/cma_emitter.py | 9 +- qdax/core/emitters/cma_improvement_emitter.py | 6 +- qdax/core/emitters/cma_mega_emitter.py | 10 +- qdax/core/emitters/cma_opt_emitter.py | 2 +- qdax/core/emitters/cma_pool_emitter.py | 2 +- qdax/core/emitters/cma_rnd_emitter.py | 4 +- qdax/core/emitters/dcg_me_emitter.py | 2 +- qdax/core/emitters/dpg_emitter.py | 22 +- qdax/core/emitters/emitter.py | 2 +- qdax/core/emitters/mees_emitter.py | 25 +- qdax/core/emitters/multi_emitter.py | 2 +- qdax/core/emitters/mutation_operators.py | 2 +- qdax/core/emitters/omg_mega_emitter.py | 9 +- qdax/core/emitters/pbt_me_emitter.py | 2 +- qdax/core/emitters/pbt_variation_operators.py | 7 +- qdax/core/emitters/pga_me_emitter.py | 2 +- qdax/core/emitters/qdcg_emitter.py | 18 +- qdax/core/emitters/qdpg_emitter.py | 3 +- qdax/core/emitters/qpg_emitter.py | 18 +- qdax/core/emitters/standard_emitters.py | 2 +- qdax/core/map_elites.py | 10 +- qdax/core/mels.py | 3 +- qdax/core/mome.py | 2 +- qdax/core/neuroevolution/buffers/buffer.py | 2 +- .../buffers/trajectory_buffer.py | 2 +- qdax/core/neuroevolution/losses/dads_loss.py | 9 +- qdax/core/neuroevolution/losses/diayn_loss.py | 2 +- qdax/core/neuroevolution/losses/sac_loss.py | 2 +- qdax/core/neuroevolution/losses/td3_loss.py | 2 +- qdax/core/neuroevolution/mdp_utils.py | 4 +- .../neuroevolution/networks/dads_networks.py | 271 +++++++----------- .../neuroevolution/networks/diayn_networks.py | 117 ++++---- .../neuroevolution/networks/sac_networks.py | 89 +++--- .../networks/seq2seq_networks.py | 1 - .../neuroevolution/normalization_utils.py | 3 +- qdax/core/neuroevolution/sac_td3_utils.py | 6 +- qdax/{types.py => custom_types.py} | 0 qdax/environments/bd_extractors.py | 2 +- qdax/environments/exploration_wrappers.py | 8 +- qdax/environments/locomotion_wrappers.py | 4 +- qdax/environments/pointmaze.py | 20 +- qdax/environments/wrappers.py | 12 +- qdax/tasks/arm.py | 2 +- qdax/tasks/brax_envs.py | 3 +- qdax/tasks/hypervolume_functions.py | 2 +- qdax/tasks/jumanji_envs.py | 5 +- qdax/tasks/qd_suite/archimedean_spiral.py | 2 +- qdax/tasks/qd_suite/deceptive_evolvability.py | 2 +- qdax/tasks/qd_suite/qd_suite_task.py | 2 +- qdax/tasks/qd_suite/ssf.py | 2 +- qdax/tasks/standard_functions.py | 2 +- qdax/utils/metrics.py | 2 +- qdax/utils/pareto_front.py | 2 +- qdax/utils/plotting.py | 4 +- qdax/utils/sampling.py | 3 +- qdax/utils/train_seq2seq.py | 4 +- requirements.txt | 27 +- setup.py | 30 +- tests/baselines_test/cmame_test.py | 12 +- tests/baselines_test/cmamega_test.py | 8 +- tests/baselines_test/dads_smerl_test.py | 1 + tests/baselines_test/dads_test.py | 1 + tests/baselines_test/ga_test.py | 12 +- tests/baselines_test/me_pbt_sac_test.py | 8 +- tests/baselines_test/me_pbt_td3_test.py | 8 +- tests/baselines_test/mees_test.py | 8 +- tests/baselines_test/omgmega_test.py | 8 +- tests/baselines_test/pbt_sac_test.py | 2 +- tests/baselines_test/pbt_td3_test.py | 2 +- tests/baselines_test/pgame_test.py | 8 +- tests/baselines_test/qdpg_test.py | 8 +- tests/baselines_test/sac_test.py | 2 +- tests/core_test/aurora_test.py | 2 +- .../mapelites_repertoire_test.py | 2 +- .../containers_test/mels_repertoire_test.py | 2 +- .../emitters_test/multi_emitter_test.py | 6 +- tests/core_test/map_elites_test.py | 8 +- tests/core_test/mels_test.py | 8 +- tests/core_test/mome_test.py | 12 +- .../buffers_test/buffer_test.py | 12 +- .../buffers_test/trajectory_buffer_test.py | 4 +- tests/default_tasks_test/arm_test.py | 6 +- tests/default_tasks_test/brax_task_test.py | 6 +- .../hypervolume_functions_test.py | 6 +- tests/default_tasks_test/jumanji_envs_test.py | 17 +- tests/default_tasks_test/qd_suite_test.py | 6 +- .../standard_functions_test.py | 6 +- tests/environments_test/pointmaze_test.py | 2 +- tests/utils_test/sampling_test.py | 2 +- tool.Dockerfile | 2 +- 137 files changed, 857 insertions(+), 733 deletions(-) rename qdax/{types.py => custom_types.py} (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9329a64..1414d749 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,17 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 24.8.0 hooks: - id: black - language_version: python3.9 - args: ["--target-version", "py39"] + language_version: python3.10 + args: ["--target-version", "py310"] - repo: https://github.com/PyCQA/flake8 - rev: 3.8.4 + rev: 7.1.1 hooks: - id: flake8 args: ['--max-line-length=88', '--extend-ignore=E203'] @@ -21,12 +21,12 @@ repos: - flake8-comprehensions - flake8-bugbear - repo: https://github.com/kynan/nbstripout - rev: 0.3.9 + rev: 0.7.1 hooks: - id: nbstripout args: ["examples/"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.6.0 hooks: - id: debug-statements - id: requirements-txt-fixer @@ -42,6 +42,6 @@ repos: - id: trailing-whitespace # This hook trims trailing whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.942 + rev: v1.11.2 hooks: - id: mypy diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d9f0965b..e22967ae 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.9" + python: "3.10" apt_packages: - swig diff --git a/README.md b/README.md index 551680eb..dab09614 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,12 @@ QDax is available on PyPI and can be installed with: ```bash pip install qdax ``` + +To install QDax with CUDA 12 support, use: +```bash +pip install qdax[cuda12] +``` + Alternatively, the latest commit of QDax can be installed directly from source with: ```bash pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main diff --git a/dev.Dockerfile b/dev.Dockerfile index 458599db..305e29e0 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -16,7 +16,7 @@ RUN micromamba create -y --file /tmp/environment.yaml \ FROM python as test-image -ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy310/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH COPY --from=conda /opt/conda/envs/. /opt/conda/envs/ @@ -26,7 +26,7 @@ RUN pip install -r requirements-dev.txt FROM nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 as cuda-image -ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy310/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH @@ -70,7 +70,7 @@ RUN apt-get update && \ libosmesa6-dev \ patchelf \ python3-opengl \ - python3-dev=3.9* \ + python3-dev=3.10* \ python3-pip \ screen \ sudo \ diff --git a/docs/installation.md b/docs/installation.md index 90c62659..585af828 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -86,7 +86,7 @@ git clone git@github.com:adaptive-intelligent-robotics/QDax.git 2. Activate the environment and manually install the package qdax ```zsh - conda activate qdaxpy39 + conda activate qdaxpy310 pip install -e . ``` diff --git a/environment.yaml b/environment.yaml index 0ddf80d5..d93726af 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,9 +1,9 @@ -name: qdaxpy39 +name: qdaxpy310 channels: - defaults - conda-forge dependencies: -- python=3.9 +- python=3.10 - pip>=20.3.3 - conda>=4.9.2 - pip: diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index e4b86238..645ed911 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -512,11 +512,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -530,7 +527,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index c8e2a9fe..d7b30b1d 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "5c4ab97a", + "id": "0", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmaes.ipynb)" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "222bbe00", + "id": "1", "metadata": {}, "source": [ "# Optimizing with CMA-ES in Jax\n", @@ -26,7 +26,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d731f067", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "markdown", - "id": "7b6e910b", + "id": "3", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -80,7 +80,7 @@ { "cell_type": "code", "execution_count": null, - "id": "404fb0dc", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "ccc7cbeb", + "id": "5", "metadata": { "pycharm": { "name": "#%% md\n" @@ -111,7 +111,7 @@ { "cell_type": "code", "execution_count": null, - "id": "436dccbb", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -133,7 +133,7 @@ }, { "cell_type": "markdown", - "id": "62bdd2a4", + "id": "7", "metadata": { "pycharm": { "name": "#%% md\n" @@ -146,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4cf03f55", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -167,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "f1f69f50", + "id": "9", "metadata": { "pycharm": { "name": "#%% md\n" @@ -180,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1a95b74d", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -194,7 +194,7 @@ }, { "cell_type": "markdown", - "id": "ac2d5c0d", + "id": "11", "metadata": { "pycharm": { "name": "#%% md\n" @@ -207,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "363198ca", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -245,7 +245,7 @@ }, { "cell_type": "markdown", - "id": "0e5820b8", + "id": "13", "metadata": {}, "source": [ "## Check final fitnesses and distribution mean" @@ -254,7 +254,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e4a2c7b", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +272,7 @@ }, { "cell_type": "markdown", - "id": "f3bd2b0f", + "id": "15", "metadata": { "pycharm": { "name": "#%% md\n" @@ -285,7 +285,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ad85551c", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -333,7 +333,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 3c355eea..ff5fa5c2 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -141,9 +141,9 @@ "def clip(x: jnp.ndarray):\n", " in_bound = (x <= maxval) * (x >= minval)\n", " return jnp.where(\n", - " condition=in_bound,\n", - " x=x,\n", - " y=(maxval / x)\n", + " in_bound,\n", + " x,\n", + " (maxval / x)\n", " )\n", "\n", "def _behavior_descriptor_1(x: jnp.ndarray):\n", @@ -387,7 +387,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index e5749993..739ac3d5 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -315,7 +315,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/dads.ipynb b/examples/dads.ipynb index b3cc43b5..47abd1ec 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -554,7 +548,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 0562e7c2..8e085fce 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -544,7 +538,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 574c56a2..ea2f9b9b 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -40,7 +40,12 @@ "from IPython.display import clear_output\n", "import functools\n", "\n", - "from tqdm import tqdm\n", + "try:\n", + " from tqdm import tqdm\n", + "except:\n", + " !pip install tqdm | tail -n 1\n", + " from tqdm import tqdm\n", + "\n", "import time\n", "\n", "import jax\n", @@ -128,8 +133,7 @@ "outputs": [], "source": [ "# Get devices (change gpu by tpu if needed)\n", - "# devices = jax.devices('gpu')\n", - "devices = jax.devices('tpu')\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f'Detected the following {num_devices} device(s): {devices}')" ] @@ -351,7 +355,7 @@ "random_key = jnp.stack(random_key)\n", "\n", "# add a dimension for devices\n", - "init_variables = jax.tree_map(\n", + "init_variables = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (num_devices, batch_size_per_device,) + x.shape[1:]),\n", " init_variables\n", ")\n", @@ -397,7 +401,7 @@ " repertoire, emitter_state, random_key, metrics = update_fn(repertoire, emitter_state, random_key)\n", "\n", " # get metrics\n", - " metrics = jax.tree_map(lambda x: x[0], metrics)\n", + " metrics = jax.tree_util.tree_map(lambda x: x[0], metrics)\n", " timelapse = time.time() - start_time\n", "\n", " # log metrics\n", @@ -454,7 +458,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index a6a140fd..bfba1e5a 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "233e0f03", + "id": "0", "metadata": {}, "source": [ "# Training a population on Jumanji-Snake with QDax\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47b46c2f", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "03c2f1f7", + "id": "2", "metadata": {}, "source": [ "## Define hyperparameters" @@ -87,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52dd1e3b", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -97,7 +97,7 @@ "population_size = 100\n", "batch_size = population_size\n", "\n", - "num_iterations = 5000\n", + "num_iterations = 1000\n", "\n", "iso_sigma = 0.005\n", "line_sigma = 0.05" @@ -105,7 +105,7 @@ }, { "cell_type": "markdown", - "id": "8b8c890a", + "id": "4", "metadata": {}, "source": [ "## Instantiate the snake environment" @@ -114,7 +114,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a842cccc", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "markdown", - "id": "776862f1", + "id": "6", "metadata": {}, "source": [ "## Define the type of policy that will be used to solve the problem" @@ -141,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a1ce7d0", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -161,7 +161,7 @@ }, { "cell_type": "markdown", - "id": "49586b07", + "id": "8", "metadata": {}, "source": [ "## Utils to interact with the environment\n", @@ -172,12 +172,12 @@ { "cell_type": "code", "execution_count": null, - "id": "d1ff7827", + "id": "9", "metadata": {}, "outputs": [], "source": [ "def observation_processing(observation):\n", - " network_input = jnp.ravel(observation)\n", + " network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])\n", " return network_input\n", "\n", "\n", @@ -207,7 +207,7 @@ " obs=timestep.observation,\n", " next_obs=next_timestep.observation,\n", " rewards=next_timestep.reward,\n", - " dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),\n", + " dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),\n", " actions=action,\n", " truncations=jnp.array(0),\n", " state_desc=state_desc,\n", @@ -219,7 +219,7 @@ }, { "cell_type": "markdown", - "id": "0078bc01", + "id": "10", "metadata": {}, "source": [ "## Init a population of policies\n", @@ -230,7 +230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6cbd2065", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ "\n", "# compute observation size from observation spec\n", "obs_spec = env.observation_spec()\n", - "observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))\n", + "observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))\n", "\n", "fake_batch = jnp.zeros(shape=(batch_size, observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", @@ -255,7 +255,7 @@ }, { "cell_type": "markdown", - "id": "fe6bf07f", + "id": "12", "metadata": {}, "source": [ "## Define a method to extract behavior descriptor when relevant" @@ -264,7 +264,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a264b672", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +311,7 @@ }, { "cell_type": "markdown", - "id": "1cdc5f87", + "id": "14", "metadata": {}, "source": [ "## Define the scoring function" @@ -320,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b77d826", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "6555491a", + "id": "16", "metadata": {}, "source": [ "## Define the emitter used" @@ -342,7 +342,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30061ff4", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "da7e9b74", + "id": "18", "metadata": {}, "source": [ "## Define the algorithm used and apply the initial step\n", @@ -371,7 +371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7b5c2d6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +415,7 @@ }, { "cell_type": "markdown", - "id": "9b1bfee5", + "id": "20", "metadata": {}, "source": [ "## Run the optimization loop" @@ -424,7 +424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1af3a35", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "114ea4a8", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "92a35bf0", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79ada2d5", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -472,7 +472,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe5da301", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ }, { "cell_type": "markdown", - "id": "93d8154e", + "id": "26", "metadata": {}, "source": [ "## Play snake with the best policy\n", @@ -500,7 +500,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3ff882f4", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -511,7 +511,7 @@ { "cell_type": "code", "execution_count": null, - "id": "762c167e", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -524,7 +524,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07523e33", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -537,7 +537,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c75ce088", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -550,7 +550,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50ef95f6", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -563,7 +563,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40a03409", + "id": "32", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index b2de823e..fc6fbe8b 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -38,12 +38,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -80,7 +74,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -261,7 +256,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 238f703c..f28a1db4 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -39,12 +39,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -82,7 +76,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -264,7 +259,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { @@ -443,7 +438,7 @@ "num_cols = 5\n", "\n", "fig, axes = plt.subplots(\n", - " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30)\n", + " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30), squeeze=False,\n", ")\n", "for i, repertoire in enumerate(repertoires):\n", "\n", @@ -492,7 +487,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/mome.ipynb b/examples/mome.ipynb index bf0a5225..7e28b608 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "59f748d3", + "id": "0", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb)" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "a5e13ff6", + "id": "1", "metadata": {}, "source": [ "# Optimizing multiple objectives with MOME in Jax\n", @@ -28,7 +28,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af063418", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ }, { "cell_type": "markdown", - "id": "22495c16", + "id": "3", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -95,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b96b5d07", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,7 @@ }, { "cell_type": "markdown", - "id": "c2850d54", + "id": "5", "metadata": {}, "source": [ "## Define the scoring function: rastrigin multi-objective\n", @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5effe11", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "231d273d", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "29250e72", + "id": "8", "metadata": {}, "source": [ "## Define the metrics function that will be used" @@ -187,7 +187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab5d6334", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "a4828ca8", + "id": "10", "metadata": {}, "source": [ "## Define the initial population and the emitter" @@ -211,7 +211,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebf3bd27", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +248,7 @@ }, { "cell_type": "markdown", - "id": "c904664b", + "id": "12", "metadata": {}, "source": [ "## Compute the centroids" @@ -257,7 +257,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76547c4c", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +273,7 @@ }, { "cell_type": "markdown", - "id": "15936d15", + "id": "14", "metadata": {}, "source": [ "## Define a MOME instance" @@ -282,7 +282,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07a0d1d9", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -295,7 +295,7 @@ }, { "cell_type": "markdown", - "id": "f7ec5a77", + "id": "16", "metadata": {}, "source": [ "## Init the algorithm" @@ -304,7 +304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c05cbf1e", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +318,7 @@ }, { "cell_type": "markdown", - "id": "6de4cedf", + "id": "18", "metadata": {}, "source": [ "## Run MOME iterations" @@ -327,7 +327,7 @@ { "cell_type": "code", "execution_count": null, - "id": "96ea04e6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "3ff9ca98", + "id": "20", "metadata": {}, "source": [ "## Plot the results" @@ -353,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6766dc4f", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -363,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28ab56c9", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ab36cb7", + "id": "23", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 9b638b2d..03fd9c00 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -106,7 +106,7 @@ "#@markdown ---\n", "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", "episode_length = 250 #@param {type:\"integer\"}\n", - "num_iterations = 4000 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", "seed = 42 #@param {type:\"integer\"}\n", "policy_hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", "iso_sigma = 0.005 #@param {type:\"number\"}\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 7762083f..a484b035 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1606cdf6", + "id": "0", "metadata": { "jupyter": { "outputs_hidden": false @@ -44,12 +44,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -67,7 +61,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df61dcc5", + "id": "1", "metadata": { "jupyter": { "outputs_hidden": false @@ -84,7 +78,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7d0b0ef", + "id": "2", "metadata": { "jupyter": { "outputs_hidden": false @@ -95,7 +89,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -103,7 +98,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f342948", + "id": "3", "metadata": { "jupyter": { "outputs_hidden": false @@ -123,8 +118,8 @@ "buffer_size = 100000\n", "\n", "# PBT Config\n", - "num_best_to_replace_from = 20\n", - "num_worse_to_replace = 40\n", + "num_best_to_replace_from = 1\n", + "num_worse_to_replace = 1\n", "\n", "# SAC config\n", "batch_size = 256\n", @@ -144,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "090f8d4d", + "id": "4", "metadata": { "jupyter": { "outputs_hidden": false @@ -175,7 +170,7 @@ { "cell_type": "code", "execution_count": null, - "id": "efac713a", + "id": "5", "metadata": { "jupyter": { "outputs_hidden": false @@ -193,7 +188,7 @@ " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_map(\n", + " lambda tree: jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(\n", " x,\n", " (\n", @@ -214,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bccfc6d", + "id": "6", "metadata": { "jupyter": { "outputs_hidden": false @@ -237,7 +232,7 @@ { "cell_type": "code", "execution_count": null, - "id": "708eea0a", + "id": "7", "metadata": { "jupyter": { "outputs_hidden": false @@ -266,7 +261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e6e2bec", + "id": "8", "metadata": { "jupyter": { "outputs_hidden": false @@ -293,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f09fe1e", + "id": "9", "metadata": { "jupyter": { "outputs_hidden": false @@ -311,7 +306,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66ba826a", + "id": "10", "metadata": { "jupyter": { "outputs_hidden": false @@ -336,7 +331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a49af55e", + "id": "11", "metadata": { "jupyter": { "outputs_hidden": false @@ -362,7 +357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b137c8e5", + "id": "12", "metadata": { "jupyter": { "outputs_hidden": false @@ -384,7 +379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1dbbb855", + "id": "13", "metadata": { "jupyter": { "outputs_hidden": false @@ -397,8 +392,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_map(\n", + " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -407,7 +402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a052ba2e", + "id": "14", "metadata": { "jupyter": { "outputs_hidden": false @@ -447,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f7354f2", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" @@ -461,7 +456,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd1c27e8-1fe7-464d-8af3-72fa8d61852d", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -471,13 +466,13 @@ "source": [ "training_states = unshard_fn(training_states)\n", "best_idx = jnp.argmax(population_returns)\n", - "best_training_state = jax.tree_map(lambda x: x[best_idx], training_states)" + "best_training_state = jax.tree_util.tree_map(lambda x: x[best_idx], training_states)" ] }, { "cell_type": "code", "execution_count": null, - "id": "60e8ee82-27cf-4fa6-b189-66e6e10e2177", + "id": "17", "metadata": { "pycharm": { "name": "#%%\n" @@ -491,7 +486,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84bd809c-0127-4241-9556-9e81e550bbd2", + "id": "18", "metadata": { "pycharm": { "name": "#%%\n" @@ -509,7 +504,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2954be53-ffa5-42cf-8696-7f40f139edaf", + "id": "19", "metadata": { "pycharm": { "name": "#%%\n" @@ -523,7 +518,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14026ff2-e7f2-46eb-91d5-6e7394136e96", + "id": "20", "metadata": { "pycharm": { "name": "#%%\n" @@ -537,7 +532,7 @@ "rng = jax.random.PRNGKey(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", - "training_state, env_state = jax.tree_map(\n", + "training_state, env_state = jax.tree_util.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)\n", ")\n", "\n", @@ -552,7 +547,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5701084-e876-43f8-8de0-4216361ef5b4", + "id": "21", "metadata": { "pycharm": { "name": "#%%\n" @@ -561,7 +556,7 @@ "outputs": [], "source": [ "rollout = [\n", - " jax.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", + " jax.tree_util.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", " for env_state in rollout\n", "]" ] @@ -569,7 +564,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85bb7556-37bb-4a20-88b3-28b298c8b0a9", + "id": "22", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index 699c6aba..433bc1d2 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -79,7 +79,12 @@ def run_me() -> None: # Run MAP-Elites loop for _ in range(num_iterations): - (repertoire, emitter_state, metrics, random_key,) = map_elites.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = map_elites.update( repertoire, emitter_state, random_key, diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index fe655fe2..d50f448f 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index ec98b9da..484f6d12 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf95707f", + "id": "0", "metadata": { "pycharm": { "name": "#%%\n" @@ -41,12 +41,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -62,7 +56,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15a43429", + "id": "1", "metadata": { "pycharm": { "name": "#%%\n" @@ -76,7 +70,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32d15301", + "id": "2", "metadata": { "pycharm": { "name": "#%%\n" @@ -84,7 +78,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"gpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -92,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7520b673", + "id": "3", "metadata": { "pycharm": { "name": "#%%\n" @@ -129,7 +124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3718a4c", + "id": "4", "metadata": { "pycharm": { "name": "#%%\n" @@ -157,7 +152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5485a16c", + "id": "5", "metadata": { "pycharm": { "name": "#%%\n" @@ -172,7 +167,7 @@ " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_map(\n", + " lambda tree: jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(\n", " x, (population_size_per_device, env_batch_size,) + x.shape[1:]\n", " ),\n", @@ -188,7 +183,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4dc22ec4", + "id": "6", "metadata": { "pycharm": { "name": "#%%\n" @@ -208,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9c610ba", + "id": "7", "metadata": { "pycharm": { "name": "#%%\n" @@ -232,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f6fd3b9", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -256,7 +251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a412cd4f", + "id": "9", "metadata": { "pycharm": { "name": "#%%\n" @@ -271,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "535250a8", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -293,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d24156f4", + "id": "11", "metadata": { "pycharm": { "name": "#%%\n" @@ -316,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23037e97", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -335,7 +330,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9ebb235", + "id": "13", "metadata": { "pycharm": { "name": "#%%\n" @@ -345,8 +340,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_map(\n", + " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -355,7 +350,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a58253e5", + "id": "14", "metadata": { "pycharm": { "name": "#%%\n" @@ -392,7 +387,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6111e836", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index 41f2ff08..bd4f4534 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -25,7 +25,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor +from qdax.custom_types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor class DadsTrainingState(TrainingState): @@ -430,12 +430,17 @@ def _update_dynamics( """ training_state, transitions = operand - dynamics_loss, dynamics_gradient = jax.value_and_grad(self._dynamics_loss_fn,)( + dynamics_loss, dynamics_gradient = jax.value_and_grad( + self._dynamics_loss_fn, + )( training_state.dynamics_params, transitions=transitions, ) - (dynamics_updates, dynamics_optimizer_state,) = self._dynamics_optimizer.update( + ( + dynamics_updates, + dynamics_optimizer_state, + ) = self._dynamics_optimizer.update( dynamics_gradient, training_state.dynamics_optimizer_state ) dynamics_params = optax.apply_updates( @@ -483,7 +488,11 @@ def _update_networks( random_key = training_state.random_key # Update skill-dynamics - (dynamics_params, dynamics_loss, dynamics_optimizer_state,) = jax.lax.cond( + ( + dynamics_params, + dynamics_loss, + dynamics_optimizer_state, + ) = jax.lax.cond( training_state.steps % self._config.dynamics_update_freq == 0, self._update_dynamics, self._not_update_dynamics, diff --git a/qdax/baselines/dads_smerl.py b/qdax/baselines/dads_smerl.py index 206f0012..5bd8274d 100644 --- a/qdax/baselines/dads_smerl.py +++ b/qdax/baselines/dads_smerl.py @@ -14,7 +14,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer from qdax.core.neuroevolution.normalization_utils import normalize_with_rmstd -from qdax.types import Metrics, Reward +from qdax.custom_types import Metrics, Reward @dataclass diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index c03cfb3f..0ebdfc32 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -20,7 +20,7 @@ from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode from qdax.core.neuroevolution.networks.diayn_networks import make_diayn_networks from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor +from qdax.custom_types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor class DiaynTrainingState(TrainingState): diff --git a/qdax/baselines/diayn_smerl.py b/qdax/baselines/diayn_smerl.py index 2966a692..daacaa74 100644 --- a/qdax/baselines/diayn_smerl.py +++ b/qdax/baselines/diayn_smerl.py @@ -13,7 +13,7 @@ from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer -from qdax.types import Metrics, Reward +from qdax.custom_types import Metrics, Reward @dataclass diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index a01a13b1..b4c6a32f 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -1,4 +1,5 @@ """Core components of a basic genetic algorithm.""" + from functools import partial from typing import Any, Callable, Optional, Tuple @@ -6,7 +7,7 @@ from qdax.core.containers.ga_repertoire import GARepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ExtraScores, Fitness, Genotype, Metrics, RNGKey +from qdax.custom_types import ExtraScores, Fitness, Genotype, Metrics, RNGKey class GeneticAlgorithm: diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index a889eadc..afd587af 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -13,7 +13,7 @@ from qdax.baselines.genetic_algorithm import GeneticAlgorithm from qdax.core.containers.nsga2_repertoire import NSGA2Repertoire from qdax.core.emitters.emitter import EmitterState -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class NSGA2(GeneticAlgorithm): diff --git a/qdax/baselines/pbt.py b/qdax/baselines/pbt.py index 65d1a950..6555c537 100644 --- a/qdax/baselines/pbt.py +++ b/qdax/baselines/pbt.py @@ -6,7 +6,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer -from qdax.types import RNGKey +from qdax.custom_types import RNGKey class PBTTrainingState(PyTreeNode): diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index a5ce15c5..482c5715 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -32,7 +32,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Mask, @@ -449,7 +449,10 @@ def _update_alpha( random_key=subkey, ) alpha_optimizer = optax.adam(learning_rate=alpha_lr) - (alpha_updates, alpha_optimizer_state,) = alpha_optimizer.update( + ( + alpha_updates, + alpha_optimizer_state, + ) = alpha_optimizer.update( alpha_gradient, training_state.alpha_optimizer_state ) alpha_params = optax.apply_updates( @@ -503,7 +506,10 @@ def _update_critic( random_key=subkey, ) critic_optimizer = optax.adam(learning_rate=critic_lr) - (critic_updates, critic_optimizer_state,) = critic_optimizer.update( + ( + critic_updates, + critic_optimizer_state, + ) = critic_optimizer.update( critic_gradient, training_state.critic_optimizer_state ) critic_params = optax.apply_updates( @@ -556,7 +562,10 @@ def _update_actor( random_key=subkey, ) policy_optimizer = optax.adam(learning_rate=policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/sac_pbt.py b/qdax/baselines/sac_pbt.py index 9aa2ff4c..947a7183 100644 --- a/qdax/baselines/sac_pbt.py +++ b/qdax/baselines/sac_pbt.py @@ -22,7 +22,7 @@ ) from qdax.core.neuroevolution.normalization_utils import normalize_with_rmstd from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn -from qdax.types import Descriptor, Mask, Metrics, RNGKey +from qdax.custom_types import Descriptor, Mask, Metrics, RNGKey class PBTSacTrainingState(PBTTrainingState, SacTrainingState): diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index c52063b6..10d195ad 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -15,7 +15,7 @@ from qdax.baselines.genetic_algorithm import GeneticAlgorithm from qdax.core.containers.spea2_repertoire import SPEA2Repertoire from qdax.core.emitters.emitter import EmitterState -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class SPEA2(GeneticAlgorithm): diff --git a/qdax/baselines/td3.py b/qdax/baselines/td3.py index e09b5254..97f37893 100644 --- a/qdax/baselines/td3.py +++ b/qdax/baselines/td3.py @@ -23,7 +23,7 @@ from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode from qdax.core.neuroevolution.networks.td3_networks import make_td3_networks from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Mask, @@ -76,7 +76,10 @@ class TD3: def __init__(self, config: TD3Config, action_size: int): self._config = config - self._policy, self._critic, = make_td3_networks( + ( + self._policy, + self._critic, + ) = make_td3_networks( action_size=action_size, critic_hidden_layer_sizes=self._config.critic_hidden_layer_size, policy_hidden_layer_sizes=self._config.policy_hidden_layer_size, @@ -421,7 +424,10 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam( learning_rate=self._config.policy_learning_rate ) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/td3_pbt.py b/qdax/baselines/td3_pbt.py index 60cd8a38..5762956d 100644 --- a/qdax/baselines/td3_pbt.py +++ b/qdax/baselines/td3_pbt.py @@ -25,7 +25,7 @@ td3_policy_loss_fn, ) from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn -from qdax.types import Descriptor, Mask, Metrics, Params, RNGKey +from qdax.custom_types import Descriptor, Mask, Metrics, Params, RNGKey class PBTTD3TrainingState(PBTTrainingState, TD3TrainingState): @@ -291,7 +291,10 @@ def update( def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam(learning_rate=training_state.policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index a0968ccc..f67d7b4f 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -12,8 +12,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.environments.bd_extractors import AuroraExtraInfo -from qdax.types import ( +from qdax.custom_types import ( Descriptor, Fitness, Genotype, @@ -22,6 +21,7 @@ Params, RNGKey, ) +from qdax.environments.bd_extractors import AuroraExtraInfo class AURORA: diff --git a/qdax/core/cmaes.py b/qdax/core/cmaes.py index 481a49bf..0e9b4084 100644 --- a/qdax/core/cmaes.py +++ b/qdax/core/cmaes.py @@ -2,6 +2,7 @@ Definition of CMAES class, containing main functions necessary to build a CMA optimization script. Link to the paper: https://arxiv.org/abs/1604.00772 """ + from functools import partial from typing import Callable, Optional, Tuple @@ -9,7 +10,7 @@ import jax import jax.numpy as jnp -from qdax.types import Fitness, Genotype, Mask, RNGKey +from qdax.custom_types import Fitness, Genotype, Mask, RNGKey class CMAESState(flax.struct.PyTreeNode): diff --git a/qdax/core/containers/archive.py b/qdax/core/containers/archive.py index 8af808f3..036c5892 100644 --- a/qdax/core/containers/archive.py +++ b/qdax/core/containers/archive.py @@ -40,7 +40,7 @@ def size(self) -> float: fake_data = jnp.isnan(self.data) # count number of real data - return sum(~fake_data) + return float(sum(~fake_data)) @classmethod def create( @@ -161,9 +161,7 @@ def insert(self, state_descriptors: jnp.ndarray) -> Archive: values, _indices = knn(self.data, state_descriptors, 1) # get indices where distance bigger than threshold - relevant_indices = jnp.where( - values.squeeze() > self.acceptance_threshold, x=0, y=1 - ) + relevant_indices = jnp.where(values.squeeze() > self.acceptance_threshold, 0, 1) def iterate_fn( carry: Tuple[Archive, jnp.ndarray, int], condition_data: Dict @@ -192,7 +190,7 @@ def iterate_fn( # get indices where distance bigger than threshold not_too_close = jnp.where( - values.squeeze() > self.acceptance_threshold, x=0, y=1 + values.squeeze() > self.acceptance_threshold, 0, 1 ) second_condition = not_too_close.sum() condition = (first_condition + second_condition) == 0 @@ -280,7 +278,7 @@ def knn( dist = jnp.nan_to_num(dist, nan=jnp.inf) # clipping necessary - numerical approx make some distancies negative - dist = jnp.sqrt(jnp.clip(dist, a_min=0.0)) + dist = jnp.sqrt(jnp.clip(dist, min=0.0)) # return values, indices values, indices = qdax_top_k(-dist, k) diff --git a/qdax/core/containers/ga_repertoire.py b/qdax/core/containers/ga_repertoire.py index 87ade54f..403331ff 100644 --- a/qdax/core/containers/ga_repertoire.py +++ b/qdax/core/containers/ga_repertoire.py @@ -10,7 +10,7 @@ from jax.flatten_util import ravel_pytree from qdax.core.containers.repertoire import Repertoire -from qdax.types import Fitness, Genotype, RNGKey +from qdax.custom_types import Fitness, Genotype, RNGKey class GARepertoire(Repertoire): diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index b1145c34..b473d4b3 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -15,7 +15,14 @@ from numpy.random import RandomState from sklearn.cluster import KMeans -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) def compute_cvt_centroids( @@ -303,7 +310,7 @@ def add( # put dominated fitness to -jnp.inf batch_of_fitnesses = jnp.where( - batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf ) # get addition condition @@ -315,7 +322,7 @@ def add( # assign fake position when relevant : num_centroids is out of bound batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids + addition_condition, batch_of_indices, num_centroids ) # create new repertoire diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py index a2e99971..7ef57bb9 100644 --- a/qdax/core/containers/mels_repertoire.py +++ b/qdax/core/containers/mels_repertoire.py @@ -14,7 +14,14 @@ MapElitesRepertoire, get_cells_indices, ) -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, Spread +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Spread, +) def _dispersion(descriptors: jnp.ndarray) -> jnp.ndarray: @@ -232,7 +239,7 @@ def add( # assign fake position when relevant : num_centroids is out of bound batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids + addition_condition, batch_of_indices, num_centroids ) # create new repertoire diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 0e2b6d3e..43be3835 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -15,7 +15,7 @@ MapElitesRepertoire, get_cells_indices, ) -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, diff --git a/qdax/core/containers/nsga2_repertoire.py b/qdax/core/containers/nsga2_repertoire.py index 74b0f454..331ef153 100644 --- a/qdax/core/containers/nsga2_repertoire.py +++ b/qdax/core/containers/nsga2_repertoire.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.containers.ga_repertoire import GARepertoire -from qdax.types import Fitness, Genotype +from qdax.custom_types import Fitness, Genotype from qdax.utils.pareto_front import compute_masked_pareto_front @@ -56,9 +56,9 @@ def _compute_crowding_distances( norm = jnp.max(srt_fitnesses, axis=0) - jnp.min(srt_fitnesses, axis=0) # get the distances - dists = jnp.row_stack( + dists = jnp.vstack( [srt_fitnesses, jnp.full(num_objective, jnp.inf)] - ) - jnp.row_stack([jnp.full(num_objective, -jnp.inf), srt_fitnesses]) + ) - jnp.vstack([jnp.full(num_objective, -jnp.inf), srt_fitnesses]) # Prepare the distance to last and next vectors dist_to_last, dist_to_next = dists, dists @@ -228,7 +228,7 @@ def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool: # get rid of the zeros (that correspond to the False from the mask) fake_indice = num_candidates + 1 # bigger than all the other indices - indices = jnp.where(indices == 0, x=fake_indice, y=indices) + indices = jnp.where(indices == 0, fake_indice, indices) # sort the indices to remove the fake indices indices = jnp.sort(indices)[: self.size] diff --git a/qdax/core/containers/repertoire.py b/qdax/core/containers/repertoire.py index f50d53b7..77c91683 100644 --- a/qdax/core/containers/repertoire.py +++ b/qdax/core/containers/repertoire.py @@ -4,11 +4,11 @@ from __future__ import annotations -from abc import ABC, abstractclassmethod, abstractmethod +from abc import ABC, abstractmethod import flax -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class Repertoire(flax.struct.PyTreeNode, ABC): @@ -19,7 +19,8 @@ class Repertoire(flax.struct.PyTreeNode, ABC): to keep the parent classes explicit and transparent. """ - @abstractclassmethod + @classmethod + @abstractmethod def init(cls) -> Repertoire: # noqa: N805 """Create a repertoire.""" pass diff --git a/qdax/core/containers/spea2_repertoire.py b/qdax/core/containers/spea2_repertoire.py index 54870db4..33c31547 100644 --- a/qdax/core/containers/spea2_repertoire.py +++ b/qdax/core/containers/spea2_repertoire.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from qdax.core.containers.ga_repertoire import GARepertoire -from qdax.types import Fitness, Genotype +from qdax.custom_types import Fitness, Genotype class SPEA2Repertoire(GARepertoire): diff --git a/qdax/core/containers/uniform_replacement_archive.py b/qdax/core/containers/uniform_replacement_archive.py index d6f233db..830878cf 100644 --- a/qdax/core/containers/uniform_replacement_archive.py +++ b/qdax/core/containers/uniform_replacement_archive.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from qdax.core.containers.archive import Archive -from qdax.types import RNGKey +from qdax.custom_types import RNGKey class UniformReplacementArchive(Archive): @@ -74,7 +74,7 @@ def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: subkey, shape=(1,), minval=0, maxval=self.max_size ) - index = jnp.where(condition=is_full, x=random_index, y=new_current_position) + index = jnp.where(is_full, random_index, new_current_position) new_data = self.data.at[index].set(state_descriptor) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index f4cc0c98..8512d3d6 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -8,7 +8,14 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from qdax.types import Centroid, Descriptor, Fitness, Genotype, Observation, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + Fitness, + Genotype, + Observation, + RNGKey, +) @partial(jax.jit, static_argnames=("k_nn",)) @@ -300,7 +307,7 @@ def add( # ReIndexing of all the inputs to the correct sorted way batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() - batch_of_genotypes = jax.tree_map( + batch_of_genotypes = jax.tree_util.tree_map( lambda x: x.at[sorted_bds].get(), batch_of_genotypes ) batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() @@ -333,7 +340,7 @@ def add( # put dominated fitness to -jnp.inf batch_of_fitnesses = jnp.where( - batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf ) # get addition condition @@ -347,12 +354,12 @@ def add( # assign fake position when relevant : num_centroids is out of bounds batch_of_indices = jnp.where( addition_condition, - x=batch_of_indices, - y=self.max_size, + batch_of_indices, + self.max_size, ) # create new grid - new_grid_genotypes = jax.tree_map( + new_grid_genotypes = jax.tree_util.tree_map( lambda grid_genotypes, new_genotypes: grid_genotypes.at[ batch_of_indices.squeeze() ].set(new_genotypes), @@ -398,7 +405,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey grid_empty = self.fitnesses == -jnp.inf p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p), self.genotypes, ) @@ -435,7 +442,7 @@ def init( # Initialize grid with default values default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) - default_genotypes = jax.tree_map( + default_genotypes = jax.tree_util.tree_map( lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan), genotypes, ) diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index 7b5609f2..dbc6522b 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -10,7 +11,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import Centroid, Genotype, Metrics, RNGKey +from qdax.custom_types import Centroid, Genotype, Metrics, RNGKey class DistributedMAPElites(MAPElites): @@ -189,7 +190,7 @@ def get_distributed_update_fn( of MAP-Elites updates. """ - @partial(jax.jit, static_argnames=("self",)) + @jax.jit def _scan_update( carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], unused: Any, @@ -200,7 +201,12 @@ def _scan_update( repertoire, emitter_state, random_key = carry # apply one step of update - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, @@ -214,7 +220,11 @@ def update_fn( random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]: """Apply num_iterations of update.""" - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( _scan_update, (repertoire, emitter_state, random_key), (), diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index 66e5677a..315dcd9b 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -13,7 +13,14 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class CMAEmitterState(EmitterState): diff --git a/qdax/core/emitters/cma_improvement_emitter.py b/qdax/core/emitters/cma_improvement_emitter.py index 28424f3f..7c3fc98c 100644 --- a/qdax/core/emitters/cma_improvement_emitter.py +++ b/qdax/core/emitters/cma_improvement_emitter.py @@ -6,7 +6,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype class CMAImprovementEmitter(CMAEmitter): @@ -62,13 +62,13 @@ def _ranking_criteria( condition = improvements == jnp.inf # criteria: fitness if new cell, improvement else - ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + ranking_criteria = jnp.where(condition, fitnesses, improvements) # make sure to have all the new cells first new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index 1fd0e1e6..c3f87fed 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -12,7 +12,7 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, @@ -238,13 +238,13 @@ def state_update( condition = improvements == jnp.inf # criteria: fitness if new cell, improvement else - ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + ranking_criteria = jnp.where(condition, fitnesses, improvements) # make sure to have all the new cells first new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) # sort indices according to the criteria @@ -282,12 +282,12 @@ def state_update( # update theta in case of reinit theta = jax.tree_util.tree_map( - lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta + lambda x, y: jnp.where(reinitialize, x, y), random_theta, theta ) # update cmaes state in case of reinit cmaes_state = jax.tree_util.tree_map( - lambda x, y: jnp.where(reinitialize, x=x, y=y), + lambda x, y: jnp.where(reinitialize, x, y), self._cma_initial_state, cmaes_state, ) diff --git a/qdax/core/emitters/cma_opt_emitter.py b/qdax/core/emitters/cma_opt_emitter.py index d9c5bf71..9a783585 100644 --- a/qdax/core/emitters/cma_opt_emitter.py +++ b/qdax/core/emitters/cma_opt_emitter.py @@ -6,7 +6,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype class CMAOptimizingEmitter(CMAEmitter): diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index 24556f8b..55ccaa4f 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -9,7 +9,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class CMAPoolEmitterState(EmitterState): diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index e05cc453..27e4f0db 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -9,7 +9,7 @@ from qdax.core.cmaes import CMAESState from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class CMARndEmitterState(CMAEmitterState): @@ -168,7 +168,7 @@ def _ranking_criteria( condition = improvements == jnp.inf ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcg_me_emitter.py index 94e0bb9d..fea237c6 100644 --- a/qdax/core/emitters/dcg_me_emitter.py +++ b/qdax/core/emitters/dcg_me_emitter.py @@ -6,8 +6,8 @@ from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Params, RNGKey @dataclass diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 2c55cbd2..ea921237 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -1,6 +1,7 @@ """ Implements the Diversity PG inspired by QDPG algorithm in jax for brax environments, based on: https://arxiv.org/abs/2006.08505 """ + from dataclasses import dataclass from functools import partial from typing import Any, Callable, Optional, Tuple @@ -17,8 +18,7 @@ QualityPGEmitterState, ) from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.environments.base_wrappers import QDEnv -from qdax.types import ( +from qdax.custom_types import ( Descriptor, ExtraScores, Fitness, @@ -28,6 +28,7 @@ RNGKey, StateDescriptor, ) +from qdax.environments.base_wrappers import QDEnv @dataclass @@ -180,7 +181,10 @@ def scan_train_critics( return new_emitter_state, () # sample transitions - (transitions, random_key,) = emitter_state.replay_buffer.sample( + ( + transitions, + random_key, + ) = emitter_state.replay_buffer.sample( random_key=emitter_state.random_key, sample_size=self._config.num_critic_training_steps * self._config.batch_size, @@ -249,7 +253,11 @@ def _train_critics( ) # Update greedy policy - (policy_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + policy_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -348,7 +356,11 @@ def scan_train_policy( transitions, ) - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (transitions), diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index 056798ba..21139356 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -6,7 +6,7 @@ from flax.struct import PyTreeNode from qdax.core.containers.repertoire import Repertoire -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class EmitterState(PyTreeNode): diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index 0a03a6ba..4d51326a 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -3,6 +3,7 @@ from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217 """ + from __future__ import annotations from dataclasses import dataclass @@ -19,7 +20,7 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class NoveltyArchive(flax.struct.PyTreeNode): @@ -362,7 +363,7 @@ def _sample( genotypes_empty = fitnesses < min_fitness p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty) random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), genotypes, ) @@ -429,7 +430,7 @@ def _sample_explore( repertoire_empty = novelties < min_novelty p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), repertoire.genotypes, ) @@ -486,7 +487,7 @@ def _es_emitter( # Sampling non-mirror noise else: sample_number = total_sample_number - sample_noise = jax.tree_map( + sample_noise = jax.tree_util.tree_map( lambda x: jax.random.normal( key=subkey, shape=jnp.repeat(x, sample_number, axis=0).shape, @@ -496,11 +497,11 @@ def _es_emitter( gradient_noise = sample_noise # Applying noise - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jnp.repeat(x, total_sample_number, axis=0), parent, ) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda mean, noise: mean + self._config.sample_sigma * noise, samples, sample_noise, @@ -526,7 +527,7 @@ def _es_emitter( if self._config.sample_mirror: ranks = jnp.reshape(ranks, (sample_number, 2)) ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks) - ranks = jax.tree_map( + ranks = jax.tree_util.tree_map( lambda x: jnp.reshape( jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape ), @@ -534,16 +535,16 @@ def _es_emitter( ) # Computing the gradients - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda noise, rank: jnp.multiply(noise, rank), gradient_noise, ranks, ) - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda x: jnp.reshape(x, (sample_number, -1)), gradient, ) - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda g, p: jnp.reshape( -jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma), p.shape, @@ -553,7 +554,7 @@ def _es_emitter( ) # Adding regularisation - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda g, p: g + self._config.l2_coefficient * p, gradient, parent, @@ -626,7 +627,7 @@ def _buffers_update( last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set( fitnesses[0] ) - last_updated_genotypes = jax.tree_map( + last_updated_genotypes = jax.tree_util.tree_map( lambda last_gen, gen: last_gen.at[ jnp.expand_dims(last_updated_position, axis=0) ].set(gen), diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index b3ad23c6..17cb8ace 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -8,7 +8,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class MultiEmitterState(EmitterState): diff --git a/qdax/core/emitters/mutation_operators.py b/qdax/core/emitters/mutation_operators.py index f39b8060..bda2daca 100644 --- a/qdax/core/emitters/mutation_operators.py +++ b/qdax/core/emitters/mutation_operators.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey def _polynomial_mutation( diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 54766152..580bd151 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -6,7 +6,14 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class OMGMEGAEmitterState(EmitterState): diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index a2266bfa..55bded4e 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -12,8 +12,8 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey class PBTEmitterState(EmitterState): diff --git a/qdax/core/emitters/pbt_variation_operators.py b/qdax/core/emitters/pbt_variation_operators.py index bd76ecd1..c8537003 100644 --- a/qdax/core/emitters/pbt_variation_operators.py +++ b/qdax/core/emitters/pbt_variation_operators.py @@ -3,7 +3,7 @@ from qdax.baselines.sac_pbt import PBTSacTrainingState from qdax.baselines.td3_pbt import PBTTD3TrainingState from qdax.core.emitters.mutation_operators import isoline_variation -from qdax.types import RNGKey +from qdax.custom_types import RNGKey def sac_pbt_variation_fn( @@ -94,7 +94,10 @@ def td3_pbt_variation_fn( training_state1.critic_params, training_state2.critic_params, ) - (policy_params, critic_params,), random_key = isoline_variation( + ( + policy_params, + critic_params, + ), random_key = isoline_variation( x1=(policy_params1, critic_params1), x2=(policy_params2, critic_params2), random_key=random_key, diff --git a/qdax/core/emitters/pga_me_emitter.py b/qdax/core/emitters/pga_me_emitter.py index e93eb696..a4f8b33f 100644 --- a/qdax/core/emitters/pga_me_emitter.py +++ b/qdax/core/emitters/pga_me_emitter.py @@ -6,8 +6,8 @@ from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Params, RNGKey @dataclass diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/qdcg_emitter.py index 0d560cbb..0fb19c4b 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/qdcg_emitter.py @@ -16,8 +16,8 @@ from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn from qdax.core.neuroevolution.networks.networks import QModuleDC +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @dataclass @@ -521,7 +521,11 @@ def _train_critics( ) # Update greedy actor - (actor_opt_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + actor_opt_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -580,7 +584,7 @@ def _update_critic( critic_params = optax.apply_updates(critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_critic_params, @@ -612,7 +616,7 @@ def _update_actor( actor_params = optax.apply_updates(actor_params, policy_updates) # Soft update of target greedy actor - target_actor_params = jax.tree_map( + target_actor_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_actor_params, @@ -701,7 +705,11 @@ def scan_train_policy( new_policy_opt_state, ), () - (emitter_state, policy_params, policy_opt_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_opt_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_opt_state), transitions, diff --git a/qdax/core/emitters/qdpg_emitter.py b/qdax/core/emitters/qdpg_emitter.py index eefd1566..b9de6090 100644 --- a/qdax/core/emitters/qdpg_emitter.py +++ b/qdax/core/emitters/qdpg_emitter.py @@ -5,6 +5,7 @@ it has been updated to work better with Jax in term of time cost. Those changes have been made in accordance with the authors of this algorithm. """ + import functools from dataclasses import dataclass from typing import Callable @@ -17,8 +18,8 @@ from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Reward, StateDescriptor from qdax.environments.base_wrappers import QDEnv -from qdax.types import Reward, StateDescriptor @dataclass diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index 4a173b51..c6e2df7e 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -17,8 +17,8 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn from qdax.core.neuroevolution.networks.networks import QModule +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @dataclass @@ -379,7 +379,11 @@ def _train_critics( ) # Update greedy actor - (actor_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + actor_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -439,7 +443,7 @@ def _update_critic( critic_params = optax.apply_updates(critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_critic_params, @@ -471,7 +475,7 @@ def _update_actor( actor_params = optax.apply_updates(actor_params, policy_updates) # Soft update of target greedy actor - target_actor_params = jax.tree_map( + target_actor_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_actor_params, @@ -527,7 +531,11 @@ def scan_train_policy( new_policy_optimizer_state, ), () - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (), diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 860962d4..1d949b2d 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -6,7 +6,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ExtraScores, Genotype, RNGKey +from qdax.custom_types import ExtraScores, Genotype, RNGKey class MixingEmitter(Emitter): diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index 8b649d0c..d0b075a9 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -8,7 +9,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, @@ -169,7 +170,12 @@ def scan_update( The updated repertoire and emitter state, with a new random key and metrics. """ repertoire, emitter_state, random_key = carry - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 6dc8f551..8b0e7511 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites Low-Spread algorithm.""" + from __future__ import annotations from functools import partial @@ -9,7 +10,7 @@ from qdax.core.containers.mels_repertoire import MELSRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, diff --git a/qdax/core/mome.py b/qdax/core/mome.py index db450b9a..c239bd1f 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -9,7 +9,7 @@ from qdax.core.containers.mome_repertoire import MOMERepertoire from qdax.core.emitters.emitter import EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import Centroid, RNGKey +from qdax.custom_types import Centroid, RNGKey class MOME(MAPElites): diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index d25c8c6c..5057e5e2 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,7 @@ import jax import jax.numpy as jnp -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Done, diff --git a/qdax/core/neuroevolution/buffers/trajectory_buffer.py b/qdax/core/neuroevolution/buffers/trajectory_buffer.py index 2cc4ab69..93e1b2f9 100644 --- a/qdax/core/neuroevolution/buffers/trajectory_buffer.py +++ b/qdax/core/neuroevolution/buffers/trajectory_buffer.py @@ -8,7 +8,7 @@ from flax import struct from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Reward, RNGKey +from qdax.custom_types import Reward, RNGKey class TrajectoryBuffer(struct.PyTreeNode): diff --git a/qdax/core/neuroevolution/losses/dads_loss.py b/qdax/core/neuroevolution/losses/dads_loss.py index b42ca416..60edfee1 100644 --- a/qdax/core/neuroevolution/losses/dads_loss.py +++ b/qdax/core/neuroevolution/losses/dads_loss.py @@ -6,7 +6,14 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.losses.sac_loss import make_sac_loss_fn -from qdax.types import Action, Observation, Params, RNGKey, Skill, StateDescriptor +from qdax.custom_types import ( + Action, + Observation, + Params, + RNGKey, + Skill, + StateDescriptor, +) def make_dads_loss_fn( diff --git a/qdax/core/neuroevolution/losses/diayn_loss.py b/qdax/core/neuroevolution/losses/diayn_loss.py index 8bca3b4b..e25a73bd 100644 --- a/qdax/core/neuroevolution/losses/diayn_loss.py +++ b/qdax/core/neuroevolution/losses/diayn_loss.py @@ -7,7 +7,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.losses.sac_loss import make_sac_loss_fn -from qdax.types import Action, Observation, Params, RNGKey, StateDescriptor +from qdax.custom_types import Action, Observation, Params, RNGKey, StateDescriptor def make_diayn_loss_fn( diff --git a/qdax/core/neuroevolution/losses/sac_loss.py b/qdax/core/neuroevolution/losses/sac_loss.py index b3656b18..d7289292 100644 --- a/qdax/core/neuroevolution/losses/sac_loss.py +++ b/qdax/core/neuroevolution/losses/sac_loss.py @@ -6,7 +6,7 @@ from brax.training.distribution import ParametricDistribution from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.custom_types import Action, Observation, Params, RNGKey def make_sac_loss_fn( diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index e12797b9..964c2c4f 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Descriptor, Observation, Params, RNGKey +from qdax.custom_types import Action, Descriptor, Observation, Params, RNGKey def make_td3_loss_fn( diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 3b077069..f269a22b 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Descriptor, Genotype, Params, RNGKey +from qdax.custom_types import Descriptor, Genotype, Params, RNGKey class TrainingState(PyTreeNode): @@ -134,7 +134,7 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray: # the double transpose trick is here to allow easy broadcasting return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T - return jax.tree_map(mask_episodes, transition) # type: ignore + return jax.tree_util.tree_map(mask_episodes, transition) # type: ignore def init_population_controllers( diff --git a/qdax/core/neuroevolution/networks/dads_networks.py b/qdax/core/neuroevolution/networks/dads_networks.py index beb4b77a..863bdab5 100644 --- a/qdax/core/neuroevolution/networks/dads_networks.py +++ b/qdax/core/neuroevolution/networks/dads_networks.py @@ -1,128 +1,129 @@ from typing import Optional, Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp -from haiku.initializers import Initializer, VarianceScaling - -from qdax.types import Action, Observation, Skill, StateDescriptor - - -class GaussianMixture(hk.Module): - """Module that outputs a Gaussian Mixture Distribution.""" - - def __init__( - self, - num_dimensions: int, - num_components: int, - reinterpreted_batch_ndims: Optional[int] = None, - identity_covariance: bool = True, - initializer: Optional[Initializer] = None, - name: str = "GaussianMixture", - ): - """Module that outputs a Gaussian Mixture Distribution - with identity covariance matrix.""" - - super().__init__(name=name) - if initializer is None: - initializer = VarianceScaling(1.0, "fan_in", "uniform") - self._num_dimensions = num_dimensions - self._num_components = num_components - self._reinterpreted_batch_ndims = reinterpreted_batch_ndims - self._identity_covariance = identity_covariance - self.initializer = initializer - logits_size = self._num_components - - self.logit_layer = hk.Linear(logits_size, w_init=self.initializer) - - # Create two layers that outputs a location and a scale, respectively, for - # each dimension and each component. - self.loc_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=self.initializer - ) - if not self._identity_covariance: - self.scale_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=self.initializer - ) +from jax.nn import initializers + +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation, Skill, StateDescriptor + +class GaussianMixture(nn.Module): + num_dimensions: int + num_components: int + reinterpreted_batch_ndims: Optional[int] = None + identity_covariance: bool = True + initializer: Optional[initializers.Initializer] = None + + @nn.compact def __call__(self, inputs: jnp.ndarray) -> tfp.distributions.Distribution: - # Compute logits, locs, and scales if necessary. - logits = self.logit_layer(inputs) - locs = self.loc_layer(inputs) + if self.initializer is None: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + else: + init = self.initializer - shape = [-1, self._num_components, self._num_dimensions] # [B, D, C] + logits = nn.Dense(self.num_components, kernel_init=init)(inputs) + locs = nn.Dense(self.num_dimensions * self.num_components, kernel_init=init)( + inputs + ) - # Reshape the mixture's location and scale parameters appropriately. + shape = [-1, self.num_components, self.num_dimensions] # [B, D, C] locs = locs.reshape(shape) - if not self._identity_covariance: - scales = self.scale_layer(inputs) + if not self.identity_covariance: + scales = nn.Dense( + self.num_dimensions * self.num_components, kernel_init=init + )(inputs) scales = scales.reshape(shape) else: scales = jnp.ones_like(locs) - # Create the mixture distribution components = tfp.distributions.MultivariateNormalDiag( loc=locs, scale_diag=scales ) mixture = tfp.distributions.Categorical(logits=logits) - distribution = tfp.distributions.MixtureSameFamily( + return tfp.distributions.MixtureSameFamily( mixture_distribution=mixture, components_distribution=components ) - return distribution - -class DynamicsNetwork(hk.Module): - """Dynamics network (used in DADS).""" +class DynamicsNetwork(nn.Module): + hidden_layer_sizes: Tuple[int, ...] + output_size: int + omit_input_dynamics_dim: int = 2 + identity_covariance: bool = True + initializer: Optional[initializers.Initializer] = None - def __init__( - self, - hidden_layer_sizes: tuple, - output_size: int, - omit_input_dynamics_dim: int = 2, - name: Optional[str] = None, - identity_covariance: bool = True, - initializer: Optional[Initializer] = None, - ): - super().__init__(name=name) - if initializer is None: - initializer = VarianceScaling(1.0, "fan_in", "uniform") + @nn.compact + def __call__( + self, obs: StateDescriptor, skill: Skill, target: StateDescriptor + ) -> jnp.ndarray: + if self.initializer is None: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + else: + init = self.initializer - self.distribution = GaussianMixture( - output_size, + distribution = GaussianMixture( + self.output_size, num_components=4, reinterpreted_batch_ndims=None, - identity_covariance=identity_covariance, - initializer=initializer, - ) - self.network = hk.Sequential( - [ - hk.nets.MLP( - list(hidden_layer_sizes), - w_init=initializer, - activation=jax.nn.relu, - activate_final=True, - ), - ] + identity_covariance=self.identity_covariance, + initializer=init, ) - self._omit_input_dynamics_dim = omit_input_dynamics_dim - def __call__( - self, obs: StateDescriptor, skill: Skill, target: StateDescriptor - ) -> jnp.ndarray: - """Normalizes the observation, predicts a distribution probability conditioned - on (obs,skill) and returns the log_prob of the target. - """ - - obs = obs[:, self._omit_input_dynamics_dim :] + obs = obs[:, self.omit_input_dynamics_dim :] obs = jnp.concatenate((obs, skill), axis=1) - out = self.network(obs) - dist = self.distribution(out) + + x = MLP( + layer_sizes=self.hidden_layer_sizes, + kernel_init=init, + activation=nn.relu, + final_activation=nn.relu, + )(obs) + + dist = distribution(x) return dist.log_prob(target) +class Actor(nn.Module): + action_size: int + hidden_layer_sizes: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + + return MLP( + layer_sizes=self.hidden_layer_sizes + (2 * self.action_size,), + kernel_init=init, + activation=nn.relu, + )(obs) + + +class Critic(nn.Module): + hidden_layer_sizes: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + input_ = jnp.concatenate([obs, action], axis=-1) + + value_1 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) + + def make_dads_networks( action_size: int, descriptor_size: int, @@ -130,78 +131,16 @@ def make_dads_networks( policy_hidden_layer_size: Tuple[int, ...] = (256, 256), omit_input_dynamics_dim: int = 2, identity_covariance: bool = True, - dynamics_initializer: Optional[Initializer] = None, -) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: - """Creates networks used in DADS. - - Args: - action_size: the size of the environment's action space - descriptor_size: the size of the environment's descriptor space (i.e. the - dimension of the dynamics network's input) - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). - omit_input_dynamics_dim: how many descriptors we omit when creating the input - of the dynamics networks. Defaults to 2. - identity_covariance: whether to fix the covariance matrix of the Gaussian models - to identity. Defaults to True. - dynamics_initializer: the initializer of the dynamics layers. Defaults to None. - - Returns: - the policy network - the critic network - the dynamics network - """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - def _dynamics_fn( - obs: StateDescriptor, skill: Skill, target: StateDescriptor - ) -> jnp.ndarray: - dynamics_network = DynamicsNetwork( - critic_hidden_layer_size, - descriptor_size, - omit_input_dynamics_dim=omit_input_dynamics_dim, - identity_covariance=identity_covariance, - initializer=dynamics_initializer, - ) - return dynamics_network(obs, skill, target) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - dynamics = hk.without_apply_rng(hk.transform(_dynamics_fn)) + dynamics_initializer: Optional[initializers.Initializer] = None, +) -> Tuple[nn.Module, nn.Module, nn.Module]: + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) + dynamics = DynamicsNetwork( + critic_hidden_layer_size, + descriptor_size, + omit_input_dynamics_dim=omit_input_dynamics_dim, + identity_covariance=identity_covariance, + initializer=dynamics_initializer, + ) return policy, critic, dynamics diff --git a/qdax/core/neuroevolution/networks/diayn_networks.py b/qdax/core/neuroevolution/networks/diayn_networks.py index c656cace..e292e131 100644 --- a/qdax/core/neuroevolution/networks/diayn_networks.py +++ b/qdax/core/neuroevolution/networks/diayn_networks.py @@ -1,10 +1,60 @@ from typing import Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp -from qdax.types import Action, Observation +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation + + +class Actor(nn.Module): + action_size: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) + + +class Critic(nn.Module): + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + input_ = jnp.concatenate([obs, action], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) + + +class Discriminator(nn.Module): + num_skills: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (self.num_skills,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) def make_diayn_networks( @@ -12,71 +62,22 @@ def make_diayn_networks( num_skills: int, critic_hidden_layer_size: Tuple[int, ...] = (256, 256), policy_hidden_layer_size: Tuple[int, ...] = (256, 256), -) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: +) -> Tuple[nn.Module, nn.Module, nn.Module]: """Creates networks used in DIAYN. Args: action_size: the size of the environment's action space num_skills: the number of skills set - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). + critic_hidden_layer_size: the number of neurons for critic hidden layers. + policy_hidden_layer_size: the number of neurons for policy hidden layers. Returns: the policy network the critic network the discriminator network """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - def _discriminator_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [num_skills], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - discriminator = hk.without_apply_rng(hk.transform(_discriminator_fn)) + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) + discriminator = Discriminator(num_skills, critic_hidden_layer_size) return policy, critic, discriminator diff --git a/qdax/core/neuroevolution/networks/sac_networks.py b/qdax/core/neuroevolution/networks/sac_networks.py index dcadfaa2..a236afd4 100644 --- a/qdax/core/neuroevolution/networks/sac_networks.py +++ b/qdax/core/neuroevolution/networks/sac_networks.py @@ -1,66 +1,65 @@ from typing import Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp -from qdax.types import Action, Observation +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation + + +class Actor(nn.Module): + action_size: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + )(obs) + + +class Critic(nn.Module): + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + input_ = jnp.concatenate([obs, action], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) def make_sac_networks( action_size: int, critic_hidden_layer_size: Tuple[int, ...] = (256, 256), policy_hidden_layer_size: Tuple[int, ...] = (256, 256), -) -> Tuple[hk.Transformed, hk.Transformed]: +) -> Tuple[nn.Module, nn.Module]: """Creates networks used in SAC. Args: action_size: the size of the environment's action space - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). + critic_hidden_layer_size: the number of neurons for critic hidden layers. + policy_hidden_layer_size: the number of neurons for policy hidden layers. Returns: the policy network the critic network """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) return policy, critic diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index ea7618ba..3cb52a3e 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -7,7 +7,6 @@ Licensed under the Apache License, Version 2.0 (the "License") """ - import functools from typing import Any, Tuple diff --git a/qdax/core/neuroevolution/normalization_utils.py b/qdax/core/neuroevolution/normalization_utils.py index 63820921..0c98b29d 100644 --- a/qdax/core/neuroevolution/normalization_utils.py +++ b/qdax/core/neuroevolution/normalization_utils.py @@ -1,11 +1,10 @@ """Utilities functions to perform normalization (generally on observations in RL).""" - from typing import NamedTuple import jax.numpy as jnp -from qdax.types import Observation +from qdax.custom_types import Observation class RunningMeanStdState(NamedTuple): diff --git a/qdax/core/neuroevolution/sac_td3_utils.py b/qdax/core/neuroevolution/sac_td3_utils.py index 1c54511a..32bbe7a4 100644 --- a/qdax/core/neuroevolution/sac_td3_utils.py +++ b/qdax/core/neuroevolution/sac_td3_utils.py @@ -5,6 +5,7 @@ We are currently thinking about elegant ways to unify both in order to avoid code repetition. """ + # TODO: Uniformize with the functions in mdp_utils from functools import partial from typing import Any, Callable, Tuple @@ -14,7 +15,7 @@ from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.core.neuroevolution.mdp_utils import TrainingState -from qdax.types import Metrics +from qdax.custom_types import Metrics @partial( @@ -75,7 +76,8 @@ def generate_unroll( ], ], ) -> Tuple[EnvState, TrainingState, Transition]: - """Generates an episode according to the agent's policy, returns the final state of the + """ + Generates an episode according to the agent's policy, returns the final state of the episode and the transitions of the episode. """ diff --git a/qdax/types.py b/qdax/custom_types.py similarity index 100% rename from qdax/types.py rename to qdax/custom_types.py diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index af1d51ba..918fbbfb 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.types import Descriptor, Params +from qdax.custom_types import Descriptor, Params def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index ec32e7a2..c784b045 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -436,10 +436,8 @@ def step(self, state: State, action: jp.ndarray) -> State: # this line avoid this by increasing the threshold done = jp.where( state.qp.pos[0, 2] < 0.2, - x=jp.array(1, dtype=jp.float32), - y=jp.array(0, dtype=jp.float32), - ) - done = jp.where( - state.qp.pos[0, 2] > 5.0, x=jp.array(1, dtype=jp.float32), y=done + jp.array(1, dtype=jp.float32), + jp.array(0, dtype=jp.float32), ) + done = jp.where(state.qp.pos[0, 2] > 5.0, jp.array(1, dtype=jp.float32), done) return state.replace(obs=new_obs, reward=new_reward, done=done) # type: ignore diff --git a/qdax/environments/locomotion_wrappers.py b/qdax/environments/locomotion_wrappers.py index a727479e..982f5b69 100644 --- a/qdax/environments/locomotion_wrappers.py +++ b/qdax/environments/locomotion_wrappers.py @@ -260,7 +260,7 @@ def name(self) -> str: def reset(self, rng: jp.ndarray) -> State: state = self.env.reset(rng) state.info["state_descriptor"] = jnp.clip( - state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval + state.qp.pos[self._cog_idx][:2], min=self._minval, max=self._maxval ) return state @@ -268,7 +268,7 @@ def step(self, state: State, action: jp.ndarray) -> State: state = self.env.step(state, action) # get xy position of the center of gravity state.info["state_descriptor"] = jnp.clip( - state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval + state.qp.pos[self._cog_idx][:2], min=self._minval, max=self._maxval ) return state diff --git a/qdax/environments/pointmaze.py b/qdax/environments/pointmaze.py index b5f86ef5..78f7c575 100644 --- a/qdax/environments/pointmaze.py +++ b/qdax/environments/pointmaze.py @@ -150,8 +150,8 @@ def step(self, state: State, action: jp.ndarray) -> State: done = jp.where( jp.array(in_zone), - x=jp.array(1.0), - y=jp.array(0.0), + jp.array(1.0), + jp.array(0.0), ) new_obs = jp.array([x_pos, y_pos]) @@ -199,8 +199,8 @@ def _collision_lower_wall( y_axis_down_contact_condition_1 & y_axis_down_contact_condition_2 & x_axis_contact_condition, - x=jp.array(self.lower_wall_height_offset), - y=y_pos, + jp.array(self.lower_wall_height_offset), + y_pos, ) # From up - boolean style @@ -217,8 +217,8 @@ def _collision_lower_wall( & y_axis_up_contact_condition_2 & y_axis_up_contact_condition_3 & x_axis_contact_condition, - x=jp.array(self.lower_wall_height_offset + self.wallheight), - y=new_y_pos, + jp.array(self.lower_wall_height_offset + self.wallheight), + new_y_pos, ) return new_y_pos @@ -250,8 +250,8 @@ def _collision_upper_wall( y_axis_up_contact_condition_1 & y_axis_up_contact_condition_2 & x_axis_contact_condition, - x=jp.array(self.upper_wall_height_offset + self.wallheight), - y=y_pos, + jp.array(self.upper_wall_height_offset + self.wallheight), + y_pos, ) # From down - boolean style @@ -264,8 +264,8 @@ def _collision_upper_wall( & y_axis_down_contact_condition_2 & y_axis_down_contact_condition_3 & x_axis_contact_condition, - x=jp.array(self.upper_wall_height_offset), - y=new_y_pos, + jp.array(self.upper_wall_height_offset), + new_y_pos, ) return new_y_pos diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index cf0c3336..e5e40e4b 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional import flax.struct import jax @@ -80,7 +80,10 @@ class ClipRewardWrapper(Wrapper): """ def __init__( - self, env: Env, clip_min: float = None, clip_max: float = None + self, + env: Env, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, ) -> None: super().__init__(env) self._clip_min = clip_min @@ -108,7 +111,10 @@ class AffineRewardWrapper(Wrapper): """ def __init__( - self, env: Env, clip_min: float = None, clip_max: float = None + self, + env: Env, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, ) -> None: super().__init__(env) self._clip_min = clip_min diff --git a/qdax/tasks/arm.py b/qdax/tasks/arm.py index 7122ed63..27782cf3 100644 --- a/qdax/tasks/arm.py +++ b/qdax/tasks/arm.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 1a588e52..07d37d59 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -12,7 +12,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP -from qdax.types import ( +from qdax.custom_types import ( Descriptor, EnvState, ExtraScores, @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_brax( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: EnvState, diff --git a/qdax/tasks/hypervolume_functions.py b/qdax/tasks/hypervolume_functions.py index f4936574..340581ab 100644 --- a/qdax/tasks/hypervolume_functions.py +++ b/qdax/tasks/hypervolume_functions.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def square(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index 14455d66..5f861f0e 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -7,7 +7,7 @@ import jumanji from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.types import ( +from qdax.custom_types import ( Descriptor, ExtraScores, Fitness, @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_jumanji( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: JumanjiState, @@ -67,7 +68,7 @@ def default_play_step_fn( obs=timestep.observation, next_obs=next_timestep.observation, rewards=next_timestep.reward, - dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)), + dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)), actions=action, truncations=jnp.array(0), state_desc=state_desc, diff --git a/qdax/tasks/qd_suite/archimedean_spiral.py b/qdax/tasks/qd_suite/archimedean_spiral.py index 5784f596..59108ae5 100644 --- a/qdax/tasks/qd_suite/archimedean_spiral.py +++ b/qdax/tasks/qd_suite/archimedean_spiral.py @@ -4,8 +4,8 @@ import jax.lax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype class ParameterizationGenotype(Enum): diff --git a/qdax/tasks/qd_suite/deceptive_evolvability.py b/qdax/tasks/qd_suite/deceptive_evolvability.py index d5be0688..830ad523 100644 --- a/qdax/tasks/qd_suite/deceptive_evolvability.py +++ b/qdax/tasks/qd_suite/deceptive_evolvability.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype def multivariate_normal( diff --git a/qdax/tasks/qd_suite/qd_suite_task.py b/qdax/tasks/qd_suite/qd_suite_task.py index 6f1af76f..0d79317f 100644 --- a/qdax/tasks/qd_suite/qd_suite_task.py +++ b/qdax/tasks/qd_suite/qd_suite_task.py @@ -4,7 +4,7 @@ import jax from jax import numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class QDSuiteTask(abc.ABC): diff --git a/qdax/tasks/qd_suite/ssf.py b/qdax/tasks/qd_suite/ssf.py index 547bee8d..601aa6ad 100644 --- a/qdax/tasks/qd_suite/ssf.py +++ b/qdax/tasks/qd_suite/ssf.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype class SsfV0(QDSuiteTask): diff --git a/qdax/tasks/standard_functions.py b/qdax/tasks/standard_functions.py index 53d5b492..82b2f875 100644 --- a/qdax/tasks/standard_functions.py +++ b/qdax/tasks/standard_functions.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def rastrigin(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/utils/metrics.py b/qdax/utils/metrics.py index 2b8355af..509c6d91 100644 --- a/qdax/utils/metrics.py +++ b/qdax/utils/metrics.py @@ -12,7 +12,7 @@ from qdax.core.containers.ga_repertoire import GARepertoire from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.mome_repertoire import MOMERepertoire -from qdax.types import Metrics +from qdax.custom_types import Metrics from qdax.utils.pareto_front import compute_hypervolume diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index f9bd77ae..54fad3e6 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from qdax.types import Mask, ParetoFront +from qdax.custom_types import Mask, ParetoFront def compute_pareto_dominance( diff --git a/qdax/utils/plotting.py b/qdax/utils/plotting.py index 9b107c7e..7f0f086d 100644 --- a/qdax/utils/plotting.py +++ b/qdax/utils/plotting.py @@ -544,7 +544,7 @@ def _get_projection_in_1d( for all index i: x[i] < bases_tuple[i]. The vector and tuple of bases must have the same length. - For example if x=jnp.array([3, 1, 2]) and the bases are (5, 7, 3). + For example if jnp.array([3, 1, 2]) and the bases are (5, 7, 3). then the projection is 3*(7*3) + 1*(3) + 2 = 47. Args: @@ -574,7 +574,7 @@ def _get_projection_in_2d( """Projects an integer vector into a pair of integers, (given tuple of bases to consider for conversion). - For example if x=jnp.array([3, 1, 2, 5]) and the bases are (5, 2, 3, 7). + For example if jnp.array([3, 1, 2, 5]) and the bases are (5, 2, 3, 7). then the projection is obtained by: - projecting in 1D the point jnp.array([3, 2]) with the bases (5, 3) - projecting in 1D the point jnp.array([1, 5]) with the bases (2, 7) diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index bf5c1ae4..be1d336d 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -1,11 +1,12 @@ """Core components of the MAP-Elites-sampling algorithm.""" + from functools import partial from typing import Callable, Tuple import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey @jax.jit diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index acb14a9b..bd9570a9 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -16,8 +16,8 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq +from qdax.custom_types import Params, RNGKey from qdax.environments.bd_extractors import AuroraExtraInfoNormalization -from qdax.types import Params, RNGKey Array = Any PRNGKey = Any @@ -132,7 +132,7 @@ def lstm_ae_train( std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) # the std where they were NaNs was set to zero. But here we divide by the # std, so we replace the zeros by inf here. - std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) + std_obs = jnp.where(std_obs == 0, jnp.inf, std_obs) # TODO: maybe we could just compute this data on the valid dataset diff --git a/requirements.txt b/requirements.txt index 50d7899c..f6dea29a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,17 @@ ---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - absl-py==1.0.0 -brax==0.9.2 -chex==0.1.83 -dm-haiku==0.0.10 -flax==0.7.4 +brax==0.10.4 +chex==0.1.86 +flax==0.8.5 gym==0.26.2 ipython -jax[cuda12_pip] +jax==0.4.28 +jaxlib==0.4.28 jumanji==0.3.1 jupyter -numpy==1.24.1 -optax==0.1.7 -protobuf==3.19.4 -scikit-learn==1.0.2 -scipy==1.8.0 -seaborn==0.11.2 -tensorflow-probability==0.19.0 -typing-extensions==4.3.0 +numpy==1.26.4 +optax==0.1.9 +protobuf==3.19.5 +scikit-learn==1.5.1 +scipy==1.10.1 +tensorflow-probability==0.24.0 +typing-extensions==4.12.2 diff --git a/setup.py b/setup.py index 0065bf18..cd7d2b13 100644 --- a/setup.py +++ b/setup.py @@ -22,19 +22,23 @@ long_description_content_type="text/markdown", install_requires=[ "absl-py>=1.0.0", - "jax>=0.4.16", - "jaxlib>=0.4.16", # necessary to build the doc atm - "jinja2<3.1.0", + "brax>=0.10.4", + "chex>=0.1.86", + "flax>=0.8.5", + "gym>=0.26.2", + "jax>=0.4.28", + "jaxlib>=0.4.28", # necessary to build the doc atm + "jinja2>=3.1.4", "jumanji>=0.3.1", - "flax>=0.7.4", - "chex>=0.1.83", - "brax>=0.9.2", - "gym>=0.23.1", - "numpy>=1.22.3", - "optax>=0.1.7", - "scikit-learn>=1.0.2", - "scipy>=1.8.0", + "numpy>=1.26.4", + "optax>=0.1.9", + "scikit-learn>=1.5.1", + "scipy>=1.10.1", + "tensorflow-probability>=0.24.0", ], + extras_require={ + "cuda12": ["jax[cuda12]>=0.4.28"], + }, dependency_links=[ "https://storage.googleapis.com/jax-releases/jax_releases.html", ], @@ -46,7 +50,9 @@ "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index c86bd622..2dc6fa10 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -16,7 +16,7 @@ from qdax.core.emitters.cma_pool_emitter import CMAPoolEmitter from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey @pytest.mark.parametrize( @@ -25,7 +25,7 @@ ) def test_cma_me(emitter_type: Type[CMAEmitter]) -> None: - num_iterations = 1000 + num_iterations = 2000 num_dimensions = 20 grid_shape = (50, 50) batch_size = 36 @@ -43,7 +43,7 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: def clip(x: jnp.ndarray) -> jnp.ndarray: in_bound = (x <= maxval) * (x >= minval) - return jnp.where(condition=in_bound, x=x, y=(maxval / x)) + return jnp.where(in_bound, x, (maxval / x)) def _behavior_descriptor_1(x: jnp.ndarray) -> jnp.ndarray: return jnp.sum(clip(x[: x.shape[-1] // 2])) @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index fdd9330b..5bfdfd58 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -12,7 +12,7 @@ ) from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey def test_cma_mega() -> None: @@ -125,7 +125,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 1e782f2a..2a8d3d1f 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -1,4 +1,5 @@ """Testing script for the algorithm DADS""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 0b9af46e..77094ffd 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -1,5 +1,6 @@ """Training script for the algorithm DADS, should be launched with hydra. e.g. python train_dads.py config=dads_ant""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 5f9ec5f7..a1eb1b51 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -15,7 +15,7 @@ polynomial_mutation, ) from qdax.core.emitters.standard_emitters import MixingEmitter -from qdax.types import ExtraScores, Fitness, RNGKey +from qdax.custom_types import ExtraScores, Fitness, RNGKey from qdax.utils.metrics import default_ga_metrics @@ -32,11 +32,11 @@ def test_ga(algorithm_class: Type[GeneticAlgorithm]) -> None: batch_size = 100 genotype_dim = 6 lag = 2.2 - base_lag = 0 + base_lag = 0.0 num_neighbours = 1 def rastrigin_scorer( - genotypes: jnp.ndarray, base_lag: int, lag: int + genotypes: jnp.ndarray, base_lag: float, lag: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Rastrigin Scorer with first two dimensions as descriptors @@ -119,7 +119,11 @@ def scoring_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( algo_instance.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 98a5b960..5058bad6 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -119,10 +119,10 @@ def test_me_pbt_sac() -> None: def scoring_function(genotypes, random_key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) @@ -186,7 +186,7 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_map( + initial_metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] @@ -196,7 +196,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys, metrics = update_fn( repertoire, emitter_state, keys ) - metrics_cpu = jax.tree_map( + metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index fc2e89b0..39c3e942 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -117,10 +117,10 @@ def test_me_pbt_td3() -> None: def scoring_function(genotypes, random_key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) @@ -184,7 +184,7 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_map( + initial_metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] @@ -194,7 +194,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys, metrics = update_fn( repertoire, emitter_state, keys ) - metrics_cpu = jax.tree_map( + metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index 3f3314fd..d1913b02 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_mees() -> None: @@ -185,7 +185,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index 7b0f0639..ad51c7ae 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -11,7 +11,7 @@ ) from qdax.core.emitters.omg_mega_emitter import OMGMEGAEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def test_omg_mega() -> None: @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index c83f277c..db7dc69e 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -59,7 +59,7 @@ def init_environments(random_key): # type: ignore eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) reshape_fn = jax.jit( - lambda tree: jax.tree_map( + lambda tree: jax.tree_util.tree_map( lambda x: jnp.reshape( x, ( diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 9e6134c9..0be68277 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -57,7 +57,7 @@ def init_environments(random_key): # type: ignore eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) reshape_fn = jax.jit( - lambda tree: jax.tree_map( + lambda tree: jax.tree_util.tree_map( lambda x: jnp.reshape( x, ( diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 9cb1b3fb..0490a481 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -15,8 +15,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_pgame() -> None: @@ -189,7 +189,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 1889f197..704416a4 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -17,8 +17,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_qdpg() -> None: @@ -239,7 +239,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), _metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), _metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index c667aa66..8c26b510 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -10,7 +10,7 @@ from qdax.baselines.sac import SAC, SacConfig, TrainingState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer -from qdax.types import EnvState +from qdax.custom_types import EnvState def test_sac() -> None: diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 2b238237..4bbb9d82 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -11,6 +11,7 @@ from qdax import environments from qdax.core.aurora import AURORA from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.custom_types import Observation from qdax.environments.bd_extractors import ( AuroraExtraInfoNormalization, get_aurora_encoding, @@ -19,7 +20,6 @@ create_default_brax_task_components, get_aurora_scoring_fn, ) -from qdax.types import Observation from qdax.utils import train_seq2seq from qdax.utils.metrics import default_qd_metrics from tests.core_test.map_elites_test import get_mixing_emitter diff --git a/tests/core_test/containers_test/mapelites_repertoire_test.py b/tests/core_test/containers_test/mapelites_repertoire_test.py index 55e6ed11..5c0d9d75 100644 --- a/tests/core_test/containers_test/mapelites_repertoire_test.py +++ b/tests/core_test/containers_test/mapelites_repertoire_test.py @@ -5,7 +5,7 @@ MapElitesRepertoire, compute_euclidean_centroids, ) -from qdax.types import ExtraScores +from qdax.custom_types import ExtraScores def test_mapelites_repertoire() -> None: diff --git a/tests/core_test/containers_test/mels_repertoire_test.py b/tests/core_test/containers_test/mels_repertoire_test.py index 2fb1bd76..0b854b32 100644 --- a/tests/core_test/containers_test/mels_repertoire_test.py +++ b/tests/core_test/containers_test/mels_repertoire_test.py @@ -2,7 +2,7 @@ import pytest from qdax.core.containers.mels_repertoire import MELSRepertoire -from qdax.types import ExtraScores +from qdax.custom_types import ExtraScores def test_add_to_mels_repertoire() -> None: diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index 93b3e081..ebf712d5 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -96,7 +96,11 @@ def test_multi_emitter() -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index b532aa65..c89ce04f 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey from qdax.utils.metrics import default_qd_metrics @@ -143,7 +143,11 @@ def play_step_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 21f90517..66bcc05f 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -15,8 +15,8 @@ from qdax.core.mels import MELS from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey @pytest.mark.parametrize( @@ -142,7 +142,11 @@ def metrics_fn(repertoire: MELSRepertoire) -> Dict: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mels.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index 103f9489..746b94a0 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -14,7 +14,7 @@ ) from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.mome import MOME -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey from qdax.utils.metrics import default_moqd_metrics @@ -36,10 +36,10 @@ def test_mome(num_descriptors: int) -> None: crossover_percentage = 1.0 batch_size = 80 lag = 2.2 - base_lag = 0 + base_lag = 0.0 def rastrigin_scorer( - genotypes: jnp.ndarray, base_lag: int, lag: int + genotypes: jnp.ndarray, base_lag: float, lag: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Rastrigin Scorer with first two dimensions as descriptors @@ -131,7 +131,11 @@ def scoring_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mome.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index e0e298c1..06e25fcd 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -42,7 +42,9 @@ def test_insert_batch() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) + simple_transition = jax.tree_util.tree_map( + lambda x: x.repeat(3, axis=0), dummy_transition + ) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) data = QDTransition.from_flatten(replay_buffer.data, dummy_transition) pytest.assume( @@ -83,7 +85,9 @@ def test_sample() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) + simple_transition = jax.tree_util.tree_map( + lambda x: x.repeat(3, axis=0), dummy_transition + ) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) @@ -91,6 +95,6 @@ def test_sample() -> None: samples, random_key = replay_buffer.sample(random_key, 3) - samples_shapes = jax.tree_map(lambda x: x.shape, samples) - transition_shapes = jax.tree_map(lambda x: x.shape, simple_transition) + samples_shapes = jax.tree_util.tree_map(lambda x: x.shape, samples) + transition_shapes = jax.tree_util.tree_map(lambda x: x.shape, simple_transition) pytest.assume((samples_shapes == transition_shapes)) diff --git a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py index 75f68b40..12ea0874 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py @@ -202,8 +202,8 @@ def test_trajectory_buffer_insert() -> None: multy_step_episodic_data, equal_nan=True, ), - "Episodic data when transitions are added sequentially is not consistent to when\ - theya are added as batch.", + "Episodic data when transitions are added sequentially is not consistent to \ + when they are added as batch.", ) pytest.assume( diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index e71e761c..98361b23 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -96,7 +96,11 @@ def test_arm(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index f8c63259..c12518fb 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -84,7 +84,11 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index a390f709..3d619353 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -102,7 +102,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index eed90127..636a02cf 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -11,11 +11,11 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Descriptor, Observation from qdax.tasks.jumanji_envs import ( jumanji_scoring_function, make_policy_network_play_step_fn_jumanji, ) -from qdax.types import Descriptor, Observation def test_jumanji_utils() -> None: @@ -53,7 +53,13 @@ def test_jumanji_utils() -> None: def observation_processing( observation: jumanji.environments.routing.snake.types.Observation, ) -> Observation: - network_input = jnp.ravel(observation.grid) + network_input = jnp.concatenate( + [ + jnp.ravel(observation.grid), + jnp.array([observation.step_count]), + observation.action_mask.ravel(), + ] + ) return network_input play_step_fn = make_policy_network_play_step_fn_jumanji( @@ -67,7 +73,12 @@ def observation_processing( keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec - observation_size = np.prod(np.array(env.observation_spec().grid.shape)) + obs_spec = env.observation_spec() + observation_size = int( + np.prod(obs_spec.grid.shape) + + np.prod(obs_spec.step_count.shape) + + np.prod(obs_spec.action_mask.shape) + ) fake_batch = jnp.zeros(shape=(batch_size, observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index a0542e9b..46f6ce9b 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -117,7 +117,11 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index 7b310389..87913364 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -92,7 +92,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/environments_test/pointmaze_test.py b/tests/environments_test/pointmaze_test.py index a13f41cc..ecc97864 100644 --- a/tests/environments_test/pointmaze_test.py +++ b/tests/environments_test/pointmaze_test.py @@ -6,8 +6,8 @@ from brax.v1.envs import Env import qdax +from qdax.custom_types import EnvState from qdax.environments.pointmaze import PointMaze -from qdax.types import EnvState def test_pointmaze() -> None: diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 6ce6cbe9..8d19379e 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -8,8 +8,8 @@ from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey from qdax.utils.sampling import ( average, closest, diff --git a/tool.Dockerfile b/tool.Dockerfile index 10b15b02..26a68236 100644 --- a/tool.Dockerfile +++ b/tool.Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9.18-slim +FROM python:3.10.14-slim ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 ENV PIPENV_VENV_IN_PROJECT=true PIP_NO_CACHE_DIR=false PIP_DISABLE_PIP_VERSION_CHECK=1 From 82d87c2418a4fce45545085c3eef1f5d7788c59d Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Thu, 5 Sep 2024 14:14:31 +0000 Subject: [PATCH 07/20] Rename DCG-ME to DCRL-ME --- .../{qdcg_emitter.py => dcrl_emitter.py} | 134 +++++++++--------- .../{dcg_me_emitter.py => dcrl_me_emitter.py} | 22 +-- qdax/core/emitters/qpg_emitter.py | 6 +- qdax/core/neuroevolution/buffers/buffer.py | 10 +- 4 files changed, 86 insertions(+), 86 deletions(-) rename qdax/core/emitters/{qdcg_emitter.py => dcrl_emitter.py} (87%) rename qdax/core/emitters/{dcg_me_emitter.py => dcrl_me_emitter.py} (84%) diff --git a/qdax/core/emitters/qdcg_emitter.py b/qdax/core/emitters/dcrl_emitter.py similarity index 87% rename from qdax/core/emitters/qdcg_emitter.py rename to qdax/core/emitters/dcrl_emitter.py index 0fb19c4b..e7bb011d 100644 --- a/qdax/core/emitters/qdcg_emitter.py +++ b/qdax/core/emitters/dcrl_emitter.py @@ -1,4 +1,4 @@ -"""Implements the PG Emitter and Actor Injection from DCG-ME algorithm +"""Implements the DCRL Emitter from DCRL-MAP-Elites algorithm in JAX for Brax environments. """ @@ -13,7 +13,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.core.neuroevolution.buffers.buffer import DCGTransition, ReplayBuffer +from qdax.core.neuroevolution.buffers.buffer import DCRLTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn from qdax.core.neuroevolution.networks.networks import QModuleDC from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @@ -21,10 +21,10 @@ @dataclass -class QualityDCGConfig: - """Configuration for QualityDCG Emitter""" +class DCRLConfig: + """Configuration for DCRL Emitter""" - qpg_batch_size: int = 64 + dcg_batch_size: int = 64 ai_batch_size: int = 64 lengthscale: float = 0.1 @@ -44,7 +44,7 @@ class QualityDCGConfig: policy_delay: int = 2 -class QualityDCGEmitterState(EmitterState): +class DCRLEmitterState(EmitterState): """Contains training state for the learner.""" critic_params: Params @@ -54,19 +54,19 @@ class QualityDCGEmitterState(EmitterState): target_critic_params: Params target_actor_params: Params replay_buffer: ReplayBuffer - random_key: RNGKey + key: RNGKey steps: jnp.ndarray -class QualityDCGEmitter(Emitter): +class DCRLEmitter(Emitter): """ - A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites - (PGA-Map-Elites) algorithm. + A descriptor-conditioned reinforcement learning emitter used to implement + DCRL-MAP-Elites algorithm. """ def __init__( self, - config: QualityDCGConfig, + config: DCRLConfig, policy_network: nn.Module, actor_network: nn.Module, env: QDEnv, @@ -114,7 +114,7 @@ def batch_size(self) -> int: Returns: the batch size emitted by the emitter. """ - return self._config.qpg_batch_size + self._config.ai_batch_size + return self._config.dcg_batch_size + self._config.ai_batch_size @property def use_all_data(self) -> bool: @@ -127,18 +127,18 @@ def use_all_data(self) -> bool: def init( self, - random_key: RNGKey, + key: RNGKey, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> Tuple[QualityDCGEmitterState, RNGKey]: + ) -> Tuple[DCRLEmitterState, RNGKey]: """Initializes the emitter state. Args: genotypes: The initial population. - random_key: A random key. + key: A random key. Returns: The initial state of the PGAMEEmitter, a new random key. @@ -149,7 +149,7 @@ def init( action_size = self._env.action_size # Initialise critic, greedy actor and population - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) fake_obs = jnp.zeros(shape=(observation_size,)) fake_desc = jnp.zeros(shape=(descriptor_size,)) fake_action = jnp.zeros(shape=(action_size,)) @@ -159,7 +159,7 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - random_key, subkey = jax.random.split(random_key) + key, subkey = jax.random.split(key) actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) @@ -168,7 +168,7 @@ def init( actor_opt_state = self._actor_optimizer.init(actor_params) # Initialize replay buffer - dummy_transition = DCGTransition.init_dummy( + dummy_transition = DCRLTransition.init_dummy( observation_dim=self._env.observation_size, action_dim=action_size, descriptor_dim=descriptor_size, @@ -191,8 +191,8 @@ def init( replay_buffer = replay_buffer.insert(transitions) # Initial training state - random_key, subkey = jax.random.split(random_key) - emitter_state = QualityDCGEmitterState( + key, subkey = jax.random.split(key) + emitter_state = DCRLEmitterState( critic_params=critic_params, critic_opt_state=critic_opt_state, actor_params=actor_params, @@ -200,11 +200,11 @@ def init( target_critic_params=target_critic_params, target_actor_params=target_actor_params, replay_buffer=replay_buffer, - random_key=subkey, + key=subkey, steps=jnp.array(0), ) - return emitter_state, random_key + return emitter_state, key @partial(jax.jit, static_argnames=("self",)) def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: @@ -281,28 +281,28 @@ def _compute_equivalent_params_with_desc( def emit( self, repertoire: Repertoire, - emitter_state: QualityDCGEmitterState, - random_key: RNGKey, + emitter_state: DCRLEmitterState, + key: RNGKey, ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: repertoire: the current repertoire of genotypes emitter_state: the state of the emitter used - random_key: a random key + key: a random key Returns: A batch of offspring, the new emitter state and a new key. """ # PG emitter - parents_pg, descs_pg, random_key = repertoire.sample_with_descs( - random_key, self._config.qpg_batch_size + parents_pg, descs_pg, key = repertoire.sample_with_descs( + key, self._config.dcg_batch_size ) genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) # Actor injection emitter - _, descs_ai, random_key = repertoire.sample_with_descs( - random_key, self._config.ai_batch_size + _, descs_ai, key = repertoire.sample_with_descs( + key, self._config.ai_batch_size ) descs_ai = descs_ai.reshape( descs_ai.shape[0], self._env.behavior_descriptor_length @@ -317,7 +317,7 @@ def emit( return ( genotypes, {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, - random_key, + key, ) @partial( @@ -326,7 +326,7 @@ def emit( ) def emit_pg( self, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, parents: Genotype, descs: Descriptor, ) -> Genotype: @@ -355,7 +355,7 @@ def emit_pg( static_argnames=("self",), ) def emit_ai( - self, emitter_state: QualityDCGEmitterState, descs: Descriptor + self, emitter_state: DCRLEmitterState, descs: Descriptor ) -> Genotype: """Emit the offsprings generated through pg mutation. @@ -376,7 +376,7 @@ def emit_ai( return offsprings @partial(jax.jit, static_argnames=("self",)) - def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: + def emit_actor(self, emitter_state: DCRLEmitterState) -> Genotype: """Emit the greedy actor. Simply needs to be retrieved from the emitter state. @@ -396,13 +396,13 @@ def emit_actor(self, emitter_state: QualityDCGEmitterState) -> Genotype: ) def state_update( self, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, repertoire: Repertoire, genotypes: Genotype, fitnesses: Fitness, descriptors: Descriptor, extra_scores: ExtraScores, - ) -> QualityDCGEmitterState: + ) -> DCRLEmitterState: """This function gives an opportunity to update the emitter state after the genotypes have been scored. @@ -431,7 +431,7 @@ def state_update( desc_prime = jnp.concatenate( [ extra_scores["desc_prime"], - descriptors[self._config.qpg_batch_size + self._config.ai_batch_size :], + descriptors[self._config.dcg_batch_size + self._config.ai_batch_size :], ], axis=0, ) @@ -449,8 +449,8 @@ def state_update( emitter_state = emitter_state.replace(replay_buffer=replay_buffer) # sample transitions from the replay buffer - random_key, subkey = jax.random.split(emitter_state.random_key) - transitions, random_key = replay_buffer.sample( + key, subkey = jax.random.split(emitter_state.key) + transitions, key = replay_buffer.sample( subkey, self._config.num_critic_training_steps * self._config.batch_size ) transitions = jax.tree_util.tree_map( @@ -468,12 +468,12 @@ def state_update( rewards=self._similarity(transitions.desc, transitions.desc_prime) * transitions.rewards ) - emitter_state = emitter_state.replace(random_key=random_key) + emitter_state = emitter_state.replace(key=key) def scan_train_critics( - carry: QualityDCGEmitterState, - transitions: DCGTransition, - ) -> Tuple[QualityDCGEmitterState, Any]: + carry: DCRLEmitterState, + transitions: DCRLTransition, + ) -> Tuple[DCRLEmitterState, Any]: emitter_state = carry new_emitter_state = self._train_critics(emitter_state, transitions) return new_emitter_state, () @@ -490,8 +490,8 @@ def scan_train_critics( @partial(jax.jit, static_argnames=("self",)) def _train_critics( - self, emitter_state: QualityDCGEmitterState, transitions: DCGTransition - ) -> QualityDCGEmitterState: + self, emitter_state: DCRLEmitterState, transitions: DCRLTransition + ) -> DCRLEmitterState: """Apply one gradient step to critics and to the greedy actor (contained in carry in training_state), then soft update target critics and target actor. @@ -510,14 +510,14 @@ def _train_critics( critic_opt_state, critic_params, target_critic_params, - random_key, + key, ) = self._update_critic( critic_params=emitter_state.critic_params, target_critic_params=emitter_state.target_critic_params, target_actor_params=emitter_state.target_actor_params, critic_opt_state=emitter_state.critic_opt_state, transitions=transitions, - random_key=emitter_state.random_key, + key=emitter_state.key, ) # Update greedy actor @@ -550,7 +550,7 @@ def _train_critics( actor_opt_state=actor_opt_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, - random_key=random_key, + key=key, steps=emitter_state.steps + 1, ) @@ -563,13 +563,13 @@ def _update_critic( target_critic_params: Params, target_actor_params: Params, critic_opt_state: Params, - transitions: DCGTransition, - random_key: RNGKey, + transitions: DCRLTransition, + key: RNGKey, ) -> Tuple[Params, Params, Params, RNGKey]: # compute loss and gradients - random_key, subkey = jax.random.split(random_key) - critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + key, subkey = jax.random.split(key) + critic_gradient = jax.grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, @@ -591,7 +591,7 @@ def _update_critic( critic_params, ) - return critic_opt_state, critic_params, target_critic_params, random_key + return critic_opt_state, critic_params, target_critic_params, key @partial(jax.jit, static_argnames=("self",)) def _update_actor( @@ -600,11 +600,11 @@ def _update_actor( actor_opt_state: optax.OptState, target_actor_params: Params, critic_params: Params, - transitions: DCGTransition, + transitions: DCRLTransition, ) -> Tuple[optax.OptState, Params, Params]: # Update greedy actor - policy_loss, policy_gradient = jax.value_and_grad(self._actor_loss_fn)( + policy_gradient = jax.grad(self._actor_loss_fn)( actor_params, critic_params, transitions, @@ -637,7 +637,7 @@ def _mutation_function_pg( self, policy_params: Genotype, descs: Descriptor, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, ) -> Genotype: """Apply pg mutation to a policy via multiple steps of gradient descent. First, update the rewards to be diversity rewards, then apply the gradient @@ -653,8 +653,8 @@ def _mutation_function_pg( The updated params of the neural network. """ # Get transitions - transitions, random_key = emitter_state.replay_buffer.sample( - emitter_state.random_key, + transitions, key = emitter_state.replay_buffer.sample( + emitter_state.key, sample_size=self._config.num_pg_training_steps * self._config.batch_size, ) descs_prime = jnp.tile( @@ -678,16 +678,16 @@ def _mutation_function_pg( transitions, ) - # Replace random_key - emitter_state = emitter_state.replace(random_key=random_key) + # Replace key + emitter_state = emitter_state.replace(key=key) # Define new policy optimizer state policy_opt_state = self._policies_optimizer.init(policy_params) def scan_train_policy( - carry: Tuple[QualityDCGEmitterState, Genotype, optax.OptState], - transitions: DCGTransition, - ) -> Tuple[Tuple[QualityDCGEmitterState, Genotype, optax.OptState], Any]: + carry: Tuple[DCRLEmitterState, Genotype, optax.OptState], + transitions: DCRLTransition, + ) -> Tuple[Tuple[DCRLEmitterState, Genotype, optax.OptState], Any]: emitter_state, policy_params, policy_opt_state = carry ( new_emitter_state, @@ -721,11 +721,11 @@ def scan_train_policy( @partial(jax.jit, static_argnames=("self",)) def _train_policy( self, - emitter_state: QualityDCGEmitterState, + emitter_state: DCRLEmitterState, policy_params: Params, policy_opt_state: optax.OptState, - transitions: DCGTransition, - ) -> Tuple[QualityDCGEmitterState, Params, optax.OptState]: + transitions: DCRLTransition, + ) -> Tuple[DCRLEmitterState, Params, optax.OptState]: """Apply one gradient step to a policy (called policy_params). Args: @@ -752,11 +752,11 @@ def _update_policy( critic_params: Params, policy_opt_state: optax.OptState, policy_params: Params, - transitions: DCGTransition, + transitions: DCRLTransition, ) -> Tuple[optax.OptState, Params]: # compute loss - _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( policy_params, critic_params, transitions, diff --git a/qdax/core/emitters/dcg_me_emitter.py b/qdax/core/emitters/dcrl_me_emitter.py similarity index 84% rename from qdax/core/emitters/dcg_me_emitter.py rename to qdax/core/emitters/dcrl_me_emitter.py index fea237c6..36a9f03d 100644 --- a/qdax/core/emitters/dcg_me_emitter.py +++ b/qdax/core/emitters/dcrl_me_emitter.py @@ -4,18 +4,18 @@ import flax.linen as nn from qdax.core.emitters.multi_emitter import MultiEmitter -from qdax.core.emitters.qdcg_emitter import QualityDCGConfig, QualityDCGEmitter +from qdax.core.emitters.dcrl_emitter import DCRLConfig, DCRLEmitter from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv @dataclass -class DCGMEConfig: - """Configuration for DCGME Algorithm""" +class DCRLMEConfig: + """Configuration for DCRL-MAP-Elites Algorithm""" ga_batch_size: int = 128 - qpg_batch_size: int = 64 + dcg_batch_size: int = 64 ai_batch_size: int = 64 lengthscale: float = 0.1 @@ -36,10 +36,10 @@ class DCGMEConfig: policy_delay: int = 2 -class DCGMEEmitter(MultiEmitter): +class DCRLMEEmitter(MultiEmitter): def __init__( self, - config: DCGMEConfig, + config: DCRLMEConfig, policy_network: nn.Module, actor_network: nn.Module, env: QDEnv, @@ -49,8 +49,8 @@ def __init__( self._env = env self._variation_fn = variation_fn - qdcg_config = QualityDCGConfig( - qpg_batch_size=config.qpg_batch_size, + dcrl_config = DCRLConfig( + dcg_batch_size=config.dcg_batch_size, ai_batch_size=config.ai_batch_size, lengthscale=config.lengthscale, critic_hidden_layer_size=config.critic_hidden_layer_size, @@ -70,8 +70,8 @@ def __init__( ) # define the quality emitter - q_emitter = QualityDCGEmitter( - config=qdcg_config, + dcrl_emitter = DCRLEmitter( + config=dcrl_config, policy_network=policy_network, actor_network=actor_network, env=env, @@ -85,4 +85,4 @@ def __init__( batch_size=config.ga_batch_size, ) - super().__init__(emitters=(q_emitter, ga_emitter)) + super().__init__(emitters=(dcrl_emitter, ga_emitter)) diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c6e2df7e..63373494 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -428,7 +428,7 @@ def _update_critic( # compute loss and gradients random_key, subkey = jax.random.split(random_key) - critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + critic_gradient = jax.grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, @@ -463,7 +463,7 @@ def _update_actor( ) -> Tuple[optax.OptState, Params, Params]: # Update greedy actor - policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( actor_params, critic_params, transitions, @@ -595,7 +595,7 @@ def _update_policy( ) -> Tuple[optax.OptState, Params]: # compute loss - _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( policy_params, critic_params, transitions, diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 5057e5e2..81f1e896 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -270,7 +270,7 @@ def init_dummy( # type: ignore return dummy_transition -class DCGTransition(QDTransition): +class DCRLTransition(QDTransition): """Stores data corresponding to a transition collected by a QD algorithm.""" desc: Descriptor @@ -325,8 +325,8 @@ def flatten(self) -> jnp.ndarray: def from_flatten( cls, flattened_transition: jnp.ndarray, - transition: QDTransition, - ) -> QDTransition: + transition: DCRLTransition, + ) -> DCRLTransition: """ Creates a transition from a flattened transition in a jnp.ndarray. Args: @@ -394,7 +394,7 @@ def from_flatten( @classmethod def init_dummy( # type: ignore cls, observation_dim: int, action_dim: int, descriptor_dim: int - ) -> QDTransition: + ) -> DCRLTransition: """ Initialize a dummy transition that then can be passed to constructors to get all shapes right. @@ -404,7 +404,7 @@ def init_dummy( # type: ignore Returns: a dummy transition """ - dummy_transition = DCGTransition( + dummy_transition = DCRLTransition( obs=jnp.zeros(shape=(1, observation_dim)), next_obs=jnp.zeros(shape=(1, observation_dim)), rewards=jnp.zeros(shape=(1,)), From f1541e7d338382f4e270a02de3c65c49b069041f Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Thu, 5 Sep 2024 14:14:43 +0000 Subject: [PATCH 08/20] Add test for DCRL-ME --- tests/baselines_test/dcrlme_test.py | 234 ++++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 tests/baselines_test/dcrlme_test.py diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py new file mode 100644 index 00000000..7968804c --- /dev/null +++ b/tests/baselines_test/dcrlme_test.py @@ -0,0 +1,234 @@ +from typing import Any, Dict, Tuple +import functools +import pytest + +import jax +import jax.numpy as jnp + +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function +from qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function +from qdax import environments +from qdax.environments import behavior_descriptor_extractor +from qdax.core.map_elites import MAPElites +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter +from qdax.core.neuroevolution.buffers.buffer import DCRLTransition +from qdax.core.neuroevolution.networks.networks import MLP, MLPDC +from qdax.utils.metrics import default_qd_metrics + + +def test_dcrlme() -> None: + seed = 42 + + env_name = "ant_omni" + episode_length = 100 + min_bd = -30. + max_bd = 30. + + num_iterations = 5 + batch_size = 256 + + # Archive + num_init_cvt_samples = 50000 + num_centroids = 1024 + policy_hidden_layer_sizes = (128, 128) + + # DCRL-ME + ga_batch_size = 128 + dcg_batch_size = 64 + ai_batch_size = 64 + lengthscale = 0.1 + + # GA emitter + iso_sigma = 0.005 + line_sigma = 0.05 + + # DCRL emitter + critic_hidden_layer_size = (256, 256) + num_critic_training_steps = 3000 + num_pg_training_steps = 150 + pg_batch_size = 100 + replay_buffer_size = 1_000_000 + discount = 0.99 + reward_scaling = 1.0 + critic_learning_rate = 3e-4 + actor_learning_rate = 3e-4 + policy_learning_rate = 5e-3 + noise_clip = 0.5 + policy_noise = 0.2 + soft_tau_update = 0.005 + policy_delay = 2 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init environment + env = environments.create(env_name, episode_length=episode_length) + reset_fn = jax.jit(env.reset) + + # Compute the centroids + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + actor_dc_network = MLPDC( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size)) + init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs) + + # Define the fonction to play a step with the policy in the environment + def play_step_fn(env_state, policy_params, random_key): + actions = policy_network.apply(policy_params, env_state.obs) + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = DCRLTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + truncations=next_state.info["truncation"], + actions=actions, + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + ) + + return next_state, policy_params, random_key, transition + + # Prepare the scoring function + bd_extraction_fn = behavior_descriptor_extractor[env_name] + scoring_fn = functools.partial( + scoring_function, + episode_length=episode_length, + play_reset_fn=reset_fn, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key): + desc_prime_normalized = dcg_emitter.emitters[0]._normalize_desc(desc) + actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc_prime_normalized) + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = DCRLTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + truncations=next_state.info["truncation"], + actions=actions, + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + ) + + return next_state, actor_dc_params, desc, random_key, transition + + # Prepare the scoring function + scoring_actor_dc_fn = jax.jit(functools.partial( + scoring_actor_dc_function, + episode_length=episode_length, + play_reset_fn=reset_fn, + play_step_actor_dc_fn=play_step_actor_dc_fn, + behavior_descriptor_extractor=bd_extraction_fn, + )) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = 0 + + # Define a metrics function + metrics_function = functools.partial( + default_qd_metrics, + qd_offset=reward_offset * episode_length, + ) + + # Define the DCG-emitter config + dcg_emitter_config = DCRLMEConfig( + ga_batch_size=ga_batch_size, + dcg_batch_size=dcg_batch_size, + ai_batch_size=ai_batch_size, + lengthscale=lengthscale, + critic_hidden_layer_size=critic_hidden_layer_size, + num_critic_training_steps=num_critic_training_steps, + num_pg_training_steps=num_pg_training_steps, + batch_size=batch_size, + replay_buffer_size=replay_buffer_size, + discount=discount, + reward_scaling=reward_scaling, + critic_learning_rate=critic_learning_rate, + actor_learning_rate=actor_learning_rate, + policy_learning_rate=policy_learning_rate, + noise_clip=noise_clip, + policy_noise=policy_noise, + soft_tau_update=soft_tau_update, + policy_delay=policy_delay, + ) + + # Get the emitter + variation_fn = functools.partial( + isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma + ) + + dcg_emitter = DCRLMEEmitter( + config=dcg_emitter_config, + policy_network=policy_network, + actor_network=actor_dc_network, + env=env, + variation_fn=variation_fn, + ) + + # Instantiate MAP Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=dcg_emitter, + metrics_function=metrics_function, + ) + + # compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key) + + @jax.jit + def update_scan_fn(carry: Any, unused: Any) -> Any: + # iterate over grid + repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) + + return (repertoire, emitter_state, random_key), metrics + + # Run the algorithm + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( + update_scan_fn, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) From d7fd89e289a5364c671499378906c1844bd44f15 Mon Sep 17 00:00:00 2001 From: Maxence Faldor Date: Thu, 5 Sep 2024 14:27:39 +0000 Subject: [PATCH 09/20] Fix typing in dcrlme_test.py --- qdax/core/emitters/dcrl_emitter.py | 8 +-- qdax/core/emitters/dcrl_me_emitter.py | 2 +- tests/baselines_test/dcrlme_test.py | 72 ++++++++++----------------- 3 files changed, 28 insertions(+), 54 deletions(-) diff --git a/qdax/core/emitters/dcrl_emitter.py b/qdax/core/emitters/dcrl_emitter.py index e7bb011d..0847e8a2 100644 --- a/qdax/core/emitters/dcrl_emitter.py +++ b/qdax/core/emitters/dcrl_emitter.py @@ -301,9 +301,7 @@ def emit( genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) # Actor injection emitter - _, descs_ai, key = repertoire.sample_with_descs( - key, self._config.ai_batch_size - ) + _, descs_ai, key = repertoire.sample_with_descs(key, self._config.ai_batch_size) descs_ai = descs_ai.reshape( descs_ai.shape[0], self._env.behavior_descriptor_length ) @@ -354,9 +352,7 @@ def emit_pg( jax.jit, static_argnames=("self",), ) - def emit_ai( - self, emitter_state: DCRLEmitterState, descs: Descriptor - ) -> Genotype: + def emit_ai(self, emitter_state: DCRLEmitterState, descs: Descriptor) -> Genotype: """Emit the offsprings generated through pg mutation. Args: diff --git a/qdax/core/emitters/dcrl_me_emitter.py b/qdax/core/emitters/dcrl_me_emitter.py index 36a9f03d..fd28555d 100644 --- a/qdax/core/emitters/dcrl_me_emitter.py +++ b/qdax/core/emitters/dcrl_me_emitter.py @@ -3,8 +3,8 @@ import flax.linen as nn -from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.dcrl_emitter import DCRLConfig, DCRLEmitter +from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 7968804c..9b602698 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -1,20 +1,20 @@ -from typing import Any, Dict, Tuple import functools -import pytest +from typing import Any, Tuple import jax import jax.numpy as jnp +import pytest -from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids -from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function -from qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function from qdax import environments -from qdax.environments import behavior_descriptor_extractor -from qdax.core.map_elites import MAPElites -from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import DCRLTransition from qdax.core.neuroevolution.networks.networks import MLP, MLPDC +from qdax.custom_types import EnvState, Params, RNGKey +from qdax.environments import behavior_descriptor_extractor +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs from qdax.utils.metrics import default_qd_metrics @@ -23,8 +23,8 @@ def test_dcrlme() -> None: env_name = "ant_omni" episode_length = 100 - min_bd = -30. - max_bd = 30. + min_bd = -30.0 + max_bd = 30.0 num_iterations = 5 batch_size = 256 @@ -48,7 +48,6 @@ def test_dcrlme() -> None: critic_hidden_layer_size = (256, 256) num_critic_training_steps = 3000 num_pg_training_steps = 150 - pg_batch_size = 100 replay_buffer_size = 1_000_000 discount = 0.99 reward_scaling = 1.0 @@ -90,7 +89,6 @@ def test_dcrlme() -> None: final_activation=jnp.tanh, ) - # Init population of controllers random_key, subkey = jax.random.split(random_key) keys = jax.random.split(subkey, num=batch_size) @@ -98,7 +96,9 @@ def test_dcrlme() -> None: init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs) # Define the fonction to play a step with the policy in the environment - def play_step_fn(env_state, policy_params, random_key): + def play_step_fn( + env_state: EnvState, policy_params: Params, random_key: RNGKey + ) -> Tuple[EnvState, Params, RNGKey, DCRLTransition]: actions = policy_network.apply(policy_params, env_state.obs) state_desc = env_state.info["state_descriptor"] next_state = env.step(env_state, actions) @@ -112,8 +112,14 @@ def play_step_fn(env_state, policy_params, random_key): actions=actions, state_desc=state_desc, next_state_desc=next_state.info["state_descriptor"], - desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, - desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, + desc=jnp.zeros( + env.behavior_descriptor_length, + ) + * jnp.nan, + desc_prime=jnp.zeros( + env.behavior_descriptor_length, + ) + * jnp.nan, ) return next_state, policy_params, random_key, transition @@ -121,43 +127,13 @@ def play_step_fn(env_state, policy_params, random_key): # Prepare the scoring function bd_extraction_fn = behavior_descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function, + reset_based_scoring_function_brax_envs, episode_length=episode_length, play_reset_fn=reset_fn, play_step_fn=play_step_fn, behavior_descriptor_extractor=bd_extraction_fn, ) - def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key): - desc_prime_normalized = dcg_emitter.emitters[0]._normalize_desc(desc) - actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc_prime_normalized) - state_desc = env_state.info["state_descriptor"] - next_state = env.step(env_state, actions) - - transition = DCRLTransition( - obs=env_state.obs, - next_obs=next_state.obs, - rewards=next_state.reward, - dones=next_state.done, - truncations=next_state.info["truncation"], - actions=actions, - state_desc=state_desc, - next_state_desc=next_state.info["state_descriptor"], - desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, - desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan, - ) - - return next_state, actor_dc_params, desc, random_key, transition - - # Prepare the scoring function - scoring_actor_dc_fn = jax.jit(functools.partial( - scoring_actor_dc_function, - episode_length=episode_length, - play_reset_fn=reset_fn, - play_step_actor_dc_fn=play_step_actor_dc_fn, - behavior_descriptor_extractor=bd_extraction_fn, - )) - # Get minimum reward value to make sure qd_score are positive reward_offset = 0 @@ -210,7 +186,9 @@ def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key): ) # compute initial repertoire - repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key) + repertoire, emitter_state, random_key = map_elites.init( + init_params, centroids, random_key + ) @jax.jit def update_scan_fn(carry: Any, unused: Any) -> Any: From f0b2d643e915b781c0f8575de10c0b57a455cf70 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 15:01:03 +0000 Subject: [PATCH 10/20] import qdax.custom_types instead of qdax.types in notebooks --- examples/aurora.ipynb | 2 +- examples/jumanji_snake.ipynb | 2 +- examples/me_sac_pbt.ipynb | 2 +- examples/me_td3_pbt.ipynb | 2 +- examples/mome.ipynb | 2 +- examples/nsga2_spea2.ipynb | 2 +- examples/pga_aurora.ipynb | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index 645ed911..cebfb0a1 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -93,7 +93,7 @@ "from qdax.core.emitters.mutation_operators import isoline_variation\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", "\n", - "from qdax.types import Observation\n", + "from qdax.custom_types import Observation\n", "from qdax.utils import train_seq2seq\n", "\n", "\n", diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index bfba1e5a..3c33c84a 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -72,7 +72,7 @@ "from qdax.core.emitters.mutation_operators import isoline_variation\n", "\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", - "from qdax.types import ExtraScores, Fitness, RNGKey, Descriptor\n", + "from qdax.custom_types import ExtraScores, Fitness, RNGKey, Descriptor\n", "from qdax.utils.metrics import default_ga_metrics, default_qd_metrics" ] }, diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index fc6fbe8b..db297992 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -54,7 +54,7 @@ "from qdax.core.distributed_map_elites import DistributedMAPElites\n", "from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig\n", "from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn\n", - "from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n", + "from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n", "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", "from qdax.utils.plotting import plot_map_elites_results" ] diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index f28a1db4..0a8db24a 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -56,7 +56,7 @@ "from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig\n", "from qdax.core.emitters.pbt_variation_operators import td3_pbt_variation_fn\n", "from qdax.core.distributed_map_elites import DistributedMAPElites\n", - "from qdax.types import RNGKey\n", + "from qdax.custom_types import RNGKey\n", "from qdax.utils.metrics import default_qd_metrics\n", "from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_map_elites_results" ] diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 7e28b608..17a4adbe 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -81,7 +81,7 @@ "\n", "import matplotlib.pyplot as plt\n", "\n", - "from qdax.types import Fitness, Descriptor, RNGKey, ExtraScores" + "from qdax.custom_types import Fitness, Descriptor, RNGKey, ExtraScores" ] }, { diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 4e9ab3b0..4d7d81b2 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -88,7 +88,7 @@ "from qdax.utils.plotting import plot_global_pareto_front\n", "from qdax.utils.metrics import default_ga_metrics\n", "\n", - "from qdax.types import Genotype, Fitness, Descriptor" + "from qdax.custom_types import Genotype, Fitness, Descriptor" ] }, { diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 6152ce63..01697d88 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -93,7 +93,7 @@ "from qdax.core.emitters.mutation_operators import isoline_variation\n", "from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter\n", "\n", - "from qdax.types import Observation\n", + "from qdax.custom_types import Observation\n", "from qdax.utils import train_seq2seq\n", "\n", "\n", From b3f0f13efdd93808c3585c837eefc4bef98ae4c4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 20:26:40 +0000 Subject: [PATCH 11/20] add reward offset to DCRLME test --- tests/baselines_test/dcrlme_test.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 9b602698..313d3623 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import pytest +from brax.envs import Env, State, Wrapper from qdax import environments from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids @@ -18,6 +19,27 @@ from qdax.utils.metrics import default_qd_metrics +class RewardOffsetEnvWrapper(Wrapper): + """Wraps ant_omni environment to add and scale position.""" + + def __init__(self, env: Env, env_name: str) -> None: + super().__init__(env) + self._env_name = env_name + + @property + def name(self) -> str: + return self._env_name + + def reset(self, rng: jnp.ndarray) -> State: + state = self.env.reset(rng) + return state + + def step(self, state: State, action: jnp.ndarray) -> State: + state = self.env.step(state, action) + new_reward = state.reward + environments.reward_offset[self._env_name] + return state.replace(reward=new_reward) + + def test_dcrlme() -> None: seed = 42 @@ -64,6 +86,10 @@ def test_dcrlme() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) + env = RewardOffsetEnvWrapper( + env, env_name + ) # apply reward offset as DCG needs positive rewards + reset_fn = jax.jit(env.reset) # Compute the centroids From 3eeb4c18cf745035d7d59ca07edd9aa284588cdf Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 22:07:35 +0000 Subject: [PATCH 12/20] add environment offset and rename variables to dcrl --- tests/baselines_test/dcrlme_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 313d3623..748da6d0 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -161,7 +161,7 @@ def play_step_fn( ) # Get minimum reward value to make sure qd_score are positive - reward_offset = 0 + reward_offset = environments.reward_offset[env_name] # Define a metrics function metrics_function = functools.partial( @@ -169,8 +169,8 @@ def play_step_fn( qd_offset=reward_offset * episode_length, ) - # Define the DCG-emitter config - dcg_emitter_config = DCRLMEConfig( + # Define the DCRL-emitter config + dcrl_emitter_config = DCRLMEConfig( ga_batch_size=ga_batch_size, dcg_batch_size=dcg_batch_size, ai_batch_size=ai_batch_size, @@ -196,8 +196,8 @@ def play_step_fn( isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma ) - dcg_emitter = DCRLMEEmitter( - config=dcg_emitter_config, + dcrl_emitter = DCRLMEEmitter( + config=dcrl_emitter_config, policy_network=policy_network, actor_network=actor_dc_network, env=env, @@ -207,7 +207,7 @@ def play_step_fn( # Instantiate MAP Elites map_elites = MAPElites( scoring_function=scoring_fn, - emitter=dcg_emitter, + emitter=dcrl_emitter, metrics_function=metrics_function, ) From 1c6a2335ca79396914598a059aa55eb3cc5063b8 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 22:11:09 +0000 Subject: [PATCH 13/20] update library versions in notebook downloads for Colab --- examples/aurora.ipynb | 6 +++--- examples/cmaes.ipynb | 6 +++--- examples/cmame.ipynb | 6 +++--- examples/cmamega.ipynb | 6 +++--- examples/dads.ipynb | 6 +++--- examples/diayn.ipynb | 6 +++--- examples/distributed_mapelites.ipynb | 6 +++--- examples/jumanji_snake.ipynb | 6 +++--- examples/mapelites.ipynb | 6 +++--- examples/me_sac_pbt.ipynb | 6 +++--- examples/me_td3_pbt.ipynb | 6 +++--- examples/mees.ipynb | 6 +++--- examples/mels.ipynb | 6 +++--- examples/mome.ipynb | 6 +++--- examples/nsga2_spea2.ipynb | 6 +++--- examples/omgmega.ipynb | 6 +++--- examples/pga_aurora.ipynb | 6 +++--- examples/pgame.ipynb | 6 +++--- examples/qdpg.ipynb | 6 +++--- examples/sac_pbt.ipynb | 6 +++--- examples/smerl.ipynb | 6 +++--- examples/td3_pbt.ipynb | 6 +++--- 22 files changed, 66 insertions(+), 66 deletions(-) diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index cebfb0a1..fb955a98 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -49,19 +49,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index d7b30b1d..8b87473d 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -36,19 +36,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index ff5fa5c2..7da832eb 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -41,19 +41,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index 739ac3d5..a90f8309 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -35,19 +35,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index 47abd1ec..19348348 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -45,19 +45,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 8e085fce..fba2055f 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -45,19 +45,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index ea2f9b9b..0fe1094c 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -54,19 +54,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 3c33c84a..dc915524 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -28,19 +28,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index b1fea651..575ee0c0 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -49,19 +49,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index db297992..778a7a5f 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -16,19 +16,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 0a8db24a..da8d7311 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -17,19 +17,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/mees.ipynb b/examples/mees.ipynb index ad1a4740..e9e37c2a 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -54,19 +54,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index bd489ca2..4f3fdc74 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -50,19 +50,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 17a4adbe..4661d406 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -41,19 +41,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index 4d7d81b2..d6385291 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -42,19 +42,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index 8d417cc0..c09deefe 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -37,19 +37,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 01697d88..29f4cc74 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -49,19 +49,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 03fd9c00..d60e246a 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -48,19 +48,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 102d5262..2082dcaa 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -48,19 +48,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index a484b035..c71559f7 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -22,19 +22,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index d50f448f..5f08e582 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -45,19 +45,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index 484f6d12..b3d2cbe1 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -19,19 +19,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", From 9c25a12c60e377d24a0b31f492f98837e01b5f87 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 22:12:02 +0000 Subject: [PATCH 14/20] dcg_batch_size -> dcrl_batch_size --- qdax/core/emitters/dcrl_emitter.py | 10 ++++++---- qdax/core/emitters/dcrl_me_emitter.py | 4 ++-- tests/baselines_test/dcrlme_test.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/qdax/core/emitters/dcrl_emitter.py b/qdax/core/emitters/dcrl_emitter.py index 0847e8a2..b353a22f 100644 --- a/qdax/core/emitters/dcrl_emitter.py +++ b/qdax/core/emitters/dcrl_emitter.py @@ -24,7 +24,7 @@ class DCRLConfig: """Configuration for DCRL Emitter""" - dcg_batch_size: int = 64 + dcrl_batch_size: int = 64 ai_batch_size: int = 64 lengthscale: float = 0.1 @@ -114,7 +114,7 @@ def batch_size(self) -> int: Returns: the batch size emitted by the emitter. """ - return self._config.dcg_batch_size + self._config.ai_batch_size + return self._config.dcrl_batch_size + self._config.ai_batch_size @property def use_all_data(self) -> bool: @@ -296,7 +296,7 @@ def emit( """ # PG emitter parents_pg, descs_pg, key = repertoire.sample_with_descs( - key, self._config.dcg_batch_size + key, self._config.dcrl_batch_size ) genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) @@ -427,7 +427,9 @@ def state_update( desc_prime = jnp.concatenate( [ extra_scores["desc_prime"], - descriptors[self._config.dcg_batch_size + self._config.ai_batch_size :], + descriptors[ + self._config.dcrl_batch_size + self._config.ai_batch_size : + ], ], axis=0, ) diff --git a/qdax/core/emitters/dcrl_me_emitter.py b/qdax/core/emitters/dcrl_me_emitter.py index fd28555d..89ffebc2 100644 --- a/qdax/core/emitters/dcrl_me_emitter.py +++ b/qdax/core/emitters/dcrl_me_emitter.py @@ -15,7 +15,7 @@ class DCRLMEConfig: """Configuration for DCRL-MAP-Elites Algorithm""" ga_batch_size: int = 128 - dcg_batch_size: int = 64 + dcrl_batch_size: int = 64 ai_batch_size: int = 64 lengthscale: float = 0.1 @@ -50,7 +50,7 @@ def __init__( self._variation_fn = variation_fn dcrl_config = DCRLConfig( - dcg_batch_size=config.dcg_batch_size, + dcrl_batch_size=config.dcrl_batch_size, ai_batch_size=config.ai_batch_size, lengthscale=config.lengthscale, critic_hidden_layer_size=config.critic_hidden_layer_size, diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index 748da6d0..eead8104 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -58,7 +58,7 @@ def test_dcrlme() -> None: # DCRL-ME ga_batch_size = 128 - dcg_batch_size = 64 + dcrl_batch_size = 64 ai_batch_size = 64 lengthscale = 0.1 @@ -172,7 +172,7 @@ def play_step_fn( # Define the DCRL-emitter config dcrl_emitter_config = DCRLMEConfig( ga_batch_size=ga_batch_size, - dcg_batch_size=dcg_batch_size, + dcrl_batch_size=dcrl_batch_size, ai_batch_size=ai_batch_size, lengthscale=lengthscale, critic_hidden_layer_size=critic_hidden_layer_size, From 64b37764ab331b23892ad41f83261d19ed9d7060 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 22:12:27 +0000 Subject: [PATCH 15/20] add notebook DCRL-ME --- examples/dcrlme.ipynb | 524 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 524 insertions(+) create mode 100644 examples/dcrlme.ipynb diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb new file mode 100644 index 00000000..36216c47 --- /dev/null +++ b/examples/dcrlme.ipynb @@ -0,0 +1,524 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing with DCRL-ME in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME)](https://arxiv.org/abs/2401.08632), also known as *Descriptor-Conditioned Gradients MAP-Elites with Actor Injection (DCG-ME-AI)*. \n", + "This algorithm extends and improves upon [Descriptor-Conditioned Gradients MAP-Elites (DCG-ME)](https://dl.acm.org/doi/abs/10.1145/3583131.3590503)\n", + "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create the DCRL emitter\n", + "- how to create a Map-elites instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualize the results of the training process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "import time\n", + "from typing import Any, Tuple\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import pytest\n", + "from brax.envs import Env, State, Wrapper\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax import environments\n", + "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", + "from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.map_elites import MAPElites\n", + "from qdax.core.neuroevolution.buffers.buffer import DCRLTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n", + "from qdax.custom_types import EnvState, Params, RNGKey\n", + "from qdax.environments import behavior_descriptor_extractor\n", + "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.utils.plotting import plot_map_elites_results\n", + "\n", + "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class RewardOffsetEnvWrapper(Wrapper):\n", + " \"\"\"Wraps ant_omni environment to add and scale position.\"\"\"\n", + "\n", + " def __init__(self, env: Env, env_name: str) -> None:\n", + " super().__init__(env)\n", + " self._env_name = env_name\n", + "\n", + " @property\n", + " def name(self) -> str:\n", + " return self._env_name\n", + "\n", + " def reset(self, rng: jnp.ndarray) -> State:\n", + " state = self.env.reset(rng)\n", + " return state\n", + "\n", + " def step(self, state: State, action: jnp.ndarray) -> State:\n", + " state = self.env.step(state, action)\n", + " new_reward = state.reward + environments.reward_offset[self._env_name]\n", + " return state.replace(reward=new_reward)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "seed = 0\n", + "\n", + "env_name = \"ant_omni\"\n", + "episode_length = 250\n", + "min_bd = -30.0\n", + "max_bd = 30.0\n", + "\n", + "num_iterations = 1000\n", + "batch_size = 256\n", + "\n", + "# Archive\n", + "num_init_cvt_samples = 50000\n", + "num_centroids = 1024\n", + "policy_hidden_layer_sizes = (128, 128)\n", + "\n", + "# DCRL-ME\n", + "ga_batch_size = 128\n", + "dcrl_batch_size = 64\n", + "ai_batch_size = 64\n", + "lengthscale = 0.1\n", + "\n", + "# GA emitter\n", + "iso_sigma = 0.005\n", + "line_sigma = 0.05\n", + "\n", + "# DCRL emitter\n", + "critic_hidden_layer_size = (256, 256)\n", + "num_critic_training_steps = 3000\n", + "num_pg_training_steps = 150\n", + "replay_buffer_size = 1_000_000\n", + "discount = 0.99\n", + "reward_scaling = 1.0\n", + "critic_learning_rate = 3e-4\n", + "actor_learning_rate = 3e-4\n", + "policy_learning_rate = 5e-3\n", + "noise_clip = 0.5\n", + "policy_noise = 0.2\n", + "soft_tau_update = 0.005\n", + "policy_delay = 2\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "env = RewardOffsetEnvWrapper(\n", + " env, env_name\n", + ") # apply reward offset as DCG needs positive rewards\n", + "\n", + "reset_fn = jax.jit(env.reset)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "actor_dc_network = MLPDC(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the fonction to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state: EnvState, policy_params: Params, random_key: RNGKey\n", + ") -> Tuple[EnvState, Params, RNGKey, DCRLTransition]:\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = DCRLTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " truncations=next_state.info[\"truncation\"],\n", + " actions=actions,\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " desc=jnp.zeros(\n", + " env.behavior_descriptor_length,\n", + " )\n", + " * jnp.nan,\n", + " desc_prime=jnp.zeros(\n", + " env.behavior_descriptor_length,\n", + " )\n", + " * jnp.nan,\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = behavior_descriptor_extractor[env_name]\n", + "scoring_fn = functools.partial(\n", + " reset_based_scoring_function_brax_envs,\n", + " episode_length=episode_length,\n", + " play_reset_fn=reset_fn,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "metrics_function = functools.partial(\n", + " default_qd_metrics,\n", + " qd_offset=reward_offset * episode_length,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter: DCRL Emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step. In this example, the emitter is the Descriptor-Conditioned RL emitter, the one used in DCRL-ME. It trains a critic with the transitions experienced in the environment and uses the critic to apply Descriptor-Conditioned gradients updates to the policies evolved." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dcrl_emitter_config = DCRLMEConfig(\n", + " ga_batch_size=ga_batch_size,\n", + " dcrl_batch_size=dcrl_batch_size,\n", + " ai_batch_size=ai_batch_size,\n", + " lengthscale=lengthscale,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " num_critic_training_steps=num_critic_training_steps,\n", + " num_pg_training_steps=num_pg_training_steps,\n", + " batch_size=batch_size,\n", + " replay_buffer_size=replay_buffer_size,\n", + " discount=discount,\n", + " reward_scaling=reward_scaling,\n", + " critic_learning_rate=critic_learning_rate,\n", + " actor_learning_rate=actor_learning_rate,\n", + " policy_learning_rate=policy_learning_rate,\n", + " noise_clip=noise_clip,\n", + " policy_noise=policy_noise,\n", + " soft_tau_update=soft_tau_update,\n", + " policy_delay=policy_delay,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "\n", + "dcrl_emitter = DCRLMEEmitter(\n", + " config=dcrl_emitter_config,\n", + " policy_network=policy_network,\n", + " actor_network=actor_dc_network,\n", + " env=env,\n", + " variation_fn=variation_fn,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the MAP Elites algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate MAP Elites\n", + "map_elites = MAPElites(\n", + " scoring_function=scoring_fn,\n", + " emitter=dcrl_emitter,\n", + " metrics_function=metrics_function,\n", + ")\n", + "\n", + "# Compute the centroids\n", + "centroids, random_key = compute_cvt_centroids(\n", + " num_descriptors=env.behavior_descriptor_length,\n", + " num_init_cvt_samples=num_init_cvt_samples,\n", + " num_centroids=num_centroids,\n", + " minval=min_bd,\n", + " maxval=max_bd,\n", + " random_key=random_key,\n", + ")\n", + "\n", + "# compute initial repertoire\n", + "repertoire, emitter_state, random_key = map_elites.init(\n", + " init_params, centroids, random_key\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + " # iterate over grid\n", + " repertoire, emitter_state, metrics, random_key = map_elites.update(*carry)\n", + "\n", + " return (repertoire, emitter_state, random_key), metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "log_period = 10\n", + "num_loops = int(num_iterations / log_period)\n", + "\n", + "csv_logger = CSVLogger(\n", + " \"dcrlme-logs.csv\",\n", + " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + ")\n", + "all_metrics = {}\n", + "\n", + "# main loop\n", + "map_elites_scan_update = map_elites.scan_update\n", + "for i in range(num_loops):\n", + " start_time = time.time()\n", + " # main iterations\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " random_key,\n", + " ), metrics = jax.lax.scan(\n", + " update_scan_fn,\n", + " (repertoire, emitter_state, random_key),\n", + " (),\n", + " length=log_period,\n", + " )\n", + " timelapse = time.time() - start_time\n", + "\n", + " # log metrics\n", + " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", + " for key, value in metrics.items():\n", + " # take last value\n", + " logged_metrics[key] = value[-1]\n", + "\n", + " # take all values\n", + " if key in all_metrics.keys():\n", + " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " else:\n", + " all_metrics[key] = value\n", + "\n", + " csv_logger.log(logged_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "# create the x-axis array\n", + "env_steps = jnp.arange(740) * episode_length * batch_size\n", + "\n", + "%matplotlib inline\n", + "# create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "vscode": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 461567118c3bfe6c5c22a04984f61efd8b63da87 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 23:19:46 +0100 Subject: [PATCH 16/20] add DCRL-ME notebook to README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c537003e..e7955450 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,7 @@ QDax currently supports the following algorithms: | [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | | [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | | [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) | +| [DCRL-ME](https://arxiv.org/abs/2401.08632) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dcrlme.ipynb) | | [QDPG](https://arxiv.org/abs/2006.08505) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/qdpg.ipynb) | | [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmame.ipynb) | | [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/omgmega.ipynb) | From 78d824941c60d910237f7a18b672cdb6e1196637 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 23:11:17 +0000 Subject: [PATCH 17/20] add params colab-compatible --- examples/dcrlme.ipynb | 52 +++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index 36216c47..12a147b6 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -136,45 +136,45 @@ "outputs": [], "source": [ "#@title QD Training Definitions Fields\n", - "seed = 0\n", + "seed = 42 #@param {type:\"integer\"}\n", "\n", - "env_name = \"ant_omni\"\n", - "episode_length = 250\n", - "min_bd = -30.0\n", - "max_bd = 30.0\n", + "env_name = \"ant_omni\" #@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "episode_length = 250 #@param {type:\"integer\"}\n", + "min_bd = -30.0 #@param {type:\"number\"}\n", + "max_bd = 30.0 #@param {type:\"number\"}\n", "\n", - "num_iterations = 1000\n", - "batch_size = 256\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "batch_size = 256 #@param {type:\"integer\"}\n", "\n", "# Archive\n", - "num_init_cvt_samples = 50000\n", - "num_centroids = 1024\n", - "policy_hidden_layer_sizes = (128, 128)\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (128, 128) #@param {type:\"raw\"}\n", "\n", "# DCRL-ME\n", - "ga_batch_size = 128\n", - "dcrl_batch_size = 64\n", - "ai_batch_size = 64\n", - "lengthscale = 0.1\n", + "ga_batch_size = 128 #@param {type:\"integer\"}\n", + "dcrl_batch_size = 64 #@param {type:\"integer\"}\n", + "ai_batch_size = 64 #@param {type:\"integer\"}\n", + "lengthscale = 0.1 #@param {type:\"number\"}\n", "\n", "# GA emitter\n", - "iso_sigma = 0.005\n", - "line_sigma = 0.05\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", "\n", "# DCRL emitter\n", - "critic_hidden_layer_size = (256, 256)\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", "num_critic_training_steps = 3000\n", "num_pg_training_steps = 150\n", "replay_buffer_size = 1_000_000\n", - "discount = 0.99\n", - "reward_scaling = 1.0\n", - "critic_learning_rate = 3e-4\n", - "actor_learning_rate = 3e-4\n", - "policy_learning_rate = 5e-3\n", - "noise_clip = 0.5\n", - "policy_noise = 0.2\n", - "soft_tau_update = 0.005\n", - "policy_delay = 2\n", + "discount = 0.99 #@param {type:\"number\"}\n", + "reward_scaling = 1.0 #@param {type:\"number\"}\n", + "critic_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "actor_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "policy_learning_rate = 5e-3 #@param {type:\"number\"}\n", + "noise_clip = 0.5 #@param {type:\"number\"}\n", + "policy_noise = 0.2 #@param {type:\"number\"}\n", + "soft_tau_update = 0.005 #@param {type:\"number\"}\n", + "policy_delay = 2 #@param {type:\"number\"}\n", "#@markdown ---" ] }, From 0508d7c426eebd2b1753bb24cdaf3d9039bc4b2e Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 5 Sep 2024 23:16:13 +0000 Subject: [PATCH 18/20] update docs --- docs/api_documentation/core/dcrlme.md | 5 +++++ docs/api_documentation/core/map_elites.md | 2 +- mkdocs.yml | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 docs/api_documentation/core/dcrlme.md diff --git a/docs/api_documentation/core/dcrlme.md b/docs/api_documentation/core/dcrlme.md new file mode 100644 index 00000000..698d75eb --- /dev/null +++ b/docs/api_documentation/core/dcrlme.md @@ -0,0 +1,5 @@ +# Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME) + +To create an instance of DCRL-ME, one need to use an instance of [MAP-Elites](map_elites.md) with the `DCRLMEEmitter`, detailed below. + +::: qdax.core.emitters.dcrl_me_emitter.DCRLMEEmitter diff --git a/docs/api_documentation/core/map_elites.md b/docs/api_documentation/core/map_elites.md index 5da3af58..89f1948a 100644 --- a/docs/api_documentation/core/map_elites.md +++ b/docs/api_documentation/core/map_elites.md @@ -2,7 +2,7 @@ This class implement the base mechanism of MAP-Elites. It must be used with an emitter. To get the usual MAP-Elites algorithm, one must use the [mixing emitter](emitters.md#qdax.core.emitters.standard_emitters.MixingEmitter). -The MAP-Elites class can be used with other emitters to create variants, like [PGAME](pgame.md), [CMA-MEGA](cma_mega.md) and [OMG-MEGA](omg_mega.md). +The MAP-Elites class can be used with other emitters to create variants, like [PGAME](pgame.md), [DCRL-ME](dcrlme.md) [CMA-MEGA](cma_mega.md) and [OMG-MEGA](omg_mega.md). ::: qdax.core.map_elites.MAPElites diff --git a/mkdocs.yml b/mkdocs.yml index 168c6ef5..43454ad3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -136,6 +136,7 @@ nav: - Core algorithms: - MAP Elites: api_documentation/core/map_elites.md - PGAME: api_documentation/core/pgame.md + - DCRLME: api_documentation/core/dcrlme.md - QDPG: api_documentation/core/qdpg.md - CMA ME: api_documentation/core/cmame.md - OMG MEGA: api_documentation/core/omg_mega.md From 38801bacfeea676a218543baf4b8efd1caf1443d Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 6 Sep 2024 10:24:51 +0100 Subject: [PATCH 19/20] add dcrl nb to docs --- mkdocs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/mkdocs.yml b/mkdocs.yml index 43454ad3..9207b4f2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -114,6 +114,7 @@ nav: - Examples: - MAPElites: examples/mapelites.ipynb - PGAME: examples/pgame.ipynb + - DCRL-ME: examples/dcrlme.ipynb - CMA ME: examples/cmame.ipynb - QDPG: examples/qdpg.ipynb - OMG MEGA: examples/omgmega.ipynb From fbbaa5e4631e7f12897f19195bcbb61c61749de4 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 6 Sep 2024 10:40:07 +0100 Subject: [PATCH 20/20] use appropriate reward offset and clip wrappers in test and nb --- examples/dcrlme.ipynb | 58 ++++------------------------- qdax/environments/wrappers.py | 31 --------------- tests/baselines_test/dcrlme_test.py | 33 ++++------------ 3 files changed, 15 insertions(+), 107 deletions(-) diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index 12a147b6..5c367e37 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -89,6 +89,7 @@ "from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n", "from qdax.custom_types import EnvState, Params, RNGKey\n", "from qdax.environments import behavior_descriptor_extractor\n", + "from qdax.environments.wrappers import OffsetRewardWrapper, ClipRewardWrapper\n", "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", "from qdax.utils.plotting import plot_map_elites_results\n", "\n", @@ -102,33 +103,6 @@ "clear_output()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class RewardOffsetEnvWrapper(Wrapper):\n", - " \"\"\"Wraps ant_omni environment to add and scale position.\"\"\"\n", - "\n", - " def __init__(self, env: Env, env_name: str) -> None:\n", - " super().__init__(env)\n", - " self._env_name = env_name\n", - "\n", - " @property\n", - " def name(self) -> str:\n", - " return self._env_name\n", - "\n", - " def reset(self, rng: jnp.ndarray) -> State:\n", - " state = self.env.reset(rng)\n", - " return state\n", - "\n", - " def step(self, state: State, action: jnp.ndarray) -> State:\n", - " state = self.env.step(state, action)\n", - " new_reward = state.reward + environments.reward_offset[self._env_name]\n", - " return state.replace(reward=new_reward)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -199,9 +173,12 @@ "\n", "# Init environment\n", "env = environments.create(env_name, episode_length=episode_length)\n", - "env = RewardOffsetEnvWrapper(\n", - " env, env_name\n", - ") # apply reward offset as DCG needs positive rewards\n", + "env = OffsetRewardWrapper(\n", + " env, offset=environments.reward_offset[env_name]\n", + ") # apply reward offset as DCRL needs positive rewards\n", + "env = ClipRewardWrapper(\n", + " env, clip_min=0.,\n", + ") # apply reward clip as DCRL needs positive rewards\n", "\n", "reset_fn = jax.jit(env.reset)\n", "\n", @@ -472,27 +449,6 @@ "# create the plots and the grid\n", "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index e5e40e4b..babedaed 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -102,37 +102,6 @@ def step(self, state: State, action: jp.ndarray) -> State: ) -class AffineRewardWrapper(Wrapper): - """Wraps gym environments to clip the reward. - - Utilisation is simple: create an environment with Brax, pass - it to the wrapper with the name of the environment, and it will - work like before and will simply clip the reward to be greater than 0. - """ - - def __init__( - self, - env: Env, - clip_min: Optional[float] = None, - clip_max: Optional[float] = None, - ) -> None: - super().__init__(env) - self._clip_min = clip_min - self._clip_max = clip_max - - def reset(self, rng: jp.ndarray) -> State: - state = self.env.reset(rng) - return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) - ) - - def step(self, state: State, action: jp.ndarray) -> State: - state = self.env.step(state, action) - return state.replace( - reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) - ) - - class OffsetRewardWrapper(Wrapper): """Wraps gym environments to offset the reward to be greater than 0. diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py index eead8104..05304944 100644 --- a/tests/baselines_test/dcrlme_test.py +++ b/tests/baselines_test/dcrlme_test.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp import pytest -from brax.envs import Env, State, Wrapper from qdax import environments from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids @@ -15,31 +14,11 @@ from qdax.core.neuroevolution.networks.networks import MLP, MLPDC from qdax.custom_types import EnvState, Params, RNGKey from qdax.environments import behavior_descriptor_extractor +from qdax.environments.wrappers import ClipRewardWrapper, OffsetRewardWrapper from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs from qdax.utils.metrics import default_qd_metrics -class RewardOffsetEnvWrapper(Wrapper): - """Wraps ant_omni environment to add and scale position.""" - - def __init__(self, env: Env, env_name: str) -> None: - super().__init__(env) - self._env_name = env_name - - @property - def name(self) -> str: - return self._env_name - - def reset(self, rng: jnp.ndarray) -> State: - state = self.env.reset(rng) - return state - - def step(self, state: State, action: jnp.ndarray) -> State: - state = self.env.step(state, action) - new_reward = state.reward + environments.reward_offset[self._env_name] - return state.replace(reward=new_reward) - - def test_dcrlme() -> None: seed = 42 @@ -86,9 +65,13 @@ def test_dcrlme() -> None: # Init environment env = environments.create(env_name, episode_length=episode_length) - env = RewardOffsetEnvWrapper( - env, env_name - ) # apply reward offset as DCG needs positive rewards + env = OffsetRewardWrapper( + env, offset=environments.reward_offset[env_name] + ) # apply reward offset as DCRL needs positive rewards + env = ClipRewardWrapper( + env, + clip_min=0.0, + ) # apply reward clip as DCRL needs positive rewards reset_fn = jax.jit(env.reset)