diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 198546f474..1eb118e575 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -13,13 +13,14 @@ BaseModel, ) +from .dp_model import DPModel from .make_model import ( make_model, ) @BaseModel.register("standard") -class DipoleModel(make_model(DPDipoleAtomicModel)): +class DipoleModel(make_model(DPDipoleAtomicModel), DPModel): model_type = "dipole" def __init__( @@ -65,23 +66,6 @@ def forward( model_predict["updated_coord"] += coord return model_predict - @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - global_jdata : dict - The global data, containing the training section - local_jdata : dict - The local data refer to the current class - """ - local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["descriptor"] = cls.get_descriptor().update_sel( - global_jdata, local_jdata["descriptor"] - ) - return local_jdata_cpy - def get_fitting_net(self): """Get the fitting network.""" return self.atomic_model.fitting_net diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py index a7a656d6bf..0bfe711067 100644 --- a/deepmd/pt/model/model/dos_model.py +++ b/deepmd/pt/model/model/dos_model.py @@ -12,14 +12,14 @@ from deepmd.pt.model.model.model import ( BaseModel, ) - +from .dp_model import DPModel from .make_model import ( make_model, ) @BaseModel.register("standard") -class DOSModel(make_model(DPAtomicModel)): +class DOSModel(make_model(DPAtomicModel), DPModel): model_type = "dos" def __init__( @@ -58,23 +58,6 @@ def forward( model_predict["updated_coord"] += coord return model_predict - @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - global_jdata : dict - The global data, containing the training section - local_jdata : dict - The local data refer to the current class - """ - local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["descriptor"] = cls.get_descriptor().update_sel( - global_jdata, local_jdata["descriptor"] - ) - return local_jdata_cpy - def get_fitting_net(self): """Get the fitting network.""" return self.atomic_model.fitting_net diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py new file mode 100644 index 0000000000..19a9a74db5 --- /dev/null +++ b/deepmd/pt/model/model/dp_model.py @@ -0,0 +1,23 @@ +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) + +class DPModel: + """A base class to implement common methods for all the Models. """ + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel( + global_jdata, local_jdata["descriptor"] + ) + return local_jdata_cpy \ No newline at end of file diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index bbc82b8d77..800f222765 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -20,11 +20,12 @@ make_model, ) +from .dp_model import DPModel DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel) @BaseModel.register("zbl") -class DPZBLModel(DPZBLModel_): +class DPZBLModel(DPZBLModel_, DPModel): model_type = "ener" def __init__( @@ -103,20 +104,3 @@ def forward_lower( assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] return model_predict - - @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - global_jdata : dict - The global data, containing the training section - local_jdata : dict - The local data refer to the current class - """ - local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["dpmodel"] = DPModel.update_sel( - global_jdata, local_jdata["dpmodel"] - ) - return local_jdata_cpy diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 5217293623..932b115f33 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -9,9 +9,18 @@ from .dp_model import ( DPModel, ) +from .make_model import ( + make_model, +) +from deepmd.pt.model.atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) - -class EnergyModel(DPModel): +@BaseModel.register("standard") +class EnergyModel(make_model(DPAtomicModel),DPModel): model_type = "ener" def __init__( diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 4eff8992d1..cee497f764 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -12,14 +12,14 @@ from deepmd.pt.model.model.model import ( BaseModel, ) - +from .dp_model import DPModel from .make_model import ( make_model, ) @BaseModel.register("standard") -class PolarModel(make_model(DPPolarAtomicModel)): +class PolarModel(make_model(DPPolarAtomicModel), DPModel): model_type = "polar" def __init__( @@ -57,23 +57,6 @@ def forward( model_predict["updated_coord"] += coord return model_predict - @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - global_jdata : dict - The global data, containing the training section - local_jdata : dict - The local data refer to the current class - """ - local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["descriptor"] = cls.get_descriptor().update_sel( - global_jdata, local_jdata["descriptor"] - ) - return local_jdata_cpy - def get_fitting_net(self): """Get the fitting network.""" return self.atomic_model.fitting_net diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index df2f48e2e4..6cc523e7d3 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -18,8 +18,11 @@ Spin, ) -from .dp_model import ( - DPModel, +from .make_model import ( + make_model, +) +from deepmd.pt.model.atomic_model import ( + DPAtomicModel, ) @@ -474,7 +477,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "SpinModel": - backbone_model_obj = DPModel.deserialize(data["backbone_model"]) + backbone_model_obj = make_model(DPAtomicModel).deserialize(data["backbone_model"]) spin = Spin.deserialize(data["spin"]) return cls( backbone_model=backbone_model_obj,