Skip to content

Latest commit

 

History

History
65 lines (54 loc) · 2.94 KB

README.md

File metadata and controls

65 lines (54 loc) · 2.94 KB

metax

metax is a meta-learning library in jax for research. It bundles various meta-learning algorithms and architectures that can be flexibly combined and is simple to extend. It includes the following components

Installation

Install metax using pip:

pip install git+https://github.com/smonsays/metax

Examples

The classic MAML model meta-learns the initialization of the model parameters by backpropagating through the optimizer. In metax the code would look as follows:

meta_model = metax.module.LearnedInit(
    loss_fn_inner=metax.energy.SquaredError(),
    loss_fn_outer=metax.energy.SquaredError(),
    base_learner=hk.transform_with_state(
        lambda x, is_training: hk.nets.MLP([64, 64, 1])(x)
    ),
    output_dim=1,
)
meta_learner = metax.learner.ModelAgnosticMetaLearning(
    meta_model=meta_model,
    batch_size=None,  # full batch GD
    steps_inner=10,
    optim_fn_inner=optax.adamw(0.1),
    optim_fn_outer=optax.adamw(0.001),
    first_order=False,
)

examples/ contains a number of educational examples that demonstrate various combinations of meta-algorithms with meta-architectures on a simple regression task.

Citation

If you use metax in your research, please cite it as:

@software{metax2023,
  title={metax: a jax meta-learning library},
  author={Schug, Simon},
  url = {http://github.com/smonsays/metax},
  year={2023}
}

Acknowledgements

Research supported with Cloud TPUs from Google's TPU Research Cloud (TRC).