Skip to content

Commit

Permalink
Backport PR #2396: Switch from optax.additive_weight_decay to `opta…
Browse files Browse the repository at this point in the history
…x.add_decayed_weights` (#2397)

Co-authored-by: Martin Kim <[email protected]>
  • Loading branch information
meeseeksmachine and martinkim0 authored Jan 17, 2024
1 parent 921f44e commit fc75cfe
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
metrics on a subset of the data {pr}`2361`.
- Add `seed` argument to {func}`scvi.model.utils.mde` for reproducibility {pr}`2373`.
- Add {meth}`scvi.hub.HubModel.save` and {meth}`scvi.hub.HubMetadata.save` {pr}`2382`.
- Add support for Optax 0.1.8 by renaming instances of {func}`optax.additive_weight_decay` to
{func}`optax.add_weight_decay` {pr}`2396`.

#### Fixed

Expand Down
2 changes: 1 addition & 1 deletion scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ def get_optimizer_creator(self) -> JaxOptimizerCreator:
# Replicates PyTorch Adam defaults
optim = optax.chain(
clip_by,
optax.additive_weight_decay(weight_decay=self.weight_decay),
optax.add_decayed_weights(weight_decay=self.weight_decay),
optax.adam(self.lr, eps=self.eps),
)
elif self.optimizer_name == "AdamW":
Expand Down

0 comments on commit fc75cfe

Please sign in to comment.