diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 3112887b05..352fa90e5c 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -53,6 +53,8 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/ metrics on a subset of the data {pr}`2361`. - Add `seed` argument to {func}`scvi.model.utils.mde` for reproducibility {pr}`2373`. - Add {meth}`scvi.hub.HubModel.save` and {meth}`scvi.hub.HubMetadata.save` {pr}`2382`. +- Add support for Optax 0.1.8 by renaming instances of {func}`optax.additive_weight_decay` to + {func}`optax.add_weight_decay` {pr}`2396`. #### Fixed diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index 3cea8f8f85..8c1601c37a 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -1242,7 +1242,7 @@ def get_optimizer_creator(self) -> JaxOptimizerCreator: # Replicates PyTorch Adam defaults optim = optax.chain( clip_by, - optax.additive_weight_decay(weight_decay=self.weight_decay), + optax.add_decayed_weights(weight_decay=self.weight_decay), optax.adam(self.lr, eps=self.eps), ) elif self.optimizer_name == "AdamW":