Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: refactor atomic bias #3654

Merged
merged 32 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9f20564
chore: change name to atom
anyangml Apr 8, 2024
9a89d74
Merge branch 'devel' into feat/atomic-bias
anyangml Apr 8, 2024
7165924
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
74cdba0
Merge branch 'devel' into feat/atomic-bias
anyangml Apr 8, 2024
6c299fa
chore: refactor global stat
anyangml Apr 8, 2024
80159e1
chore: refactor global stat
anyangml Apr 8, 2024
a83e757
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
05a06fc
Merge branch 'devel' into feat/atomic-bias
anyangml Apr 8, 2024
ef95a9c
feat: add atomic bias
anyangml Apr 8, 2024
f9278eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
f6ebec1
fix: precommit
anyangml Apr 8, 2024
f16fdb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
74b3795
fix: keys
anyangml Apr 8, 2024
1de9e71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
c0a14ea
chore: clean code
anyangml Apr 8, 2024
3851137
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
3be7e91
fix: UTs
anyangml Apr 8, 2024
1c43712
Merge branch 'devel' into feat/atomic-bias
anyangml Apr 9, 2024
b99afa0
feat: add UT missing atype
anyangml Apr 9, 2024
3baa28c
Merge branch 'devel' into feat/atomic-bias
anyangml Apr 9, 2024
bb47541
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
f504f07
fix: UTs
anyangml Apr 9, 2024
64427c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
48f10cb
fix: UTs
anyangml Apr 9, 2024
d57d561
fix: precommit
anyangml Apr 9, 2024
62288b1
chore: revert breaking changes
anyangml Apr 10, 2024
267264b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
99f18da
chore: revert breaking changes
anyangml Apr 10, 2024
fcfeed6
chore: refactor code
anyangml Apr 10, 2024
398fbac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
2cf195d
chore: refactor code
anyangml Apr 10, 2024
c585206
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@

Parameters
----------
var_name
The name of the output variable.
ntypes
The number of atom types.
dim_descrpt
Expand Down Expand Up @@ -86,7 +84,6 @@

def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
Expand Down Expand Up @@ -124,7 +121,7 @@
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name=var_name,
var_name="dipole",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -161,6 +158,7 @@
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
assert data.pop("var_name", None) == "dipole"
Fixed Show fixed Hide fixed
return super().deserialize(data)

def output_def(self):
Expand Down
6 changes: 2 additions & 4 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@

Parameters
----------
var_name
The name of the output variable.
ntypes
The number of atom types.
dim_descrpt
Expand Down Expand Up @@ -88,7 +86,6 @@

def __init__(
self,
var_name: str,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
Expand Down Expand Up @@ -145,7 +142,7 @@
self.shift_diag = shift_diag
self.constant_matrix = np.zeros(ntypes, dtype=GLOBAL_NP_FLOAT_PRECISION)
super().__init__(
var_name=var_name,
var_name="polar",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -201,6 +198,7 @@
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
assert data.pop("var_name", None) == "polar"
Fixed Show fixed Hide fixed
return super().deserialize(data)

def output_def(self):
Expand Down
5 changes: 2 additions & 3 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class DipoleFittingNet(GeneralFitting):

Parameters
----------
var_name : str
The atomic property to fit, 'dipole'.
ntypes : int
Element count.
dim_descrpt : int
Expand Down Expand Up @@ -97,7 +95,7 @@ def __init__(
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
super().__init__(
var_name=kwargs.pop("var_name", "dipole"),
var_name="dipole",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -131,6 +129,7 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name", None)
return super().deserialize(data)

def output_def(self) -> FittingOutputDef:
Expand Down
5 changes: 2 additions & 3 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class PolarFittingNet(GeneralFitting):

Parameters
----------
var_name : str
The atomic property to fit, 'polar'.
ntypes : int
Element count.
dim_descrpt : int
Expand Down Expand Up @@ -127,7 +125,7 @@ def __init__(
ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
super().__init__(
var_name=kwargs.pop("var_name", "polar"),
var_name="polar",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
neuron=neuron,
Expand Down Expand Up @@ -180,6 +178,7 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("var_name", None)
return super().deserialize(data)

def output_def(self) -> FittingOutputDef:
Expand Down
Loading
Loading