diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_device.py similarity index 100% rename from mace/cli/convert_dev.py rename to mace/cli/convert_device.py diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7f1a5e74..3813b055 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -36,7 +36,6 @@ LRScheduler, check_path_ase_read, convert_to_json_format, - create_error_table, dict_to_array, extract_config_mace_model, get_atomic_energies, @@ -49,9 +48,11 @@ get_params_options, get_swa, print_git_commit, + remove_pt_head, setup_wandb, ) from mace.tools.slurm_distributed import DistributedEnvironment +from mace.tools.tables_utils import create_error_table from mace.tools.utils import AtomicNumberTable @@ -115,10 +116,6 @@ def run(args: argparse.Namespace) -> None: commit = print_git_commit() model_foundation: Optional[torch.nn.Module] = None if args.foundation_model is not None: - if args.multiheads_finetuning: - assert ( - args.E0s != "average" - ), "average atomic energies cannot be used for multiheads finetuning" if args.foundation_model in ["small", "medium", "large"]: logging.info( f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." @@ -148,6 +145,27 @@ def run(args: argparse.Namespace) -> None: f"Using foundation model {args.foundation_model} as initial checkpoint." ) args.r_max = model_foundation.r_max.item() + if ( + args.foundation_model not in ["small", "medium", "large"] + and args.pt_train_file is None + ): + logging.warning( + "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." + ) + args.multiheads_finetuning = False + if args.multiheads_finetuning: + assert ( + args.E0s != "average" + ), "average atomic energies cannot be used for multiheads finetuning" + # check that the foundation model has a single head, if not, use the first head + if hasattr(model_foundation, "heads"): + if len(model_foundation.heads) > 1: + logging.warning( + "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." + ) + model_foundation = remove_pt_head( + model_foundation, args.foundation_head + ) else: args.multiheads_finetuning = False @@ -587,7 +605,6 @@ def run(args: argparse.Namespace) -> None: distributed_model = DDP(model, device_ids=[local_rank]) else: distributed_model = None - tools.train( model=model, loss_fn=loss_fn, diff --git a/mace/cli/select_head.py b/mace/cli/select_head.py new file mode 100644 index 00000000..a1e27229 --- /dev/null +++ b/mace/cli/select_head.py @@ -0,0 +1,33 @@ +from argparse import ArgumentParser + +import torch + +from mace.tools.scripts_utils import remove_pt_head + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--head_name", + "-n", + help="name of the head to extract", + default=None, + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model_single = remove_pt_head(model, args.head_name) + torch.save(model_single, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index e492c827..6d8be783 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -360,6 +360,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str2bool, default=True, ) + parser.add_argument( + "--foundation_model_head", + help="Name of the head to use for fine-tuning", + type=str, + default=None, + required=False, + ) parser.add_argument( "--weight_pt_head", help="Weight of the pretrained head in the loss function", diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 71ca6a7c..8df0b0d1 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -73,10 +73,10 @@ def load_foundations_elements( model.interactions[i].linear.weight = torch.nn.Parameter( model_foundations.interactions[i].linear.weight.clone() ) - if ( - model.interactions[i].__class__.__name__ - in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] - ): + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: model.interactions[i].skip_tp.weight = torch.nn.Parameter( model_foundations.interactions[i] .skip_tp.weight.reshape( @@ -101,19 +101,17 @@ def load_foundations_elements( .clone() / (num_species_foundations / num_species) ** 0.5 ) - if ( - model.interactions[i].__class__.__name__ - in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"] - ): + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: # Assuming only 1 layer in density_fn getattr(model.interactions[i].density_fn, "layer0").weight = ( torch.nn.Parameter( getattr( model_foundations.interactions[i].density_fn, "layer0", - ) - .weight - .clone() + ).weight.clone() ) ) # Transferring products diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index 8e8c2877..3f49eb41 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -53,7 +53,6 @@ def configure_model( model_config_foundation["atomic_inter_shift"] = ( _determine_atomic_inter_shift(args.mean, heads) ) - model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] args.model = "FoundationMACE" diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index d20e942b..be96558d 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -17,11 +17,9 @@ import torch import torch.distributed from e3nn import o3 -from prettytable import PrettyTable from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules, tools -from mace.tools import evaluate from mace.tools.train import SWAContainer @@ -224,6 +222,98 @@ def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: ) +def remove_pt_head( + model: torch.nn.Module, head_to_keep: Optional[str] = None +) -> torch.nn.Module: + """Converts a multihead MACE model to a single head model by removing the pretraining head. + + Args: + model (ScaleShiftMACE): The multihead MACE model to convert + head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head. + + Returns: + ScaleShiftMACE: A new MACE model with only the specified head + + Raises: + ValueError: If the model is not a multihead model or if the specified head is not found + """ + if not hasattr(model, "heads") or len(model.heads) <= 1: + raise ValueError("Model must be a multihead model with more than one head") + + # Get index of head to keep + if head_to_keep is None: + # Find first non-PT head + try: + head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head") + except StopIteration as e: + raise ValueError("No non-PT head found in model") from e + else: + try: + head_idx = model.heads.index(head_to_keep) + except ValueError as e: + raise ValueError(f"Head {head_to_keep} not found in model") from e + + # Extract config and modify for single head + model_config = extract_config_mace_model(model) + model_config["heads"] = [model.heads[head_idx]] + model_config["atomic_energies"] = ( + model.atomic_energies_fn.atomic_energies[head_idx] + .unsqueeze(0) + .detach() + .cpu() + .numpy() + ) + model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() + model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() + mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) // len(model.heads) + model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") + + new_model = model.__class__(**model_config) + state_dict = model.state_dict() + new_state_dict = {} + + for name, param in state_dict.items(): + if "atomic_energies" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "scale" in name or "shift" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "readouts" in name: + channels_per_head = param.shape[0] // len(model.heads) + start_idx = head_idx * channels_per_head + end_idx = start_idx + channels_per_head + if "linear_2.weight" in name: + end_idx = start_idx + channels_per_head // 2 + # if ( + # "readouts.0.linear.weight" in name + # or "readouts.1.linear_2.weight" in name + # ): + # new_state_dict[name] = param[start_idx:end_idx] / ( + # len(model.heads) ** 0.5 + # ) + if "readouts.0.linear.weight" in name: + new_state_dict[name] = param.reshape(-1, len(model.heads))[ + :, head_idx + ].flatten() + elif "readouts.1.linear_1.weight" in name: + new_state_dict[name] = param.reshape( + -1, len(model.heads), mlp_count_irreps + )[:, head_idx, :].flatten() + elif "readouts.1.linear_2.weight" in name: + new_state_dict[name] = param.reshape( + len(model.heads), -1, len(model.heads) + )[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5) + else: + new_state_dict[name] = param[start_idx:end_idx] + + else: + new_state_dict[name] = param + + # Load state dict into new model + new_model.load_state_dict(new_state_dict) + + return new_model + + def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: model_copy = model.__class__(**extract_config_mace_model(model)) model_copy.load_state_dict(model.state_dict()) @@ -613,19 +703,6 @@ def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: ] -def custom_key(key): - """ - Helper function to sort the keys of the data loader dictionary - to ensure that the training set, and validation set - are evaluated first - """ - if key == "train": - return (0, key) - if key == "valid": - return (1, key) - return (2, key) - - def dict_to_array(input_data, heads): if all(isinstance(value, np.ndarray) for value in input_data.values()): return np.array([input_data[head] for head in heads]) @@ -680,227 +757,6 @@ def __getattr__(self, name): return getattr(self.lr_scheduler, name) -def create_error_table( - table_type: str, - all_data_loaders: dict, - model: torch.nn.Module, - loss_fn: torch.nn.Module, - output_args: Dict[str, bool], - log_wandb: bool, - device: str, - distributed: bool = False, -) -> PrettyTable: - if log_wandb: - import wandb - table = PrettyTable() - if table_type == "TotalRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSEstressvirials": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - "RMSE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "PerAtomMAEstressvirials": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - "MAE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "TotalMAE": - table.field_names = [ - "config_type", - "MAE E / meV", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "PerAtomMAE": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "DipoleRMSE": - table.field_names = [ - "config_type", - "RMSE MU / mDebye / atom", - "relative MU RMSE %", - ] - elif table_type == "DipoleMAE": - table.field_names = [ - "config_type", - "MAE MU / mDebye / atom", - "relative MU MAE %", - ] - elif table_type == "EnergyDipoleRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "rel F RMSE %", - "RMSE MU / mDebye / atom", - "rel MU RMSE %", - ] - - for name in sorted(all_data_loaders, key=custom_key): - data_loader = all_data_loaders[name] - logging.info(f"Evaluating {name} ...") - _, metrics = evaluate( - model, - loss_fn=loss_fn, - data_loader=data_loader, - output_args=output_args, - device=device, - ) - if distributed: - torch.distributed.barrier() - - del data_loader - torch.cuda.empty_cache() - if log_wandb: - wandb_log_dict = { - name - + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] - * 1e3, # meV / atom - name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A - name + "_final_rel_rmse_f": metrics["rel_rmse_f"], - } - wandb.log(wandb_log_dict) - if table_type == "TotalRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif table_type == "PerAtomRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_virials'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_virials'] * 1000:8.1f}", - ] - ) - elif table_type == "TotalMAE": - table.add_row( - [ - name, - f"{metrics['mae_e'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "PerAtomMAE": - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "DipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - elif table_type == "DipoleMAE": - table.add_row( - [ - name, - f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_mae_mu']:8.1f}", - ] - ) - elif table_type == "EnergyDipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.1f}", - f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - return table - - def check_folder_subfolder(folder_path): entries = os.listdir(folder_path) for entry in entries: diff --git a/mace/tools/tables_utils.py b/mace/tools/tables_utils.py new file mode 100644 index 00000000..07f41401 --- /dev/null +++ b/mace/tools/tables_utils.py @@ -0,0 +1,241 @@ +import logging +from typing import Dict + +import torch +from prettytable import PrettyTable + +from mace.tools import evaluate + + +def custom_key(key): + """ + Helper function to sort the keys of the data loader dictionary + to ensure that the training set, and validation set + are evaluated first + """ + if key == "train": + return (0, key) + if key == "valid": + return (1, key) + return (2, key) + + +def create_error_table( + table_type: str, + all_data_loaders: dict, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + output_args: Dict[str, bool], + log_wandb: bool, + device: str, + distributed: bool = False, +) -> PrettyTable: + if log_wandb: + import wandb + table = PrettyTable() + if table_type == "TotalRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSEstressvirials": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + "RMSE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "TotalMAE": + table.field_names = [ + "config_type", + "MAE E / meV", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "PerAtomMAE": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "DipoleRMSE": + table.field_names = [ + "config_type", + "RMSE MU / mDebye / atom", + "relative MU RMSE %", + ] + elif table_type == "DipoleMAE": + table.field_names = [ + "config_type", + "MAE MU / mDebye / atom", + "relative MU MAE %", + ] + elif table_type == "EnergyDipoleRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "rel F RMSE %", + "RMSE MU / mDebye / atom", + "rel MU RMSE %", + ] + + for name in sorted(all_data_loaders, key=custom_key): + data_loader = all_data_loaders[name] + logging.info(f"Evaluating {name} ...") + _, metrics = evaluate( + model, + loss_fn=loss_fn, + data_loader=data_loader, + output_args=output_args, + device=device, + ) + if distributed: + torch.distributed.barrier() + + del data_loader + torch.cuda.empty_cache() + if log_wandb: + wandb_log_dict = { + name + + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] + * 1e3, # meV / atom + name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A + name + "_final_rel_rmse_f": metrics["rel_rmse_f"], + } + wandb.log(wandb_log_dict) + if table_type == "TotalRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif table_type == "PerAtomRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", + ] + ) + elif table_type == "TotalMAE": + table.add_row( + [ + name, + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "PerAtomMAE": + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "DipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + elif table_type == "DipoleMAE": + table.add_row( + [ + name, + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", + ] + ) + elif table_type == "EnergyDipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + return table diff --git a/setup.cfg b/setup.cfg index 139f914e..76467fda 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,8 @@ console_scripts = mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main mace_finetuning = mace.cli.fine_tuning_select:main - mace_convert_dev = mace.cli.convert_dev:main + mace_convert_device = mace.cli.convert_device:main + mace_select_head = mace.cli.select_head:main [options.extras_require] wandb = wandb diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 03ea85c3..44879395 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -12,7 +12,7 @@ from mace.calculators import mace_mp, mace_off from mace.tools import torch_geometric from mace.tools.finetuning_utils import load_foundations_elements -from mace.tools.scripts_utils import extract_config_mace_model +from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head from mace.tools.utils import AtomicNumberTable MODEL_PATH = ( @@ -208,3 +208,240 @@ def test_extract_config(model): for key in output.keys(): if isinstance(output[key], torch.Tensor): assert torch.allclose(output[key], output_copy[key], atol=1e-5) + + +def test_remove_pt_head(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT"], + "atomic_inter_scale": [1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test molecule + mol = molecule("H2O") + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + energy=1.0, + forces=np.random.randn(len(mol), 3), + head="DFT", + ) + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + # Test original mode + output_orig = model(batch) + + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep="DFT") + + # Basic structure tests + assert len(new_model.heads) == 1 + assert new_model.heads[0] == "DFT" + assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + + # Test output consistency + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + output_new = new_model(batch) + torch.testing.assert_close( + output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5 + ) + + +def test_remove_pt_head_multihead(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array( + [ + [1.0, 2.0], # H energies for each head + [3.0, 4.0], # O energies for each head + ] + * 2 + ) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT", "MP2", "CCSD"], + "atomic_inter_scale": [1.0, 1.0, 1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1, 0.2, 0.3], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test configurations for each head + mol = molecule("H2O") + configs = {} + atomic_datas = {} + dataloaders = {} + original_outputs = {} + + # First get outputs from original model for each head + for head in model.heads: + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + energy=1.0, + forces=np.random.randn(len(mol), 3), + head=head, + ) + configs[head] = config_pt_head + + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads + ) + atomic_datas[head] = atomic_data + + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + dataloaders[head] = dataloader + + batch = next(iter(dataloader)) + output = model(batch) + original_outputs[head] = output + + # Now test each head separately + for i, head in enumerate(model.heads): + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep=head) + + # Basic structure tests + assert len(new_model.heads) == 1, f"Failed for head {head}" + assert new_model.heads[0] == head, f"Failed for head {head}" + assert ( + new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + ), f"Failed for head {head}" + + # Verify scale and shift values + assert torch.allclose( + new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1] + ), f"Failed for head {head}" + assert torch.allclose( + new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1] + ), f"Failed for head {head}" + + # Test output consistency + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + new_output = new_model(batch) + + # Compare outputs + print( + original_outputs[head]["energy"], + new_output["energy"], + ) + torch.testing.assert_close( + original_outputs[head]["energy"], + new_output["energy"], + rtol=1e-5, + atol=1e-5, + msg=f"Energy mismatch for head {head}", + ) + torch.testing.assert_close( + original_outputs[head]["forces"], + new_output["forces"], + rtol=1e-5, + atol=1e-5, + msg=f"Forces mismatch for head {head}", + ) + + # Test error cases + with pytest.raises(ValueError, match="Head non_existent not found in model"): + remove_pt_head(model, head_to_keep="non_existent") + + # Test default behavior (first non-PT head) + default_model = remove_pt_head(model) + assert default_model.heads[0] == "DFT" + + # Additional test: check if each model's computation graph is independent + models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads} + results = {} + + for head, head_model in models.items(): + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + results[head] = head_model(batch) + + # Verify each model produces different outputs + energies = torch.stack([results[head]["energy"] for head in model.heads]) + assert not torch.allclose( + energies[0], energies[1], rtol=1e-3 + ), "Different heads should produce different outputs"