diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index e0258bd4..1f3068ba 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -6,6 +6,7 @@ import ase.io import numpy as np import pytest +import yaml from ase.atoms import Atoms pytest_mace_dir = Path(__file__).parent.parent @@ -164,3 +165,42 @@ def test_preprocess_data(tmp_path, sample_configs): np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) print("All checks passed successfully!") + + +def test_preprocess_config(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": str(tmp_path / "sample.xyz"), + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": str(tmp_path / "preprocessed_"), + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + yaml.dump(preprocess_params, file) + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + "--config" + + " " + + str(filename) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0