From 74a6a70c4ddfbd1fa3afb5cbcfff827b89f9502d Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 23 Oct 2023 07:17:30 -0700 Subject: [PATCH] Add `CHGNet.version` property (#86) * add CHGNet property version * refactor CHGNet.load() and pass model_name as version to CHGNet.__init__() * add test_model_load() * use new pymatgen (Structure|Molecule).to_ase_atoms() convenience methods * fixed typo * fix pkg not bundling pretrained checkpoints * pkg also bundle readmes for pretrained checkpoints * default model initialization to be same as 0.3.0 weights --------- Co-authored-by: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +-- chgnet/model/dynamics.py | 4 +- chgnet/model/model.py | 77 ++++++++++++++++++------------- chgnet/pretrained/0.2.0/README.md | 2 +- pyproject.toml | 2 +- tests/test_model.py | 18 ++++++++ 6 files changed, 71 insertions(+), 38 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 50e0fc8c..f4307d2f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,13 +4,13 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.292 + rev: v0.1.1 hooks: - id: ruff args: [--fix] - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.0 hooks: - id: black-jupyter @@ -49,7 +49,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v8.51.0 + rev: v8.52.0 hooks: - id: eslint types: [file] diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 5cb618e4..fc0f4d9a 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -222,7 +222,7 @@ def relax( A dictionary with 'final_structure' and 'trajectory'. """ if isinstance(atoms, Structure): - atoms = AseAtomsAdaptor.get_atoms(atoms) + atoms = atoms.to_ase_atoms() atoms.calc = self.calculator # assign model used to predict forces @@ -432,7 +432,7 @@ def __init__( self.ensemble = ensemble self.thermostat = thermostat if isinstance(atoms, (Structure, Molecule)): - atoms = AseAtomsAdaptor.get_atoms(atoms) + atoms = atoms.to_ase_atoms() self.atoms = atoms if isinstance(model, CHGNetCalculator): diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 0850807d..5e6a636d 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from chgnet import PredTask +module_dir = os.path.dirname(os.path.abspath(__file__)) + class CHGNet(nn.Module): """Crystal Hamiltonian Graph neural Network @@ -38,8 +40,8 @@ def __init__( bond_fea_dim: int = 64, angle_fea_dim: int = 64, composition_model: str | nn.Module = "MPtrj", - num_radial: int = 9, - num_angular: int = 9, + num_radial: int = 31, + num_angular: int = 31, n_conv: int = 4, atom_conv_hidden_dim: Sequence[int] | int = 64, update_bond: bool = True, @@ -48,19 +50,22 @@ def __init__( angle_layer_hidden_dim: Sequence[int] | int = 0, conv_dropout: float = 0, read_out: str = "ave", - mlp_hidden_dims: Sequence[int] | int = (64, 64), + mlp_hidden_dims: Sequence[int] | int = (64, 64, 64), mlp_dropout: float = 0, mlp_first: bool = True, is_intensive: bool = True, non_linearity: Literal["silu", "relu", "tanh", "gelu"] = "silu", - atom_graph_cutoff: float = 5, + atom_graph_cutoff: float = 6, bond_graph_cutoff: float = 3, graph_converter_algorithm: Literal["legacy", "fast"] = "fast", - cutoff_coeff: int = 5, + cutoff_coeff: int = 8, learnable_rbf: bool = True, + gMLP_norm: str | None = "layer", + readout_norm: str | None = "layer", + version: str | None = None, **kwargs, ) -> None: - """Initialize the CHGNet. + """Initialize CHGNet. Args: atom_fea_dim (int): atom feature vector embedding dimension. @@ -135,6 +140,11 @@ def __init__( learnable_rbf (bool): whether to set the frequencies in rbf and Fourier basis functions learnable. Default = True + gMLP_norm (str): normalization layer to use in gate-MLP + Default = 'layer' + readout_norm (str): normalization layer to use before readout layer + Default = 'layer' + version (str): Pretrained checkpoint version. **kwargs: Additional keyword arguments """ # Store model args for reconstruction @@ -144,6 +154,8 @@ def __init__( if k not in ["self", "__class__", "kwargs"] } self.model_args.update(kwargs) + if version: + self.model_args["version"] = version super().__init__() self.atom_fea_dim = atom_fea_dim @@ -200,7 +212,6 @@ def __init__( # Define convolutional layers conv_norm = kwargs.pop("conv_norm", None) - gMLP_norm = kwargs.pop("gMLP_norm", None) mlp_out_bias = kwargs.pop("mlp_out_bias", False) atom_graph_layers = [ AtomConv( @@ -261,9 +272,7 @@ def __init__( # Define readout layer self.site_wise = nn.Linear(atom_fea_dim, 1) - self.readout_norm = find_normalization( - name=kwargs.pop("readout_norm", None), dim=atom_fea_dim - ) + self.readout_norm = find_normalization(readout_norm, dim=atom_fea_dim) self.mlp_first = mlp_first if mlp_first: self.read_out_type = "sum" @@ -306,6 +315,11 @@ def __init__( f"parameters" ) + @property + def version(self) -> str | None: + """Return the version of the loaded checkpoint.""" + return self.model_args.get("version") + def forward( self, graphs: Sequence[CrystalGraph], @@ -313,12 +327,11 @@ def forward( return_site_energies: bool = False, return_atom_feas: bool = False, return_crystal_feas: bool = False, - ) -> dict: + ) -> dict[str, Tensor]: """Get prediction associated with input graphs Args: graphs (List): a list of CrystalGraphs - task (str): the prediction task - eg: 'e', 'em', 'ef', 'efs', 'efsm' + task (str): the prediction task. One of 'e', 'em', 'ef', 'efs', 'efsm'. Default = 'e' return_site_energies (bool): whether to return per-site energies, only available if self.mlp_first == True @@ -651,26 +664,28 @@ def from_file(cls, path, **kwargs): @classmethod def load(cls, model_name="0.3.0"): - """Load pretrained CHGNet.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - if model_name == "0.3.0": - return cls.from_file( - os.path.join( - current_dir, - "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar", - ) - ) - elif model_name == "0.2.0": # noqa: RET505 - return cls.from_file( - os.path.join( - current_dir, - "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar", - ), - mlp_out_bias=True, - ) - else: + """Load pretrained CHGNet model. + + Args: + model_name (str, optional): Defaults to "0.3.0". + + Raises: + ValueError: On unknown model_name. + """ + checkpoint_path = { + "0.3.0": "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar", + "0.2.0": "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar", + }.get(model_name) + + if checkpoint_path is None: raise ValueError(f"Unknown {model_name=}") + return cls.from_file( + os.path.join(module_dir, checkpoint_path), + mlp_out_bias=model_name == "0.2.0", + version=model_name, + ) + @dataclass class BatchedGraph: diff --git a/chgnet/pretrained/0.2.0/README.md b/chgnet/pretrained/0.2.0/README.md index 87ba34c0..6dcf0356 100755 --- a/chgnet/pretrained/0.2.0/README.md +++ b/chgnet/pretrained/0.2.0/README.md @@ -70,5 +70,5 @@ trainer = Trainer( | partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) | | ---------- | ----------------- | ------------- | ------------ | ------------ | | Train | 22 | 59 | 0.246 | 0.030 | -| Validation | 20 | 75 | 0.350 | 0.033 | +| Validation | 30 | 75 | 0.350 | 0.033 | | Test | 30 | 77 | 0.348 | 0.032 | diff --git a/pyproject.toml b/pyproject.toml index 329cc9de..c6caf92a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] } [tool.setuptools.package-data] "chgnet" = ["*.json"] -"chgnet.pretrained" = ["*.tar"] +"chgnet.pretrained" = ["**/*"] [tool.ruff] target-version = "py39" diff --git a/tests/test_model.py b/tests/test_model.py index 378ecdd1..6e7964a1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -223,3 +223,21 @@ def test_as_to_from_dict() -> None: model_3 = CHGNet(**to_dict["model_args"]) assert model_3.todict() == to_dict + + +def test_model_load(capsys: pytest.CaptureFixture) -> None: + model = CHGNet.load() + assert model.version == "0.3.0" + stdout, stderr = capsys.readouterr() + assert stdout == "CHGNet initialized with 412,525 parameters\n" + assert stderr == "" + + model = CHGNet.load(model_name="0.2.0") + assert model.version == "0.2.0" + stdout, stderr = capsys.readouterr() + assert stdout == "CHGNet initialized with 400,438 parameters\n" + assert stderr == "" + + model_name = "0.1.0" # invalid + with pytest.raises(ValueError, match=f"Unknown {model_name=}"): + CHGNet.load(model_name=model_name)