Skip to content

Commit

Permalink
add missing dtype conversion in calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Aug 23, 2024
1 parent ba31537 commit 630f6b5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
19 changes: 12 additions & 7 deletions mlff/cAPI/mlff_structure_relaxation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import numpy as np
import os
import logging

Expand Down Expand Up @@ -57,6 +58,8 @@ def run_relaxation():
parser.add_argument('--qn_tol', type=float, required=False, default=1e-4)
parser.add_argument('--qn_max_steps', type=int, required=False, default=200)

parser.add_argument('--optimizer', type=str, required=False, default='QuasiNewton')

parser.add_argument('--mic', type=str, required=False, default=None,
help='Minimal image convention.')

Expand Down Expand Up @@ -221,20 +224,22 @@ def load_start_geometry(f: str) -> Atoms:
#
# scales = read_json(os.path.join(ckpt_dir, 'scales.json'))

potential = mdx.MLFFPotential.create_from_ckpt_dir(ckpt_dir=ckpt_dir, dtype=_mdx_dtype)
calc = mlffCalculator(potential=potential,
capacity_multiplier=1.25,
F_to_eV_Ang=default_access(conversion_table, key=F_key, default=eV),
E_to_eV=default_access(conversion_table, key=E_key, default=eV),
)
calc = mlffCalculator.create_from_ckpt_dir(
ckpt_dir=ckpt_dir,
capacity_multiplier=1.25,
add_energy_shift=False,
F_to_eV_Ang=default_access(conversion_table, key=F_key, default=eV),
E_to_eV=default_access(conversion_table, key=E_key, default=eV),
dtype=np.float64,
)

molecule.set_calculator(calc)

# save the structure before the relaxation
from ase.io import write
write(os.path.join(save_dir, 'init_structure.xyz'), molecule)
# do a geometry relaxation
qn = ase_opt.LBFGS(molecule)
qn = getattr(ase_opt, args.optimizer)(molecule)
converged = qn.run(qn_tol, qn_max_steps)
if converged:
write(os.path.join(save_dir, 'relaxed_structure.xyz'), molecule)
Expand Down
6 changes: 3 additions & 3 deletions mlff/md/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def create_from_ckpt_dir(cls,
F_to_eV_Ang: float = 1.,
capacity_multiplier: float = 1.25,
add_energy_shift: bool = False,
dtype: np.dtype = np.float32):
dtype: np.dtype = np.float64):

mlff_potential = MLFFPotential.create_from_ckpt_dir(
ckpt_dir=ckpt_dir,
Expand All @@ -60,7 +60,7 @@ def __init__(
F_to_eV_Ang: float = 1.,
capacity_multiplier: float = 1.25,
calculate_stress: bool = False,
dtype: np.dtype = np.float32,
dtype: np.dtype = np.float64,
*args,
**kwargs
):
Expand Down Expand Up @@ -144,7 +144,7 @@ def calculate(self, atoms=None, *args, **kwargs):

output = self.calculate_fn(System(R=R, Z=z, cell=cell), neighbors=neighbors) # note different cell convention

self.results = jax.tree_map(lambda x: np.array(x), output)
self.results = jax.tree_map(lambda x: np.array(x, dtype=self.dtype), output)


def to_displacement(cell):
Expand Down

0 comments on commit 630f6b5

Please sign in to comment.