Skip to content

Commit

Permalink
fix: import
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 11, 2024
1 parent 81573d5 commit 3ad9637
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 79 deletions.
20 changes: 2 additions & 18 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
BaseModel,
)

from .dp_model import DPModel
from .make_model import (
make_model,
)


@BaseModel.register("standard")
class DipoleModel(make_model(DPDipoleAtomicModel)):
class DipoleModel(make_model(DPDipoleAtomicModel), DPModel):
model_type = "dipole"

def __init__(
Expand Down Expand Up @@ -65,23 +66,6 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = cls.get_descriptor().update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net
Expand Down
21 changes: 2 additions & 19 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from deepmd.pt.model.model.model import (
BaseModel,
)

from .dp_model import DPModel
from .make_model import (
make_model,
)


@BaseModel.register("standard")
class DOSModel(make_model(DPAtomicModel)):
class DOSModel(make_model(DPAtomicModel), DPModel):
model_type = "dos"

def __init__(
Expand Down Expand Up @@ -58,23 +58,6 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = cls.get_descriptor().update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net
Expand Down
23 changes: 23 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)

class DPModel:
"""A base class to implement common methods for all the Models. """

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
20 changes: 2 additions & 18 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
make_model,
)

from .dp_model import DPModel
DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel)


@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_):
class DPZBLModel(DPZBLModel_, DPModel):
model_type = "ener"

def __init__(
Expand Down Expand Up @@ -103,20 +104,3 @@ def forward_lower(
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
return model_predict

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"] = DPModel.update_sel(
global_jdata, local_jdata["dpmodel"]
)
return local_jdata_cpy
13 changes: 11 additions & 2 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,18 @@
from .dp_model import (
DPModel,
)
from .make_model import (
make_model,
)
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)


class EnergyModel(DPModel):
@BaseModel.register("standard")
class EnergyModel(make_model(DPAtomicModel),DPModel):
model_type = "ener"

def __init__(
Expand Down
21 changes: 2 additions & 19 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from deepmd.pt.model.model.model import (
BaseModel,
)

from .dp_model import DPModel
from .make_model import (
make_model,
)


@BaseModel.register("standard")
class PolarModel(make_model(DPPolarAtomicModel)):
class PolarModel(make_model(DPPolarAtomicModel), DPModel):
model_type = "polar"

def __init__(
Expand Down Expand Up @@ -57,23 +57,6 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = cls.get_descriptor().update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
Spin,
)

from .dp_model import (
DPModel,
from .make_model import (
make_model,
)
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)


Expand Down Expand Up @@ -474,7 +477,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "SpinModel":
backbone_model_obj = DPModel.deserialize(data["backbone_model"])
backbone_model_obj = make_model(DPAtomicModel).deserialize(data["backbone_model"])
spin = Spin.deserialize(data["spin"])
return cls(
backbone_model=backbone_model_obj,
Expand Down

0 comments on commit 3ad9637

Please sign in to comment.