Skip to content

Simple, extensible implementations of some meta-learning algorithms in Jax

Notifications You must be signed in to change notification settings

alexub/jax-meta-learning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 

Repository files navigation

jax-meta-learning

Simple, flexible implementations of some meta-learning algorithms in Jax.

The goal is that you should be able to just specify hyperparameters and "drop in" your choice of model, gradient-based optimizers, and distribution over tasks, and these implementations should work out-of-the-box with minimal code overhead, whether your tasks are classification, regression, reinforcement learning, or something weird and wonderful.

The caveats are that you need to use Flax models/optimizers (or write classes with similar API), and your "tasks" must be written as functions which map from a random seed and a model to a scalar loss. The MAML implementation also does not include improvements added by subsequent papers, such as trainable inner-loop learning rates.

Requires:

Done

Todo

  • Usage guide, incl. adding code snippets to README
  • Examples on Omniglot
  • Migrate examples from "maml.py" etc to their own folder

About

Simple, extensible implementations of some meta-learning algorithms in Jax

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages