Skip to content

Commit

Permalink
Test preprocessing config
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Nov 14, 2024
1 parent 0d5e222 commit 293fe60
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ase.io
import numpy as np
import pytest
import yaml
from ase.atoms import Atoms

pytest_mace_dir = Path(__file__).parent.parent
Expand Down Expand Up @@ -164,3 +165,42 @@ def test_preprocess_data(tmp_path, sample_configs):
np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8)

print("All checks passed successfully!")


def test_preprocess_config(tmp_path, sample_configs):
ase.io.write(tmp_path / "sample.xyz", sample_configs)

preprocess_params = {
"train_file": str(tmp_path / "sample.xyz"),
"r_max": 5.0,
"config_type_weights": "{'Default':1.0}",
"num_process": 2,
"valid_fraction": 0.1,
"h5_prefix": str(tmp_path / "preprocessed_"),
"compute_statistics": None,
"seed": 42,
"energy_key": "REF_energy",
"forces_key": "REF_forces",
"stress_key": "REF_stress",
}
filename = tmp_path / "config.yaml"
with open(filename, "w", encoding="utf-8") as file:
yaml.dump(preprocess_params, file)

run_env = os.environ.copy()
sys.path.insert(0, str(Path(__file__).parent.parent))
run_env["PYTHONPATH"] = ":".join(sys.path)
print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"])

cmd = (
sys.executable
+ " "
+ str(preprocess_data)
+ " "
+ "--config"
+ " "
+ str(filename)
)

p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

0 comments on commit 293fe60

Please sign in to comment.