Skip to content

Commit

Permalink
Add CHGNet.version property (#86)
Browse files Browse the repository at this point in the history
* add CHGNet property version

* refactor CHGNet.load() and pass model_name as version to CHGNet.__init__()

* add test_model_load()

* use new pymatgen (Structure|Molecule).to_ase_atoms() convenience methods

* fixed typo

* fix pkg not bundling pretrained checkpoints

* pkg also bundle readmes for pretrained checkpoints

* default model initialization to be same as 0.3.0 weights

---------

Co-authored-by: BowenD-UCB <[email protected]>
  • Loading branch information
janosh and bowen-bd authored Oct 23, 2023
1 parent ee03e31 commit 74a6a70
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 38 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
rev: v0.1.1
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.0
hooks:
- id: black-jupyter

Expand Down Expand Up @@ -49,7 +49,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.51.0
rev: v8.52.0
hooks:
- id: eslint
types: [file]
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def relax(
A dictionary with 'final_structure' and 'trajectory'.
"""
if isinstance(atoms, Structure):
atoms = AseAtomsAdaptor.get_atoms(atoms)
atoms = atoms.to_ase_atoms()

atoms.calc = self.calculator # assign model used to predict forces

Expand Down Expand Up @@ -432,7 +432,7 @@ def __init__(
self.ensemble = ensemble
self.thermostat = thermostat
if isinstance(atoms, (Structure, Molecule)):
atoms = AseAtomsAdaptor.get_atoms(atoms)
atoms = atoms.to_ase_atoms()

self.atoms = atoms
if isinstance(model, CHGNetCalculator):
Expand Down
77 changes: 46 additions & 31 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
if TYPE_CHECKING:
from chgnet import PredTask

module_dir = os.path.dirname(os.path.abspath(__file__))


class CHGNet(nn.Module):
"""Crystal Hamiltonian Graph neural Network
Expand All @@ -38,8 +40,8 @@ def __init__(
bond_fea_dim: int = 64,
angle_fea_dim: int = 64,
composition_model: str | nn.Module = "MPtrj",
num_radial: int = 9,
num_angular: int = 9,
num_radial: int = 31,
num_angular: int = 31,
n_conv: int = 4,
atom_conv_hidden_dim: Sequence[int] | int = 64,
update_bond: bool = True,
Expand All @@ -48,19 +50,22 @@ def __init__(
angle_layer_hidden_dim: Sequence[int] | int = 0,
conv_dropout: float = 0,
read_out: str = "ave",
mlp_hidden_dims: Sequence[int] | int = (64, 64),
mlp_hidden_dims: Sequence[int] | int = (64, 64, 64),
mlp_dropout: float = 0,
mlp_first: bool = True,
is_intensive: bool = True,
non_linearity: Literal["silu", "relu", "tanh", "gelu"] = "silu",
atom_graph_cutoff: float = 5,
atom_graph_cutoff: float = 6,
bond_graph_cutoff: float = 3,
graph_converter_algorithm: Literal["legacy", "fast"] = "fast",
cutoff_coeff: int = 5,
cutoff_coeff: int = 8,
learnable_rbf: bool = True,
gMLP_norm: str | None = "layer",
readout_norm: str | None = "layer",
version: str | None = None,
**kwargs,
) -> None:
"""Initialize the CHGNet.
"""Initialize CHGNet.
Args:
atom_fea_dim (int): atom feature vector embedding dimension.
Expand Down Expand Up @@ -135,6 +140,11 @@ def __init__(
learnable_rbf (bool): whether to set the frequencies in rbf and Fourier
basis functions learnable.
Default = True
gMLP_norm (str): normalization layer to use in gate-MLP
Default = 'layer'
readout_norm (str): normalization layer to use before readout layer
Default = 'layer'
version (str): Pretrained checkpoint version.
**kwargs: Additional keyword arguments
"""
# Store model args for reconstruction
Expand All @@ -144,6 +154,8 @@ def __init__(
if k not in ["self", "__class__", "kwargs"]
}
self.model_args.update(kwargs)
if version:
self.model_args["version"] = version

super().__init__()
self.atom_fea_dim = atom_fea_dim
Expand Down Expand Up @@ -200,7 +212,6 @@ def __init__(

# Define convolutional layers
conv_norm = kwargs.pop("conv_norm", None)
gMLP_norm = kwargs.pop("gMLP_norm", None)
mlp_out_bias = kwargs.pop("mlp_out_bias", False)
atom_graph_layers = [
AtomConv(
Expand Down Expand Up @@ -261,9 +272,7 @@ def __init__(

# Define readout layer
self.site_wise = nn.Linear(atom_fea_dim, 1)
self.readout_norm = find_normalization(
name=kwargs.pop("readout_norm", None), dim=atom_fea_dim
)
self.readout_norm = find_normalization(readout_norm, dim=atom_fea_dim)
self.mlp_first = mlp_first
if mlp_first:
self.read_out_type = "sum"
Expand Down Expand Up @@ -306,19 +315,23 @@ def __init__(
f"parameters"
)

@property
def version(self) -> str | None:
"""Return the version of the loaded checkpoint."""
return self.model_args.get("version")

def forward(
self,
graphs: Sequence[CrystalGraph],
task: PredTask = "e",
return_site_energies: bool = False,
return_atom_feas: bool = False,
return_crystal_feas: bool = False,
) -> dict:
) -> dict[str, Tensor]:
"""Get prediction associated with input graphs
Args:
graphs (List): a list of CrystalGraphs
task (str): the prediction task
eg: 'e', 'em', 'ef', 'efs', 'efsm'
task (str): the prediction task. One of 'e', 'em', 'ef', 'efs', 'efsm'.
Default = 'e'
return_site_energies (bool): whether to return per-site energies,
only available if self.mlp_first == True
Expand Down Expand Up @@ -651,26 +664,28 @@ def from_file(cls, path, **kwargs):

@classmethod
def load(cls, model_name="0.3.0"):
"""Load pretrained CHGNet."""
current_dir = os.path.dirname(os.path.abspath(__file__))
if model_name == "0.3.0":
return cls.from_file(
os.path.join(
current_dir,
"../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar",
)
)
elif model_name == "0.2.0": # noqa: RET505
return cls.from_file(
os.path.join(
current_dir,
"../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar",
),
mlp_out_bias=True,
)
else:
"""Load pretrained CHGNet model.
Args:
model_name (str, optional): Defaults to "0.3.0".
Raises:
ValueError: On unknown model_name.
"""
checkpoint_path = {
"0.3.0": "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar",
"0.2.0": "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar",
}.get(model_name)

if checkpoint_path is None:
raise ValueError(f"Unknown {model_name=}")

return cls.from_file(
os.path.join(module_dir, checkpoint_path),
mlp_out_bias=model_name == "0.2.0",
version=model_name,
)


@dataclass
class BatchedGraph:
Expand Down
2 changes: 1 addition & 1 deletion chgnet/pretrained/0.2.0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ trainer = Trainer(
| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) |
| ---------- | ----------------- | ------------- | ------------ | ------------ |
| Train | 22 | 59 | 0.246 | 0.030 |
| Validation | 20 | 75 | 0.350 | 0.033 |
| Validation | 30 | 75 | 0.350 | 0.033 |
| Test | 30 | 77 | 0.348 | 0.032 |
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }

[tool.setuptools.package-data]
"chgnet" = ["*.json"]
"chgnet.pretrained" = ["*.tar"]
"chgnet.pretrained" = ["**/*"]

[tool.ruff]
target-version = "py39"
Expand Down
18 changes: 18 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,21 @@ def test_as_to_from_dict() -> None:

model_3 = CHGNet(**to_dict["model_args"])
assert model_3.todict() == to_dict


def test_model_load(capsys: pytest.CaptureFixture) -> None:
model = CHGNet.load()
assert model.version == "0.3.0"
stdout, stderr = capsys.readouterr()
assert stdout == "CHGNet initialized with 412,525 parameters\n"
assert stderr == ""

model = CHGNet.load(model_name="0.2.0")
assert model.version == "0.2.0"
stdout, stderr = capsys.readouterr()
assert stdout == "CHGNet initialized with 400,438 parameters\n"
assert stderr == ""

model_name = "0.1.0" # invalid
with pytest.raises(ValueError, match=f"Unknown {model_name=}"):
CHGNet.load(model_name=model_name)

0 comments on commit 74a6a70

Please sign in to comment.