Skip to content

Commit

Permalink
Merge pull request #687 from ACEsuit/develop
Browse files Browse the repository at this point in the history
fix case with multihead foundation model
  • Loading branch information
ilyes319 authored Nov 12, 2024
2 parents f1e671d + 68149a7 commit 83b30b8
Show file tree
Hide file tree
Showing 10 changed files with 645 additions and 256 deletions.
File renamed without changes.
29 changes: 23 additions & 6 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
LRScheduler,
check_path_ase_read,
convert_to_json_format,
create_error_table,
dict_to_array,
extract_config_mace_model,
get_atomic_energies,
Expand All @@ -49,9 +48,11 @@
get_params_options,
get_swa,
print_git_commit,
remove_pt_head,
setup_wandb,
)
from mace.tools.slurm_distributed import DistributedEnvironment
from mace.tools.tables_utils import create_error_table
from mace.tools.utils import AtomicNumberTable


Expand Down Expand Up @@ -115,10 +116,6 @@ def run(args: argparse.Namespace) -> None:
commit = print_git_commit()
model_foundation: Optional[torch.nn.Module] = None
if args.foundation_model is not None:
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
if args.foundation_model in ["small", "medium", "large"]:
logging.info(
f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint."
Expand Down Expand Up @@ -148,6 +145,27 @@ def run(args: argparse.Namespace) -> None:
f"Using foundation model {args.foundation_model} as initial checkpoint."
)
args.r_max = model_foundation.r_max.item()
if (
args.foundation_model not in ["small", "medium", "large"]
and args.pt_train_file is None
):
logging.warning(
"Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file."
)
args.multiheads_finetuning = False
if args.multiheads_finetuning:
assert (
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head
if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1:
logging.warning(
"Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head."
)
model_foundation = remove_pt_head(
model_foundation, args.foundation_head
)
else:
args.multiheads_finetuning = False

Expand Down Expand Up @@ -587,7 +605,6 @@ def run(args: argparse.Namespace) -> None:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None

tools.train(
model=model,
loss_fn=loss_fn,
Expand Down
33 changes: 33 additions & 0 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from argparse import ArgumentParser

import torch

from mace.tools.scripts_utils import remove_pt_head


def main():
parser = ArgumentParser()
parser.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model_single = remove_pt_head(model, args.head_name)
torch.save(model_single, args.output_file)


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str2bool,
default=True,
)
parser.add_argument(
"--foundation_model_head",
help="Name of the head to use for fine-tuning",
type=str,
default=None,
required=False,
)
parser.add_argument(
"--weight_pt_head",
help="Weight of the pretrained head in the loss function",
Expand Down
20 changes: 9 additions & 11 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def load_foundations_elements(
model.interactions[i].linear.weight = torch.nn.Parameter(
model_foundations.interactions[i].linear.weight.clone()
)
if (
model.interactions[i].__class__.__name__
in ["RealAgnosticResidualInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
if model.interactions[i].__class__.__name__ in [
"RealAgnosticResidualInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
model.interactions[i].skip_tp.weight = torch.nn.Parameter(
model_foundations.interactions[i]
.skip_tp.weight.reshape(
Expand All @@ -101,19 +101,17 @@ def load_foundations_elements(
.clone()
/ (num_species_foundations / num_species) ** 0.5
)
if (
model.interactions[i].__class__.__name__
in ["RealAgnosticDensityInteractionBlock", "RealAgnosticDensityResidualInteractionBlock"]
):
if model.interactions[i].__class__.__name__ in [
"RealAgnosticDensityInteractionBlock",
"RealAgnosticDensityResidualInteractionBlock",
]:
# Assuming only 1 layer in density_fn
getattr(model.interactions[i].density_fn, "layer0").weight = (
torch.nn.Parameter(
getattr(
model_foundations.interactions[i].density_fn,
"layer0",
)
.weight
.clone()
).weight.clone()
)
)
# Transferring products
Expand Down
1 change: 0 additions & 1 deletion mace/tools/model_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def configure_model(
model_config_foundation["atomic_inter_shift"] = (
_determine_atomic_inter_shift(args.mean, heads)
)

model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads)
args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"]
args.model = "FoundationMACE"
Expand Down
Loading

0 comments on commit 83b30b8

Please sign in to comment.