Skip to content

Commit

Permalink
Dispersion (#192)
Browse files Browse the repository at this point in the history
* ENH: add dispersion tutorial

* STY: format dispersion notebook

* add extra deps set dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"]

also bump ruff

* tweak dispersion.ipynb var names

---------

Co-authored-by: alex <[email protected]>
Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
3 people authored Aug 31, 2024
1 parent be9bab9 commit 180ffea
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5
rev: v0.6.2
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.8.0
rev: v9.9.0
hooks:
- id: eslint
types: [file]
Expand Down
274 changes: 274 additions & 0 deletions examples/dispersion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Adding dispersion to the CHGNet pre-trained model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook describes the process of adding a dispersion correction to the CHGNet pre-trained model. CHGNet is trained on PBE (GGA) DFT calculations; as such, it does not include a correction for van der Waals or dispersive forces. This kind of correction may be particularly useful for those studying porous materials, such as MOFs or zeolites, but who do not wish to fine-tune the pre-trained model on data that include a dispersion correction.\n",
"\n",
"This notebook uses both the [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master) and [DFT-D4](https://dftd4.readthedocs.io/en/latest/reference/ase.html) repositories to add dispersion to CHGNet. The torch-dftd repository currently has DFT-D2 and DFT-D3 implementations and does not have the most recent DFT-D4 version, but is GPU-accelerated where DFT-D4 is not. The Grimme group has released a version of [DFT-D4 implemented in PyTorch](https://github.com/dftd4/tad-dftd4); however, this version does not have an ASE-compatible calculator available.\n",
"\n",
"You will need to install CHGNet, [ASE](https://wiki.fysik.dtu.dk/ase/install.html), [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master?tab=readme-ov-file#install), and [DFT-D4](https://dftd4.readthedocs.io/en/latest/recipe/installation.html) to run this notebook (links are to their installation instructions)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ase.build import fcc111\n",
"from ase.calculators.mixing import SumCalculator\n",
"from dftd4.ase import DFTD4\n",
"from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator\n",
"\n",
"from chgnet.model.dynamics import CHGNetCalculator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet v0.3.0 initialized with 412,525 parameters\n",
"CHGNet will run on cpu\n",
"CHGNet will run on cpu\n"
]
}
],
"source": [
"# pre-trained chgnet model\n",
"chgnet_calc = CHGNetCalculator()\n",
"\n",
"d3_calc = TorchDFTD3Calculator() # uses PBE parameters by default\n",
"\n",
"d4_calc = DFTD4(method=\"PBE\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A simple example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example shows how to initialize an Atoms object (of a Cu(111) surface) and compute its energy with and without the dispersion correction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Disp calculator: sumcalculator\n",
"Cu4\n",
"E without dispersion: -12.876540184020996\n",
"E with DFT-D3 dispersion: -13.272150007989707\n",
"E with DFT-D4 dispersion: -13.573149493981\n"
]
}
],
"source": [
"# Create a 2x2x1 fcc(111) Cu slab\n",
"atoms = fcc111(\"Cu\", (2, 2, 1), vacuum=10.0)\n",
"atoms.set_pbc([True, True, True])\n",
"\n",
"atoms_disp = atoms.copy()\n",
"atoms_d4 = atoms.copy()\n",
"\n",
"atoms.calc = chgnet_calc\n",
"\n",
"chgnet_d3 = SumCalculator([chgnet_calc, d3_calc])\n",
"chgnet_d4 = SumCalculator([chgnet_calc, d4_calc])\n",
"atoms_disp.calc = chgnet_d3\n",
"atoms_d4.calc = chgnet_d4\n",
"\n",
"e_chg = atoms.get_potential_energy()\n",
"e_disp = atoms_disp.get_potential_energy()\n",
"e_d4 = atoms_d4.get_potential_energy()\n",
"\n",
"print(f\"Disp calculator: {chgnet_d3.name}\")\n",
"print(atoms.get_chemical_formula())\n",
"print(f\"E without dispersion: {e_chg}\")\n",
"print(f\"E with DFT-D3 dispersion: {e_disp}\")\n",
"print(f\"E with DFT-D4 dispersion: {e_d4}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Optimization example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below is a simple example of an optimization of a Cu cell with a displaced atom and perturbed unit cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ase.build import bulk\n",
"from ase.filters import FrechetCellFilter\n",
"from ase.optimize import BFGS\n",
"\n",
"atoms = bulk(\"Cu\", cubic=True)\n",
"\n",
"atoms[0].x += 0.1\n",
"atoms.cell[0] += 0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Cell([[3.71, 0.1, 0.1], [0.0, 3.61, 0.0], [0.0, 0.0, 3.61]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atoms.cell"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Step Time Energy fmax\n",
"BFGS: 0 17:55:31 -18.448409 0.718576\n",
"BFGS: 1 17:55:33 -18.479250 0.649700\n",
"BFGS: 2 17:55:34 -18.602240 1352.772997\n",
"BFGS: 3 17:55:35 -18.495526 0.666888\n",
"BFGS: 4 17:55:36 -18.505485 0.674898\n",
"BFGS: 5 17:55:37 -18.524050 0.707074\n",
"BFGS: 6 17:55:39 -18.524685 0.708438\n",
"BFGS: 7 17:55:40 -18.527744 0.722830\n",
"BFGS: 8 17:55:41 -18.529436 0.740565\n",
"BFGS: 9 17:55:42 -18.530648 0.755404\n",
"BFGS: 10 17:55:43 -18.530939 0.757174\n",
"BFGS: 11 17:55:44 -18.531477 0.756699\n",
"BFGS: 12 17:55:46 -18.532681 0.750869\n",
"BFGS: 13 17:55:47 -18.534685 0.741311\n",
"BFGS: 14 17:55:48 -18.537005 0.728779\n",
"BFGS: 15 17:55:49 -18.539452 0.706232\n",
"BFGS: 16 17:55:50 -18.540848 0.684654\n",
"BFGS: 17 17:55:52 -18.541594 0.680409\n",
"BFGS: 18 17:55:53 -18.544656 0.659496\n",
"BFGS: 19 17:55:54 -18.548455 0.625415\n",
"BFGS: 20 17:55:55 -18.554630 0.578256\n",
"BFGS: 21 17:55:56 -18.562180 0.520110\n",
"BFGS: 22 17:55:58 -18.568905 0.460355\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atoms.calc = chgnet_d4\n",
"cell_filter = FrechetCellFilter(atoms)\n",
"opt = BFGS(cell_filter, trajectory=\"Cu.traj\")\n",
"\n",
"opt.run(fmax=0.5, steps=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The output of this optimization can be viewed by running\n",
"\n",
"```bash\n",
"ase gui Cu.traj\n",
"```\n",
"\n",
"in the command line in this folder."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "htvs",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ test = ["pytest-cov>=4", "pytest>=8"]
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
logging = ["wandb>=0.17"]
dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"]

[project.urls]
Source = "https://github.com/CederGroupHub/chgnet"
Expand All @@ -52,7 +53,6 @@ build-backend = "setuptools.build_meta"

[tool.ruff]
target-version = "py39"
extend-include = ["*.ipynb"]

[tool.ruff.lint]
select = ["ALL"]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from chgnet import ROOT


@pytest.fixture()
@pytest.fixture
def li_mn_o2() -> Structure:
return Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
2 changes: 1 addition & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
NaCl = Structure(lattice, species, coords)


@pytest.fixture()
@pytest.fixture
def _set_make_graph() -> Generator[None, None, None]:
# fixture to force make_graph to be None and then restore it after test
from chgnet.graph import converter
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
NaCl = Structure(lattice, species, coords)


@pytest.fixture()
@pytest.fixture
def structure_data() -> StructureData:
"""Create a graph with 3 nodes and 3 directed edges."""
random.seed(42)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chgnet.graph.graph import DirectedEdge, Graph, Node, UndirectedEdge


@pytest.fixture()
@pytest.fixture
def graph() -> Graph:
"""Create a graph with 3 nodes and 3 directed edges."""
nodes = [Node(index=idx) for idx in range(3)]
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_as_dict(graph: Graph) -> None:
assert len(graph_dict["undirected_edges_list"]) == 3


@pytest.fixture()
@pytest.fixture
def bigraph() -> Graph:
"""Create a bi-directional graph with 3 nodes and 4 bi-directed edges."""
nodes = [Node(index=idx) for idx in range(3)]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
assert torch.all(comparison == expect)


@pytest.fixture()
@pytest.fixture
def mock_wandb():
with patch("chgnet.trainer.trainer.wandb") as mock:
yield mock
Expand Down

0 comments on commit 180ffea

Please sign in to comment.