Skip to content

Commit

Permalink
change interface for automatic multihead finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Apr 29, 2024
1 parent 6a189dd commit 41d77c4
Show file tree
Hide file tree
Showing 18 changed files with 396 additions and 234 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Linting and code formatting

on: []
# Trigger the workflow on push or pull request,
# but only for the main branch
# push:
# branches: []
# pull_request:
# branches: []


jobs:
build-linux:
runs-on: ubuntu-latest

steps:
# Setup
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8.10
- name: Get cache
uses: actions/cache@v2
with:
path: /opt/hostedtoolcache/Python/3.8.10/x64/lib/python3.8/site-packages
# Look to see if there is a cache hit for the corresponding requirements file
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}

# Install packages
- name: Install packages required for installation
run: python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: pip install -r requirements.txt

# Check code
- name: Check formatting with yapf
run: python -m yapf --style=.style.yapf --diff --recursive .
# - name: Lint with flake8
# run: flake8 --config=.flake8 .
# - name: Check type annotations with mypy
# run: mypy --config-file=.mypy.ini .

- name: Test with pytest
run: python -m pytest tests
12 changes: 6 additions & 6 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def __init__(
)
self.charges_key = charges_key
try:
self.theories = self.models[0].theories
self.heads = self.models[0].heads
except:
self.theories = ["Default"]
self.heads = ["Default"]
model_dtype = get_model_dtype(self.models[0])
if default_dtype == "":
print(
Expand Down Expand Up @@ -223,7 +223,7 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
config,
z_table=self.z_table,
cutoff=self.r_max,
theories=self.theories,
heads=self.heads,
)
],
batch_size=1,
Expand All @@ -233,10 +233,10 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):

if self.model_type in ["MACE", "EnergyDipoleMACE"]:
batch = next(iter(data_loader)).to(self.device)
node_theories = batch["theory"][batch["batch"]]
node_heads = batch["head"][batch["batch"]]
num_atoms_arange = torch.arange(batch["positions"].shape[0])
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
num_atoms_arange, node_theories
num_atoms_arange, node_heads
]
compute_stress = not self.use_compile
else:
Expand Down Expand Up @@ -339,7 +339,7 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
config,
z_table=self.z_table,
cutoff=self.r_max,
theories=self.theories,
heads=self.heads,
)
],
batch_size=1,
Expand Down
32 changes: 16 additions & 16 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,14 @@ def parse_args() -> argparse.Namespace:
default="float64",
)
parser.add_argument(
"--info_prefix",
help="prefix for energy, forces and stress keys",
type=str,
default="MACE_",
)
parser.add_argument(
"--theory_pt",
help="level of theory for the pretraining set",
"--head_pt",
help="level of head for the pretraining set",
type=str,
default=None,
)
parser.add_argument(
"--theory_ft",
help="level of theory for the finetuning set",
"--head_ft",
help="level of head for the finetuning set",
type=str,
default=None,
)
Expand Down Expand Up @@ -199,8 +193,9 @@ def assemble_descriptors(self) -> np.ndarray:
)


def main():
args = parse_args()
def select_samples(
args: argparse.Namespace,
) -> None:
if args.model in ["small", "medium", "large"]:
calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype)
else:
Expand Down Expand Up @@ -272,22 +267,27 @@ def main():
# del atoms.info["mace_descriptors"]
atoms.info["pretrained"] = True
atoms.info["config_weight"] = args.weight_pt
if args.theory_pt is not None:
atoms.info["theory"] = args.theory_pt
if args.head_pt is not None:
atoms.info["head"] = args.head_pt

print("Saving the selected configurations")
ase.io.write(args.output, atoms_list_pt, format="extxyz")
print("Saving a combined XYZ file")
for atoms in atoms_list_ft:
atoms.info["pretrained"] = False
atoms.info["config_weight"] = args.weight_ft
if args.theory_ft is not None:
atoms.info["theory"] = args.theory_ft
if args.head_ft is not None:
atoms.info["head"] = args.head_ft
atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft
ase.io.write(
args.output.replace(".xyz", "_combined.xyz"), atoms_fps_pt_ft, format="extxyz"
)


def main():
args = parse_args()
select_samples(args)


if __name__ == "__main__":
main()
Loading

0 comments on commit 41d77c4

Please sign in to comment.