metax
is a meta-learning library in jax for research.
It bundles various meta-learning algorithms and architectures that can be flexibly combined and is simple to extend.
It includes the following components
metax/learner
- maml.py: Backpropagation through the optimization as in MAML
- eqprop.py: Equilibrium propagation as in CML
- evolution.py: Evolutionary algorithms interfacing with
evosax
- implicit.py: Conjugate Gradient, Recurrent Backpropagation, T1T2
- reptile.py: Reptile
metax/module
Install metax
using pip:
pip install git+https://github.com/smonsays/metax
The classic MAML model meta-learns the initialization of the model parameters by backpropagating through the optimizer. In metax
the code would look as follows:
meta_model = metax.module.LearnedInit(
loss_fn_inner=metax.energy.SquaredError(),
loss_fn_outer=metax.energy.SquaredError(),
base_learner=hk.transform_with_state(
lambda x, is_training: hk.nets.MLP([64, 64, 1])(x)
),
output_dim=1,
)
meta_learner = metax.learner.ModelAgnosticMetaLearning(
meta_model=meta_model,
batch_size=None, # full batch GD
steps_inner=10,
optim_fn_inner=optax.adamw(0.1),
optim_fn_outer=optax.adamw(0.001),
first_order=False,
)
examples/
contains a number of educational examples that demonstrate various combinations of meta-algorithms with meta-architectures on a simple regression task.
If you use metax
in your research, please cite it as:
@software{metax2023,
title={metax: a jax meta-learning library},
author={Schug, Simon},
url = {http://github.com/smonsays/metax},
year={2023}
}
Research supported with Cloud TPUs from Google's TPU Research Cloud (TRC).