Skip to content

Commit

Permalink
automatic download of the descriptors
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Apr 30, 2024
1 parent cbdc546 commit 757385a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
29 changes: 15 additions & 14 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def parse_args() -> argparse.Namespace:
def calculate_descriptors(
atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict
) -> None:
print("Calculating descriptors")
logging.info("Calculating descriptors")
for mol in atoms:
descriptors = calc.get_descriptors(mol.copy(), invariants_only=True)
# average descriptors over atoms for each element
Expand Down Expand Up @@ -164,8 +164,8 @@ def run(
"""
Run the farthest point sampling algorithm.
"""
print(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape)
print("n_samples", self.n_samples)
logging.info(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape)
logging.info("n_samples", self.n_samples)
self.list_index = fpsample.fps_npdu_kdtree_sampling(
self.descriptors_dataset.reshape(len(self.atoms_list), -1), self.n_samples
)
Expand Down Expand Up @@ -207,12 +207,12 @@ def select_samples(

if args.filtering_type != None:
all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms])
print(
"Filtering configurations based on the finetuning set,"
logging.info(
"Filtering configurations based on the finetuning set, "
f"filtering type: combinations, elements: {all_species_ft}"
)
if args.descriptors is not None:
print("Loading descriptors")
logging.info("Loading descriptors")
descriptors = np.load(args.descriptors, allow_pickle=True)
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
for i, atoms in enumerate(atoms_list_pt):
Expand All @@ -222,7 +222,6 @@ def select_samples(
for x in atoms_list_pt
if filter_atoms(x, all_species_ft, "combinations")
]

else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
atoms_list_pt = [
Expand All @@ -233,7 +232,7 @@ def select_samples(
else:
atoms_list_pt = ase.io.read(args.configs_pt, index=":")
if args.descriptors is not None:
print(
logging.info(
"Loading descriptors for the pretraining set from {}".format(
args.descriptors
)
Expand All @@ -244,35 +243,37 @@ def select_samples(

if args.num_samples is not None and args.num_samples < len(atoms_list_pt):
if args.descriptors is None:
print("Calculating descriptors for the pretraining set")
logging.info("Calculating descriptors for the pretraining set")
calculate_descriptors(atoms_list_pt, calc, None)
descriptors_list = [
atoms.info["mace_descriptors"] for atoms in atoms_list_pt
]
print(
logging.info(
"Saving descriptors at {}".format(
args.output.replace(".xyz", "descriptors.npy")
)
)
np.save(args.output.replace(".xyz", "descriptors.npy"), descriptors_list)
print("Selecting configurations using Farthest Point Sampling")
logging.info("Selecting configurations using Farthest Point Sampling")
fps_pt = FPS(atoms_list_pt, args.num_samples)
idx_pt = fps_pt.run()
print(f"Selected {len(idx_pt)} configurations")
logging.info(f"Selected {len(idx_pt)} configurations")
atoms_list_pt = [atoms_list_pt[i] for i in idx_pt]
for atoms in atoms_list_pt:
# del atoms.info["mace_descriptors"]
atoms.info["pretrained"] = True
atoms.info["config_weight"] = args.weight_pt
atoms.info["mace_descriptors"] = None
if args.head_pt is not None:
atoms.info["head"] = args.head_pt

print("Saving the selected configurations")
logging.info("Saving the selected configurations")
ase.io.write(args.output, atoms_list_pt, format="extxyz")
print("Saving a combined XYZ file")
logging.info("Saving a combined XYZ file")
for atoms in atoms_list_ft:
atoms.info["pretrained"] = False
atoms.info["config_weight"] = args.weight_ft
atoms.info["mace_descriptors"] = None
if args.head_ft is not None:
atoms.info["head"] = args.head_ft
atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft
Expand Down
26 changes: 25 additions & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,20 @@ def main() -> None:
logging.info(f"Using heads: {heads}")
try:
checkpoint_url = "https://tinyurl.com/mw2wetc5"
descriptors_url = "https://tinyurl.com/mpe7br4d"
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c
for c in os.path.basename(checkpoint_url)
if c.isalnum() or c in "_"
)
cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}"
descriptors_url_name = "".join(
c
for c in os.path.basename(descriptors_url)
if c.isalnum() or c in "_"
)
cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}"
if not os.path.isfile(cached_dataset_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
Expand All @@ -205,9 +212,26 @@ def main() -> None:
f"Dataset download failed, please check the URL {checkpoint_url}"
)
logging.info(f"Materials Project dataset to {cached_dataset_path}")
if not os.path.isfile(cached_descriptors_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
logging.info("Downloading MP descriptors for finetuning")
_, http_msg = urllib.request.urlretrieve(
descriptors_url, cached_descriptors_path
)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Descriptors download failed, please check the URL {descriptors_url}"
)
logging.info(
f"Materials Project descriptors to {cached_descriptors_path}"
)
dataset_mp = cached_dataset_path
descriptors_mp = cached_descriptors_path
msg = f"Using Materials Project dataset with {dataset_mp}"
logging.info(msg)
msg = f"Using Materials Project descriptors with {descriptors_mp}"
logging.info(msg)
args_samples = {
"configs_pt": dataset_mp,
"configs_ft": args.train_file,
Expand All @@ -220,7 +244,7 @@ def main() -> None:
"weight_ft": 1.0,
"filtering_type": "combination",
"output": f"{cache_dir}/mp_finetuning.xyz",
"descriptors": r"D:\Work\mace_mp\descriptors.npy",
"descriptors": descriptors_mp,
"device": args.device,
"default_dtype": args.default_dtype,
}
Expand Down

0 comments on commit 757385a

Please sign in to comment.