From cbdc546d0c2419aa0917309cc118856f5ede321a Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:22:13 +0100 Subject: [PATCH] fix the interface for multihead --- mace/cli/fine_tuning_select.py | 9 +++------ mace/cli/run_train.py | 16 +++++++++++----- mace/data/utils.py | 1 - mace/tools/arg_parser.py | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index b90b432c..5fe1f7d0 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -99,7 +99,7 @@ def calculate_descriptors( atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict ) -> None: print("Calculating descriptors") - for mol in tqdm(atoms): + for mol in atoms: descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) # average descriptors over atoms for each element descriptors_dict = { @@ -182,7 +182,8 @@ def assemble_descriptors(self) -> np.ndarray: len(self.atoms_list), len(self.species), len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), - ) + ), + dtype=np.float32, ) ) for i, atoms in enumerate(self.atoms_list): @@ -216,10 +217,6 @@ def select_samples( atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): atoms.info["mace_descriptors"] = descriptors[i] - print( - "Filtering configurations based on the finetuning set," - f"filtering type: combinations, elements: {all_species_ft}" - ) atoms_list_pt = [ x for x in atoms_list_pt diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c39170cb..7c1d84df 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -177,7 +177,13 @@ def main() -> None: if args.multiheads_finetuning: logging.info("Using multiheads finetuning mode") - heads = list(set(["pbe_mp"] + heads)) + if heads is not None: + heads = list(set(["pbe_mp"] + heads)) + args.heads = heads + else: + heads = ["pbe_mp", "Default"] + args.heads = heads + logging.info(f"Using heads: {heads}") try: checkpoint_url = "https://tinyurl.com/mw2wetc5" cache_dir = os.path.expanduser("~/.cache/mace") @@ -214,13 +220,13 @@ def main() -> None: "weight_ft": 1.0, "filtering_type": "combination", "output": f"{cache_dir}/mp_finetuning.xyz", - "descriptors": None, + "descriptors": r"D:\Work\mace_mp\descriptors.npy", "device": args.device, "default_dtype": args.default_dtype, } select_samples(dict_to_namespace(args_samples)) collections_mp, _, _ = get_dataset_from_xyz( - train_path=dataset_mp, + train_path=f"{cache_dir}/mp_finetuning.xyz", valid_path=None, valid_fraction=args.valid_fraction, config_type_weights=config_type_weights, @@ -277,9 +283,9 @@ def main() -> None: else: atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table, heads) if args.multiheads_finetuning: - with open("mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file: + with open(r"mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file: E0s_mp = json.load(file) - atomic_energies_dict["pbe_mp"] = {E0s_mp["pbe"][z] for z in z_table.zs} + atomic_energies_dict["pbe_mp"] = {z: E0s_mp["pbe"][f"{z}"] for z in z_table.zs} if args.model == "AtomicDipolesMACE": atomic_energies = None diff --git a/mace/data/utils.py b/mace/data/utils.py index b12368bf..c55ad86b 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -231,7 +231,6 @@ def load_from_xyz( ) stress_key = "REF_stress" - # Process each atom only once for atoms in atoms_list: if energy_key == "REF_energy": try: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index fa646fbd..57249044 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -372,7 +372,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--stress_key", help="Key of reference stress in training xyz", type=str, - default="stress", + default="REF_stress", ) parser.add_argument( "--dipole_key",