Skip to content

Commit

Permalink
fix: address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 12, 2024
1 parent 56accdf commit 9f437ed
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 46 deletions.
9 changes: 9 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,19 @@
from .polar_atomic_model import (
DPPolarAtomicModel,
)
from .dos_atomic_model import (
DPDOSAtomicModel
)
from .energy_atomic_model import (
DPEnergyAtomicModel
)


__all__ = [
"BaseAtomicModel",
"DPAtomicModel",
"DPDOSAtomicModel",
"DPEnergyAtomicModel",
"PairTabAtomicModel",
"LinearEnergyAtomicModel",
"DPPolarAtomicModel",
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dp_atomic_model import (
DPAtomicModel,
)

class DPDOSAtomicModel(DPAtomicModel):
pass
7 changes: 7 additions & 0 deletions deepmd/pt/model/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dp_atomic_model import (
DPAtomicModel,
)

class DPEnergyAtomicModel(DPAtomicModel):
pass
12 changes: 2 additions & 10 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
)

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
)


@BaseModel.register("dipole")
class DipoleModel(DPModel, make_model(DPDipoleAtomicModel)):
class DipoleModel(DPModelCommon, make_model(DPDipoleAtomicModel)):
model_type = "dipole"

def __init__(
Expand Down Expand Up @@ -68,14 +68,6 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net

def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor

@torch.jit.export
def forward_lower(
self,
Expand Down
14 changes: 3 additions & 11 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
import torch

from deepmd.pt.model.atomic_model import (
DPAtomicModel,
DPDOSAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
)


@BaseModel.register("dos")
class DOSModel(DPModel, make_model(DPAtomicModel)):
class DOSModel(DPModelCommon, make_model(DPDOSAtomicModel)):
model_type = "dos"

def __init__(
Expand Down Expand Up @@ -61,14 +61,6 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net

def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor

@torch.jit.export
def get_numb_dos(self) -> int:
"""Get the number of DOS for DOSFittingNet."""
Expand Down
11 changes: 10 additions & 1 deletion deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
)


class DPModel:
class DPModelCommon:
"""A base class to implement common methods for all the Models."""

@classmethod
Expand All @@ -23,3 +23,12 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy


def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net

def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor
4 changes: 2 additions & 2 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
Expand All @@ -24,7 +24,7 @@


@BaseModel.register("zbl")
class DPZBLModel(DPModel, DPZBLModel_):
class DPZBLModel(DPModelCommon, DPZBLModel_):
model_type = "ener"

def __init__(
Expand Down
16 changes: 4 additions & 12 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
import torch

from deepmd.pt.model.atomic_model import (
DPAtomicModel,
DPEnergyAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
)


@BaseModel.register("standard")
class EnergyModel(DPModel, make_model(DPAtomicModel)):
@BaseModel.register("energy")
class EnergyModel(DPModelCommon, make_model(DPEnergyAtomicModel)):
model_type = "ener"

def __init__(
Expand Down Expand Up @@ -70,14 +70,6 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net

def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor

@torch.jit.export
def forward_lower(
self,
Expand Down
12 changes: 2 additions & 10 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
)

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
)


@BaseModel.register("polar")
class PolarModel(DPModel, make_model(DPPolarAtomicModel)):
class PolarModel(DPModelCommon, make_model(DPPolarAtomicModel)):
model_type = "polar"

def __init__(
Expand Down Expand Up @@ -60,14 +60,6 @@ def forward(
model_predict["updated_coord"] += coord
return model_predict

def get_fitting_net(self):
"""Get the fitting network."""
return self.atomic_model.fitting_net

def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor

@torch.jit.export
def forward_lower(
self,
Expand Down

0 comments on commit 9f437ed

Please sign in to comment.