Skip to content

Commit

Permalink
Add limit_nus calc to ACE coeffs calculation
Browse files Browse the repository at this point in the history
Co-authored-by: James Goff <[email protected]>
  • Loading branch information
timcallow and jmgoff committed Jun 11, 2024
1 parent 808af1a commit c29119c
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 2 deletions.
8 changes: 7 additions & 1 deletion mala/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,14 @@ def __init__(self):
self.ace_M_R = 0

self.ace_coupling_type = "cg"

self.ace_lmax_traditional = 12

# TODO: add consistency check for these
# if grid_filter, types_like_snap must be False
# if grid_filter, padfunc must be True
self.ace_grid_filter = True
self.ace_types_like_snap = False
self.ace_padfunc = True


@property
Expand Down
65 changes: 64 additions & 1 deletion mala/descriptors/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ def calculate_coupling_coeffs(self):
+ " not recongised"
)

limit_nus = self.calc_limit_nus()

def calc_limit_nus(self):
ranked_chem_nus = []
for ind, rank in enumerate(self.parameters.ace_ranks):
rank = int(rank)
Expand All @@ -532,6 +535,66 @@ def calculate_coupling_coeffs(self):
)
ranked_chem_nus.append(PA_lammps)

nus_unsort = [item for sublist in ranked_chem_nus for item in sublist]
nus = nus_unsort.copy()
mu0s = []
mus = []
ns = []
ls = []
for nu in nus_unsort:
mu0ii, muii, nii, lii = acu.get_mu_n_l(nu)
mu0s.append(mu0ii)
mus.append(tuple(muii))
ns.append(tuple(nii))
ls.append(tuple(lii))
nus.sort(key=lambda x: mus[nus_unsort.index(x)], reverse=False)
nus.sort(key=lambda x: ns[nus_unsort.index(x)], reverse=False)
nus.sort(key=lambda x: ls[nus_unsort.index(x)], reverse=False)
nus.sort(key=lambda x: mu0s[nus_unsort.index(x)], reverse=False)
nus.sort(key=lambda x: len(x), reverse=False)
nus.sort(key=lambda x: mu0s[nus_unsort.index(x)], reverse=False)
musins = range(len(self.parameters.ace_elements) - 1)
all_funcs = {}
if self.parameters.ace_types_like_snap:
byattyp, byattypfiltered = self.srt_by_attyp(nus, 1)
if self.parameters.ace_grid_filter:
assert (
self.parameters.ace_padfunc
), "must pad with at least 1 other basis function for other element types to work in LAMMPS - set padfunc=True"
limit_nus = byattypfiltered["%d" % 0]
if self.parameters.ace_padfunc:
for muii in musins:
limit_nus.append(byattypfiltered["%d" % muii][0])
elif not grid_filter:
limit_nus = byattyp["%d" % 0]

else:
byattyp, byattypfiltered = self.srt_by_attyp(
nus, len(self.parameters.ace_elements)
)
if self.parameters.ace_grid_filter:
limit_nus = byattypfiltered[
"%d" % (len(self.parameters.ace_elements) - 1)
]
assert (
self.parameters.ace_padfunc
), "must pad with at least 1 other basis function for other element types to work in LAMMPS - set padfunc=True"
if self.parameters.ace_padfunc:
for muii in musins:
limit_nus.append(byattypfiltered["%d" % muii][0])
elif not grid_filter:
limit_nus = byattyp[
"%d" % (len(self.parameters.ace_elements) - 1)
]
if self.parameters.ace_padfunc:
for muii in musins:
limit_nus.append(byattyp["%d" % muii][0])
printout(
"all basis functions", len(nus), "grid subset", len(limit_nus)
)

return limit_nus

def get_default_settings(self):
rc_range = {bp: None for bp in self.bonds}
rin_def = {bp: None for bp in self.bonds}
Expand Down Expand Up @@ -653,7 +716,7 @@ def srt_by_attyp(self, nulst, remove_type=2):
for nu in nulst:
mu0 = nu.split("_")[0]
byattyp[mu0].append(nu)
mu0ii, muii, nii, lii = get_mu_n_l(nu)
mu0ii, muii, nii, lii = acu.get_mu_n_l(nu)
if mumax not in muii:
byattypfiltered[mu0].append(nu)
return byattyp, byattypfiltered
Expand Down
47 changes: 47 additions & 0 deletions mala/descriptors/ace_coupling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,50 @@ def ind_vec(lrng, size):
if pstr not in uniques:
uniques.append(pstr)
return uniques


def get_mu_n_l(nu_in, return_L=False, **kwargs):
rank = get_mu_nu_rank(nu_in)
if len(nu_in.split("_")) > 1:
if len(nu_in.split("_")) == 2:
nu = nu_in.split("_")[-1]
Lstr = ""
else:
nu = nu_in.split("_")[1]
Lstr = nu_in.split("_")[-1]
mu0 = int(nu_in.split("_")[0])
nusplt = [int(k) for k in nu.split(",")]
mu = nusplt[:rank]
n = nusplt[rank : 2 * rank]
l = nusplt[2 * rank :]
if len(Lstr) >= 1:
L = tuple([int(k) for k in Lstr.split("-")])
else:
L = None
if return_L:
return mu0, mu, n, l, L
else:
return mu0, mu, n, l
# provide option to get n,l for depricated descriptor labels
else:
nu = nu_in
mu0 = 0
mu = [0] * rank
nusplt = [int(k) for k in nu.split(",")]
n = nusplt[:rank]
l = nusplt[rank : 2 * rank]
return mu0, mu, n, l


def get_mu_nu_rank(nu_in):
if len(nu_in.split("_")) > 1:
assert (
len(nu_in.split("_")) <= 3
), "make sure your descriptor label is in proper format: mu0_mu1,mu2,mu3,n1,n2,n3,l1,l2,l3_L1"
nu = nu_in.split("_")[1]
nu_splt = nu.split(",")
return int(len(nu_splt) / 3)
else:
nu = nu_in
nu_splt = nu.split(",")
return int(len(nu_splt) / 2)

0 comments on commit c29119c

Please sign in to comment.