From 5455604a5b029658e971e0f6b7c2a9cac2fb7529 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 31 Jul 2024 14:47:55 +0000 Subject: [PATCH] check for correct foreach fns --- adam_atan2_pytorch/foreach.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/adam_atan2_pytorch/foreach.py b/adam_atan2_pytorch/foreach.py index ecceae1..1ff8636 100644 --- a/adam_atan2_pytorch/foreach.py +++ b/adam_atan2_pytorch/foreach.py @@ -38,7 +38,7 @@ def __init__( assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) assert weight_decay >= 0. - assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' + assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'lerp', 'sqrt')]), 'this version of torch does not have the prerequisite foreach functions' self._init_lr = lr diff --git a/pyproject.toml b/pyproject.toml index 832caa6..dc26459 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "adam-atan2-pytorch" -version = "0.0.7" +version = "0.0.8" description = "Adam-atan2 for Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }