Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify test_doper - softmax add up to 1 #197

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 147 additions & 108 deletions smact/dopant_prediction/doper.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,114 @@
from typing import List, Tuple

from pymatgen.util import plotting
from itertools import groupby

import numpy as np
import smact
from typing import Callable, List, Tuple, Union, Type
from pymatgen.util import plotting
from smact.structure_prediction import mutation, utilities

from smact.structure_prediction.mutation import CationMutator
from smact import Element, element_dictionary

class Doper:
"""
A class to search for n & p type dopants
Methods: get_dopants, plot_dopants

Attributes:
_original_species: A tuple which describes the constituent species of a material. For example:

>>> test= Doper(("Zn2+","S2-"))
>>> test.original_species
('Zn2+','S2-')

"""

def __init__(
self, _original_species: Tuple[str, ...], filepath: str = None
self, original_species: Tuple[str, ...], filepath: str = None
):
"""
Intialise the `Doper` class with a tuple of species

Args:
_original_species: See :class:`~.Doper`.
filepath (str): lambda table json file

original_species: See :class:`~.Doper`.
filepath (str): Path to a JSON file containing lambda table data.
"""
self._original_species = _original_species
self._filepath = filepath

@property
def original_species(self):
return self._original_species

@original_species.setter
def original_species(self, original_species):
self._original_species = original_species

@property
def filepath(self):
return self._filepath
self.original_species = original_species
self.filepath = filepath
self.results = None

def _get_selectivity(self, data_list: List[smact.Element], cations: List[smact.Element], CM:Type[CationMutator], sub):
data = data_list.copy()
for dopants in data:
if sub == "anion":
dopants.append(1.0)
continue
selected_site, original_specie, sub_prob = dopants[:3]
sum_prob = sub_prob
for cation in cations:
if cation != original_specie:
sum_prob += CM.sub_prob(cation, selected_site)

selectivity = sub_prob / sum_prob
selectivity = round(selectivity, 2)
dopants.append(selectivity)
assert len(dopants) == 4
return data

def _merge_dicts(self, keys, dopants_list, groupby_list):
merged_dict = dict()
for k, dopants, groupby in zip(keys, dopants_list, groupby_list):
merged_values = dict()
merged_values["sorted"] = dopants
for key, value in groupby.items():
merged_values[key] = sorted(value, key=lambda x:x[2], reverse=True)
merged_dict[k] = merged_values
return merged_dict

def _get_dopants(
self,
element_objects: List[smact.Element],
spicie_ions: List[str],
ion_type: str
):
"""
Get possible dopants for a given list of elements and dopants.

@filepath.setter
def filepath(self, filepath):
self._filepath = filepath
Args:
element_objects (List[smact.Element]): List of Element objects.
spicie_ions (List[str]): List of original species (anions or cations) as strings.
ion_type (str): Identify which spicie to check.

def _get_cation_dopants(
self, element_objects: List[smact.Element], cations: List[str]
):
poss_n_type_cat = set()
poss_p_type_cat = set()
Returns:
List[str]: List of possible dopants.
"""
poss_n_type = set()
poss_p_type = set()

for element in element_objects:
# [-2, -1, 0, +1, +2]
# i.e. element: "Zn", [-2, -1, 0, +1, +2]
oxi_state = element.oxidation_states
el_symbol = element.symbol
for state in oxi_state:
for cation in cations:
for ion in spicie_ions:
ele = utilities.unparse_spec((el_symbol, state))
_, charge = utilities.parse_spec(cation)
if state > charge:
poss_n_type_cat.add(ele)
elif state < charge and state > 0:
poss_p_type_cat.add(ele)
_, charge = utilities.parse_spec(ion)

return list(poss_n_type_cat), list(poss_p_type_cat)
if ion_type == "anion":
if state > charge and state < 0:
poss_n_type.add(ele)
elif state < charge:
poss_p_type.add(ele)
elif ion_type == "cation":
if state > charge:
poss_n_type.add(ele)
elif state < charge and state > 0:
poss_p_type.add(ele)

def _get_anion_dopants(
self, element_objects: List[smact.Element], anions: List[str]
):
poss_n_type_an = set()
poss_p_type_an = set()

for element in element_objects:
oxi_state = element.oxidation_states
el_symbol = element.symbol
for state in oxi_state:
for anion in anions:
ele = utilities.unparse_spec((el_symbol, state))
_, charge = utilities.parse_spec(anion)
if state > charge and state < 0:
poss_n_type_an.add(ele)
elif state < charge:
poss_p_type_an.add(ele)
return list(poss_n_type_an), list(poss_p_type_an)
return list(poss_n_type), list(poss_p_type)

def get_dopants(
self,
num_dopants: int = 5,
get_selectivity=True,
group_by_charge=True
) -> dict:
"""
Args:
num_dopants (int): The number of suggestions to return for n- and p-type dopants.
apply_softmax (bool): Whether to apply softmax to probabilities. (default = True)
get_selectivity (bool): Whether
Returns:
(dict): Dopant suggestions, given as a dictionary with keys
"n_type_cation", "p_type_cation", "n_type_anion", "p_type_anion".
Expand All @@ -114,99 +126,126 @@ def get_dopants(
('C4-', 9.31310255126729e-08)]}
"""

cations = []
anions = []
try:
for ion in self._original_species:
cations, anions = [], []

for ion in self.original_species:
try:
_, charge = utilities.parse_spec(ion)
if charge > 0:
cations.append(ion)
elif charge < 0:
anions.append(ion)
except Exception as e:
print(f"{e}: charge is not defined for {ion}!")
except Exception as e:
print(f"{e}: charge is not defined for {ion}!")

CM = mutation.CationMutator.from_json(self._filepath)
CM = mutation.CationMutator.from_json(self.filepath)

# call all elements
element_objects = list(smact.element_dictionary().values())
element_objects = list(element_dictionary().values())

poss_n_type_cat, poss_p_type_cat = self._get_cation_dopants(
element_objects, cations
poss_n_type_cat, poss_p_type_cat = self._get_dopants(
element_objects, cations, "cation"
)
poss_n_type_an, poss_p_type_an = self._get_anion_dopants(
element_objects, anions
poss_n_type_an, poss_p_type_an = self._get_dopants(
element_objects, anions, "anion"
)

n_type_cat, p_type_cat, n_type_an, p_type_an = [], [], [], []
for cation in cations:
cation_charge = utilities.parse_spec(cation)[1]

for n_specie in poss_n_type_cat:
n_specie_charge = utilities.parse_spec(n_specie)[1]
if cation_charge >= n_specie_charge:
continue
n_type_cat.append(
(n_specie, cation, CM.sub_prob(cation, n_specie))
[n_specie, cation, CM.sub_prob(cation, n_specie)]
)

for p_specie in poss_p_type_cat:
p_specie_charge = utilities.parse_spec(p_specie)[1]
if cation_charge <= p_specie_charge:
continue
p_type_cat.append(
(p_specie, cation, CM.sub_prob(cation, p_specie))
[p_specie, cation, CM.sub_prob(cation, p_specie)]
)

for anion in anions:
anion_charge = utilities.parse_spec(anion)[1]

for n_specie in poss_n_type_an:
n_specie_charge = utilities.parse_spec(n_specie)[1]
if anion == n_specie or anion_charge >= n_specie_charge:
if anion_charge >= n_specie_charge:
continue
n_type_an.append(
(n_specie, anion, CM.sub_prob(anion, n_specie))
[n_specie, anion, CM.sub_prob(anion, n_specie)]
)

for p_specie in poss_p_type_an:
p_specie_charge = utilities.parse_spec(p_specie)[1]
if anion == p_specie or anion_charge <= p_specie_charge:
if anion_charge <= p_specie_charge:
continue
p_type_an.append(
(p_specie, anion, CM.sub_prob(anion, p_specie))
[p_specie, anion, CM.sub_prob(anion, p_specie)]
)

# [('B3+', 0.003), ('C4+', 0.001), (), (), ...] : list(tuple(str, float))
dopants_lists = [n_type_cat, p_type_cat, n_type_an, p_type_an]

# sort by probability
n_type_cat.sort(key=lambda x: x[-1], reverse=True)
p_type_cat.sort(key=lambda x: x[-1], reverse=True)
n_type_an.sort(key=lambda x: x[-1], reverse=True)
p_type_an.sort(key=lambda x: x[-1], reverse=True)

self.results = {
"n-type cation substitutions": n_type_cat[:num_dopants],
"p-type cation substitutions": p_type_cat[:num_dopants],
"n-type anion substitutions": n_type_an[:num_dopants],
"p-type anion substitutions": p_type_an[:num_dopants],
}
for dopants_list in dopants_lists:
dopants_list.sort(key=lambda x: x[-1], reverse=True)

# if groupby
groupby_lists = [dict()] * 4 #create list of empty dict length of 4 (n-cat, p-cat, n-an, p-an)
# in case group_by_charge = False
if group_by_charge:
for i, dl in enumerate(dopants_lists):
# groupby first element charge
dl = sorted(dl, key=lambda x:utilities.parse_spec(x[0])[1])
grouped_data = groupby(dl, key=lambda x:utilities.parse_spec(x[0])[1])
grouped_top_data = {str(k): list(g)[:num_dopants] for k, g in grouped_data}
groupby_lists[i] = grouped_top_data
del grouped_data

# select top n elements
dopants_lists = [dopants_list[:num_dopants] for dopants_list in dopants_lists]

if get_selectivity:
for i in range(len(dopants_lists)):
sub = "cation"
if i > 1:
sub = "anion"
dopants_lists[i] = self._get_selectivity(dopants_lists[i], cations, CM, sub)

keys = [
"n-type cation substitutions",
"p-type cation substitutions",
"n-type anion substitutions",
"p-type anion substitutions",
]

self.results = self._merge_dicts(keys, dopants_lists, groupby_lists)

# return the top (num_dopants) results for each case
return self.results

def plot_dopants(self) -> None:
"""
Uses pymatgen plotting utilities to plot the results of doping search
Plot the dopant suggestions using the periodic table heatmap.
Args:
None
Returns:
None
"""
try:
for val in self.results.values():
dict_results = {
utilities.parse_spec(x)[0]: y for x, _, y in val
}
plotting.periodic_table_heatmap(
elemental_data=dict_results,
cmap="rainbow",
blank_color="gainsboro",
edge_color="white",
)
except AttributeError as e:
print(f"Dopants are not calculated. Run get_dopants first.")
assert self.results, "Dopants are not calculated. Run get_dopants first."

for dopant_type, dopants in self.results.items():
dict_results = {
utilities.parse_spec(x)[0]: y for x, _, y in dopants
}
plotting.periodic_table_heatmap(
elemental_data=dict_results,
cmap="rainbow",
blank_color="gainsboro",
edge_color="white",
)
18 changes: 13 additions & 5 deletions smact/tests/test_doper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,25 @@ def test_dopant_prediction(self):
self.assertIs(type(ls), list)

# Assert: (cation) higher charges for n-type and lower charges for p-type
n_sub_list_cat = test.get_dopants().get("n-type cation substitutions")
p_sub_list_cat = test.get_dopants().get("p-type cation substitutions")
n_sub_list_an = test.get_dopants().get("n-type anion substitutions")
p_sub_list_an = test.get_dopants().get("p-type anion substitutions")
n_sub_list_cat = test.get_dopants(apply_softmax=True).get("n-type cation substitutions")
p_sub_list_cat = test.get_dopants(apply_softmax=True).get("p-type cation substitutions")
n_sub_list_an = test.get_dopants(apply_softmax=True).get("n-type anion substitutions")
p_sub_list_an = test.get_dopants(apply_softmax=True).get("p-type anion substitutions")
result_list = [n_sub_list_cat, p_sub_list_cat, n_sub_list_an, p_sub_list_an]
for n_atom, p_atom in zip(n_sub_list_cat, p_sub_list_cat):
self.assertGreater(utilities.parse_spec(n_atom[0])[1], cat_charge)
self.assertLess(utilities.parse_spec(p_atom[0])[1], cat_charge)

for n_atom, p_atom in zip(n_sub_list_an, p_sub_list_an):
self.assertGreater(utilities.parse_spec(n_atom[0])[1], an_charge)
self.assertLess(utilities.parse_spec(p_atom[0])[1], an_charge)

# Assert: softmax add up to 1
for sub_list in result_list:
sum_softmax = 0
for doping_result in sub_list:
sum_softmax += doping_result[2]
self.assertAlmostEqual(1, sum_softmax)


if __name__ == "__main__":
Expand All @@ -47,4 +55,4 @@ def test_dopant_prediction(self):
)

runner = unittest.TextTestRunner()
result = runner.run(DoperTests)
result = runner.run(DoperTests)
Loading