Skip to content

Commit

Permalink
import config from jax not from jax.config
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jul 4, 2024
1 parent eae3c0b commit ecc9a6c
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion mlff/cAPI/mlff_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def evaluate():

jax_dtype = args.jax_dtype
if jax_dtype == 'x64':
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

ckpt_dir = (Path(args.ckpt_dir).absolute().resolve()).as_posix()
Expand Down
2 changes: 1 addition & 1 deletion mlff/cAPI/mlff_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def run_md():
mdx_scan_interval = args.mdx_scan_interval

if args.mdx_dtype == 'x64':
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
_mdx_dtype = jnp.float64 if args.mdx_dtype == 'x64' else jnp.float32
Expand Down
4 changes: 2 additions & 2 deletions mlff/cAPI/mlff_structure_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def run_relaxation():
x64 = args.x64

if args.mdx_dtype == 'x64':
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
_mdx_dtype = jnp.float64 if args.mdx_dtype == 'x64' else jnp.float32
Expand All @@ -104,7 +104,7 @@ def run_relaxation():
set_seeds(seed)

if x64:
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

h = read_json(os.path.join(ckpt_dir, 'hyperparameters.json'))
Expand Down
2 changes: 1 addition & 1 deletion mlff/cAPI/mlff_train_so3kratace.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train_so3kratace():

jax_dtype = args.jax_dtype
if jax_dtype == 'x64':
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

r_cut = args.r_cut
Expand Down
2 changes: 1 addition & 1 deletion mlff/cAPI/mlff_train_so3krates.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def parse_data_file(x):

jax_dtype = args.jax_dtype
if jax_dtype == 'x64':
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

r_cut = args.r_cut
Expand Down
2 changes: 1 addition & 1 deletion tests/test_neighbor_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def filter_and_sort(distances, cutoff):


def test_neighborhood_list():
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

# load all data from the new ZrO2 data set
Expand Down

0 comments on commit ecc9a6c

Please sign in to comment.