Skip to content

Commit

Permalink
fix the interface for multihead
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Apr 30, 2024
1 parent 41d77c4 commit cbdc546
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
9 changes: 3 additions & 6 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def calculate_descriptors(
atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict
) -> None:
print("Calculating descriptors")
for mol in tqdm(atoms):
for mol in atoms:
descriptors = calc.get_descriptors(mol.copy(), invariants_only=True)
# average descriptors over atoms for each element
descriptors_dict = {
Expand Down Expand Up @@ -182,7 +182,8 @@ def assemble_descriptors(self) -> np.ndarray:
len(self.atoms_list),
len(self.species),
len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]),
)
),
dtype=np.float32,
)
)
for i, atoms in enumerate(self.atoms_list):
Expand Down Expand Up @@ -216,10 +217,6 @@ def select_samples(
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
for i, atoms in enumerate(atoms_list_pt):
atoms.info["mace_descriptors"] = descriptors[i]
print(
"Filtering configurations based on the finetuning set,"
f"filtering type: combinations, elements: {all_species_ft}"
)
atoms_list_pt = [
x
for x in atoms_list_pt
Expand Down
16 changes: 11 additions & 5 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,13 @@ def main() -> None:

if args.multiheads_finetuning:
logging.info("Using multiheads finetuning mode")
heads = list(set(["pbe_mp"] + heads))
if heads is not None:
heads = list(set(["pbe_mp"] + heads))
args.heads = heads
else:
heads = ["pbe_mp", "Default"]
args.heads = heads
logging.info(f"Using heads: {heads}")
try:
checkpoint_url = "https://tinyurl.com/mw2wetc5"
cache_dir = os.path.expanduser("~/.cache/mace")
Expand Down Expand Up @@ -214,13 +220,13 @@ def main() -> None:
"weight_ft": 1.0,
"filtering_type": "combination",
"output": f"{cache_dir}/mp_finetuning.xyz",
"descriptors": None,
"descriptors": r"D:\Work\mace_mp\descriptors.npy",
"device": args.device,
"default_dtype": args.default_dtype,
}
select_samples(dict_to_namespace(args_samples))
collections_mp, _, _ = get_dataset_from_xyz(
train_path=dataset_mp,
train_path=f"{cache_dir}/mp_finetuning.xyz",
valid_path=None,
valid_fraction=args.valid_fraction,
config_type_weights=config_type_weights,
Expand Down Expand Up @@ -277,9 +283,9 @@ def main() -> None:
else:
atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table, heads)
if args.multiheads_finetuning:
with open("mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file:
with open(r"mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file:
E0s_mp = json.load(file)
atomic_energies_dict["pbe_mp"] = {E0s_mp["pbe"][z] for z in z_table.zs}
atomic_energies_dict["pbe_mp"] = {z: E0s_mp["pbe"][f"{z}"] for z in z_table.zs}

if args.model == "AtomicDipolesMACE":
atomic_energies = None
Expand Down
1 change: 0 additions & 1 deletion mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def load_from_xyz(
)
stress_key = "REF_stress"

# Process each atom only once
for atoms in atoms_list:
if energy_key == "REF_energy":
try:
Expand Down
2 changes: 1 addition & 1 deletion mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--stress_key",
help="Key of reference stress in training xyz",
type=str,
default="stress",
default="REF_stress",
)
parser.add_argument(
"--dipole_key",
Expand Down

0 comments on commit cbdc546

Please sign in to comment.