Skip to content

Commit

Permalink
Merge branch 'devel' into feat/dos-train
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml authored Mar 20, 2024
2 parents a73d392 + 9c861c2 commit bf8fac2
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 17 deletions.
8 changes: 8 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import copy
import json

import numpy as np

from deepmd.pt.model.atomic_model import (
DPAtomicModel,
PairTabAtomicModel,
Expand Down Expand Up @@ -57,6 +59,12 @@

def get_spin_model(model_params):
model_params = copy.deepcopy(model_params)
if not model_params["spin"]["use_spin"] or isinstance(
model_params["spin"]["use_spin"][0], int
):
use_spin = np.full(len(model_params["type_map"]), False)
use_spin[model_params["spin"]["use_spin"]] = True
model_params["spin"]["use_spin"] = use_spin.tolist()
# include virtual spin and placeholder types
model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]]
spin = Spin(
Expand Down
44 changes: 31 additions & 13 deletions deepmd/pt/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,35 @@


class LearningRateExp:
def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs):
"""Construct an exponential-decayed learning rate.
def __init__(
self,
start_lr,
stop_lr,
decay_steps,
stop_steps,
decay_rate=None,
**kwargs,
):
"""
Construct an exponential-decayed learning rate.
Args:
- start_lr: Initial learning rate.
- stop_lr: Learning rate at the last step.
- decay_steps: Decay learning rate every N steps.
- stop_steps: When is the last step.
Parameters
----------
start_lr
The learning rate at the start of the training.
stop_lr
The desired learning rate at the end of the training.
When decay_rate is explicitly set, this value will serve as
the minimum learning rate during training. In other words,
if the learning rate decays below stop_lr, stop_lr will be applied instead.
decay_steps
The learning rate is decaying every this number of training steps.
stop_steps
The total training steps for learning rate scheduler.
decay_rate
The decay rate for the learning rate.
If provided, the decay rate will be set instead of
calculating it through interpolation between start_lr and stop_lr.
"""
self.start_lr = start_lr
default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
Expand All @@ -20,12 +41,9 @@ def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs):
self.decay_rate = np.exp(
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
)
if "decay_rate" in kwargs:
self.decay_rate = kwargs["decay_rate"]
if "min_lr" in kwargs:
self.min_lr = kwargs["min_lr"]
else:
self.min_lr = 3e-10
if decay_rate is not None:
self.decay_rate = decay_rate
self.min_lr = stop_lr

def value(self, step):
"""Get the learning rate at the given step."""
Expand Down
30 changes: 26 additions & 4 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def type_embedding_args():


def spin_args():
doc_use_spin = "Whether to use atomic spin model for each atom type"
doc_use_spin = (
"Whether to use atomic spin model for each atom type. "
"List of boolean values with the shape of [ntypes] to specify which types use spin, "
f"or a list of integer values {doc_only_pt_supported} "
"to indicate the index of the type that uses spin."
)
doc_spin_norm = "The magnitude of atomic spin for each atom type with spin"
doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin"
doc_virtual_scale = (
Expand All @@ -106,7 +111,7 @@ def spin_args():
)

return [
Argument("use_spin", List[bool], doc=doc_use_spin),
Argument("use_spin", [List[bool], List[int]], doc=doc_use_spin),
Argument(
"spin_norm",
List[float],
Expand All @@ -121,7 +126,7 @@ def spin_args():
),
Argument(
"virtual_scale",
List[float],
[List[float], float],
optional=True,
doc=doc_only_pt_supported + doc_virtual_scale,
),
Expand Down Expand Up @@ -1517,15 +1522,32 @@ def linear_ener_model_args() -> Argument:
# --- Learning rate configurations: --- #
def learning_rate_exp():
doc_start_lr = "The learning rate at the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training."
doc_stop_lr = (
"The desired learning rate at the end of the training. "
f"When decay_rate {doc_only_pt_supported}is explicitly set, "
"this value will serve as the minimum learning rate during training. "
"In other words, if the learning rate decays below stop_lr, stop_lr will be applied instead."
)
doc_decay_steps = (
"The learning rate is decaying every this number of training steps."
)
doc_decay_rate = (
"The decay rate for the learning rate. "
"If this is provided, it will be used directly as the decay rate for learning rate "
"instead of calculating it through interpolation between start_lr and stop_lr."
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=5000, doc=doc_decay_steps),
Argument(
"decay_rate",
float,
optional=True,
default=None,
doc=doc_only_pt_supported + doc_decay_rate,
),
]
return args

Expand Down
47 changes: 47 additions & 0 deletions source/tests/pt/test_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_consistency(self):
self.decay_step = decay_step
self.stop_step = stop_step
self.judge_it()
self.decay_rate_pt()

def judge_it(self):
base_lr = learning_rate.LearningRateExp(
Expand Down Expand Up @@ -54,6 +55,52 @@ def judge_it(self):
self.assertTrue(np.allclose(base_vals, my_vals))
tf.reset_default_graph()

def decay_rate_pt(self):
my_lr = LearningRateExp(
self.start_lr, self.stop_lr, self.decay_step, self.stop_step
)

default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1
if self.decay_step >= self.stop_step:
self.decay_step = default_ds
decay_rate = np.exp(
np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step)
)
my_lr_decay = LearningRateExp(
self.start_lr,
1e-10,
self.decay_step,
self.stop_step,
decay_rate=decay_rate,
)
min_lr = 1e-5
my_lr_decay_trunc = LearningRateExp(
self.start_lr,
min_lr,
self.decay_step,
self.stop_step,
decay_rate=decay_rate,
)
my_vals = [
my_lr.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
my_vals_decay = [
my_lr_decay.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
my_vals_decay_trunc = [
my_lr_decay_trunc.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
self.assertTrue(np.allclose(my_vals_decay, my_vals))
self.assertTrue(
np.allclose(my_vals_decay_trunc, np.clip(my_vals, a_min=min_lr, a_max=None))
)


if __name__ == "__main__":
unittest.main()

0 comments on commit bf8fac2

Please sign in to comment.