diff --git a/mala/common/parameters.py b/mala/common/parameters.py index 64325a77b..c1829b6d2 100644 --- a/mala/common/parameters.py +++ b/mala/common/parameters.py @@ -382,8 +382,14 @@ def __init__(self): self.ace_M_R = 0 self.ace_coupling_type = "cg" - self.ace_lmax_traditional = 12 + + # TODO: add consistency check for these + # if grid_filter, types_like_snap must be False + # if grid_filter, padfunc must be True + self.ace_grid_filter = True + self.ace_types_like_snap = False + self.ace_padfunc = True @property diff --git a/mala/descriptors/ace.py b/mala/descriptors/ace.py index 8b8b6adc8..d6f9ec8b0 100755 --- a/mala/descriptors/ace.py +++ b/mala/descriptors/ace.py @@ -520,6 +520,9 @@ def calculate_coupling_coeffs(self): + " not recongised" ) + limit_nus = self.calc_limit_nus() + + def calc_limit_nus(self): ranked_chem_nus = [] for ind, rank in enumerate(self.parameters.ace_ranks): rank = int(rank) @@ -532,6 +535,66 @@ def calculate_coupling_coeffs(self): ) ranked_chem_nus.append(PA_lammps) + nus_unsort = [item for sublist in ranked_chem_nus for item in sublist] + nus = nus_unsort.copy() + mu0s = [] + mus = [] + ns = [] + ls = [] + for nu in nus_unsort: + mu0ii, muii, nii, lii = acu.get_mu_n_l(nu) + mu0s.append(mu0ii) + mus.append(tuple(muii)) + ns.append(tuple(nii)) + ls.append(tuple(lii)) + nus.sort(key=lambda x: mus[nus_unsort.index(x)], reverse=False) + nus.sort(key=lambda x: ns[nus_unsort.index(x)], reverse=False) + nus.sort(key=lambda x: ls[nus_unsort.index(x)], reverse=False) + nus.sort(key=lambda x: mu0s[nus_unsort.index(x)], reverse=False) + nus.sort(key=lambda x: len(x), reverse=False) + nus.sort(key=lambda x: mu0s[nus_unsort.index(x)], reverse=False) + musins = range(len(self.parameters.ace_elements) - 1) + all_funcs = {} + if self.parameters.ace_types_like_snap: + byattyp, byattypfiltered = self.srt_by_attyp(nus, 1) + if self.parameters.ace_grid_filter: + assert ( + self.parameters.ace_padfunc + ), "must pad with at least 1 other basis function for other element types to work in LAMMPS - set padfunc=True" + limit_nus = byattypfiltered["%d" % 0] + if self.parameters.ace_padfunc: + for muii in musins: + limit_nus.append(byattypfiltered["%d" % muii][0]) + elif not grid_filter: + limit_nus = byattyp["%d" % 0] + + else: + byattyp, byattypfiltered = self.srt_by_attyp( + nus, len(self.parameters.ace_elements) + ) + if self.parameters.ace_grid_filter: + limit_nus = byattypfiltered[ + "%d" % (len(self.parameters.ace_elements) - 1) + ] + assert ( + self.parameters.ace_padfunc + ), "must pad with at least 1 other basis function for other element types to work in LAMMPS - set padfunc=True" + if self.parameters.ace_padfunc: + for muii in musins: + limit_nus.append(byattypfiltered["%d" % muii][0]) + elif not grid_filter: + limit_nus = byattyp[ + "%d" % (len(self.parameters.ace_elements) - 1) + ] + if self.parameters.ace_padfunc: + for muii in musins: + limit_nus.append(byattyp["%d" % muii][0]) + printout( + "all basis functions", len(nus), "grid subset", len(limit_nus) + ) + + return limit_nus + def get_default_settings(self): rc_range = {bp: None for bp in self.bonds} rin_def = {bp: None for bp in self.bonds} @@ -653,7 +716,7 @@ def srt_by_attyp(self, nulst, remove_type=2): for nu in nulst: mu0 = nu.split("_")[0] byattyp[mu0].append(nu) - mu0ii, muii, nii, lii = get_mu_n_l(nu) + mu0ii, muii, nii, lii = acu.get_mu_n_l(nu) if mumax not in muii: byattypfiltered[mu0].append(nu) return byattyp, byattypfiltered diff --git a/mala/descriptors/ace_coupling_utils.py b/mala/descriptors/ace_coupling_utils.py index 573a18bdd..9bee88529 100644 --- a/mala/descriptors/ace_coupling_utils.py +++ b/mala/descriptors/ace_coupling_utils.py @@ -543,3 +543,50 @@ def ind_vec(lrng, size): if pstr not in uniques: uniques.append(pstr) return uniques + + +def get_mu_n_l(nu_in, return_L=False, **kwargs): + rank = get_mu_nu_rank(nu_in) + if len(nu_in.split("_")) > 1: + if len(nu_in.split("_")) == 2: + nu = nu_in.split("_")[-1] + Lstr = "" + else: + nu = nu_in.split("_")[1] + Lstr = nu_in.split("_")[-1] + mu0 = int(nu_in.split("_")[0]) + nusplt = [int(k) for k in nu.split(",")] + mu = nusplt[:rank] + n = nusplt[rank : 2 * rank] + l = nusplt[2 * rank :] + if len(Lstr) >= 1: + L = tuple([int(k) for k in Lstr.split("-")]) + else: + L = None + if return_L: + return mu0, mu, n, l, L + else: + return mu0, mu, n, l + # provide option to get n,l for depricated descriptor labels + else: + nu = nu_in + mu0 = 0 + mu = [0] * rank + nusplt = [int(k) for k in nu.split(",")] + n = nusplt[:rank] + l = nusplt[rank : 2 * rank] + return mu0, mu, n, l + + +def get_mu_nu_rank(nu_in): + if len(nu_in.split("_")) > 1: + assert ( + len(nu_in.split("_")) <= 3 + ), "make sure your descriptor label is in proper format: mu0_mu1,mu2,mu3,n1,n2,n3,l1,l2,l3_L1" + nu = nu_in.split("_")[1] + nu_splt = nu.split(",") + return int(len(nu_splt) / 3) + else: + nu = nu_in + nu_splt = nu.split(",") + return int(len(nu_splt) / 2)