Skip to content

Commit

Permalink
fix: lerp
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed May 5, 2024
1 parent 518e32d commit 0f2d6b7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pytorch_optimizer/optimizer/schedulefree.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def eval(self):
for p in group['params']:
state = self.state[p]
if 'z' in state:
p.lerp_(end=state['z'], weight=1.0 - 1.0 / momentum)
p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / momentum)
group['train_mode'] = False

def train(self):
Expand All @@ -81,7 +81,7 @@ def train(self):
for p in group['params']:
state = self.state[p]
if 'z' in state:
p.lerp_(end=state['z'], weight=1.0 - momentum)
p.data.lerp_(end=state['z'], weight=1.0 - momentum)
group['train_mode'] = True

@torch.no_grad()
Expand Down Expand Up @@ -216,7 +216,7 @@ def eval(self):
for p in group['params']:
state = self.state[p]
if 'z' in state:
p.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1)
p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1)
group['train_mode'] = False

def train(self):
Expand All @@ -226,7 +226,7 @@ def train(self):
for p in group['params']:
state = self.state[p]
if 'z' in state:
p.lerp_(end=state['z'], weight=1.0 - beta1)
p.data.lerp_(end=state['z'], weight=1.0 - beta1)
group['train_mode'] = True

@torch.no_grad()
Expand Down

0 comments on commit 0f2d6b7

Please sign in to comment.