From 0f2d6b725c0b4cad61f0f3a0b4f0900a37bfacc1 Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 5 May 2024 16:53:52 +0900 Subject: [PATCH] fix: lerp --- pytorch_optimizer/optimizer/schedulefree.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_optimizer/optimizer/schedulefree.py b/pytorch_optimizer/optimizer/schedulefree.py index 5d4546d06..0898f531d 100644 --- a/pytorch_optimizer/optimizer/schedulefree.py +++ b/pytorch_optimizer/optimizer/schedulefree.py @@ -71,7 +71,7 @@ def eval(self): for p in group['params']: state = self.state[p] if 'z' in state: - p.lerp_(end=state['z'], weight=1.0 - 1.0 / momentum) + p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / momentum) group['train_mode'] = False def train(self): @@ -81,7 +81,7 @@ def train(self): for p in group['params']: state = self.state[p] if 'z' in state: - p.lerp_(end=state['z'], weight=1.0 - momentum) + p.data.lerp_(end=state['z'], weight=1.0 - momentum) group['train_mode'] = True @torch.no_grad() @@ -216,7 +216,7 @@ def eval(self): for p in group['params']: state = self.state[p] if 'z' in state: - p.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1) + p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1) group['train_mode'] = False def train(self): @@ -226,7 +226,7 @@ def train(self): for p in group['params']: state = self.state[p] if 'z' in state: - p.lerp_(end=state['z'], weight=1.0 - beta1) + p.data.lerp_(end=state['z'], weight=1.0 - beta1) group['train_mode'] = True @torch.no_grad()