Skip to content

Commit

Permalink
Use issubclass rather than isinstance
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jul 12, 2024
1 parent a0a8139 commit 2ef401e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
11 changes: 9 additions & 2 deletions shakenbreak/energy_lowering_distortions.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ def _prune_dict_across_charges(
defect_species_snb_name = f"{defect}_{'+' if charge > 0 else ''}{charge}"
for i in ["+", "", "+"]: # back to SnB name with "+" if all fail
defect_species = defect_species_snb_name.replace("+", i) # try with and without '+'
distorted_struct = distortion_dict["structures"][0]
if not isinstance(distorted_struct, Structure):
raise ValueError(
f"Distorted structure for {defect_species} was not correctly parsed. Instead "
f"of a pymatgen Structure object, got: {type(distorted_struct)}; "
f"{distorted_struct}"
)
comparison_results = compare_struct_to_distortions(
distortion_dict["structures"][0],
defect_species,
Expand Down Expand Up @@ -460,7 +467,7 @@ def get_energy_lowering_distortions(
structure_filename=structure_filename,
) # get the final structure of the
# energy lowering distortion
if any(isinstance(warning.category, UserWarning) for warning in w):
if any(issubclass(warning.category, UserWarning) for warning in w):
# problem parsing structure, user will have received appropriate
# warning from io.read_vasp_structure()
print(
Expand Down Expand Up @@ -532,7 +539,7 @@ def get_energy_lowering_distortions(
structure_path=f"{output_path}/{defect_species}/{bond_distortion}",
structure_filename=structure_filename,
)
if any(isinstance(warning.category, UserWarning) for warning in w):
if any(issubclass(warning.category, UserWarning) for warning in w):
# problem parsing structure, user will have received appropriate
# warning from io.read_vasp_structure()
print(
Expand Down
19 changes: 7 additions & 12 deletions tests/test_energy_lowering_distortions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ def setUp(self):
self.FHI_AIMS_DATA_DIR = os.path.join(self.DATA_DIR, "fhi_aims")
self.ESPRESSO_DATA_DIR = os.path.join(self.DATA_DIR, "quantum_espresso")
self.V_Cd_minus_0pt55_structure = Structure.from_file(
self.VASP_CDTE_DATA_DIR + "/vac_1_Cd_0/Bond_Distortion_-55.0%/CONTCAR"
f"{self.VASP_CDTE_DATA_DIR}/vac_1_Cd_0/Bond_Distortion_-55.0%/CONTCAR"
)

# create fake distortion folders for testing functionality:
for defect_dir in ["Int_Cd_2_+1", "vac_1_Cd_-1", "vac_1_Cd_-2"]:
if not os.path.exists(f"{self.VASP_CDTE_DATA_DIR}/{defect_dir}"):
os.mkdir(self.VASP_CDTE_DATA_DIR + f"/{defect_dir}")
os.mkdir(f"{self.VASP_CDTE_DATA_DIR}/{defect_dir}")
# Int_Cd_2_+1 without data, to test warnings
V_Cd_1_dict = {"distortions": {-0.075: -205.740}, "Unperturbed": -205.800}
dumpfn(
Expand Down Expand Up @@ -79,16 +79,10 @@ def setUp(self):
"sub_1_In_on_Cd_+1",
]

self.orig_castep_0pt3_files = os.listdir(
self.CASTEP_DATA_DIR + "/vac_1_Cd_0/Bond_Distortion_30.0%"
)
self.orig_cp2k_0pt3_files = os.listdir(self.CP2K_DATA_DIR + "/vac_1_Cd_0/Bond_Distortion_30.0%")
self.orig_fhi_aims_0pt3_files = os.listdir(
self.FHI_AIMS_DATA_DIR + "/vac_1_Cd_0/Bond_Distortion_30.0%"
)
self.orig_espresso_0pt3_files = os.listdir(
self.ESPRESSO_DATA_DIR + "/vac_1_Cd_0/Bond_Distortion_30.0%"
)
self.orig_castep_0pt3_files = os.listdir(f"{self.CASTEP_DATA_DIR}/vac_1_Cd_0/Bond_Distortion_30.0%")
self.orig_cp2k_0pt3_files = os.listdir(f"{self.CP2K_DATA_DIR}/vac_1_Cd_0/Bond_Distortion_30.0%")
self.orig_fhi_aims_0pt3_files = os.listdir(f"{self.FHI_AIMS_DATA_DIR}/vac_1_Cd_0/Bond_Distortion_30.0%")
self.orig_espresso_0pt3_files = os.listdir(f"{self.ESPRESSO_DATA_DIR}/vac_1_Cd_0/Bond_Distortion_30.0%")

def tearDown(self):
# removed generated folders
Expand Down Expand Up @@ -228,6 +222,7 @@ def test_get_energy_lowering_distortions(self):
low_energy_defects_dict = energy_lowering_distortions.get_energy_lowering_distortions(
self.defect_charges_dict, self.VASP_CDTE_DATA_DIR
)
print([str(warning.message) for warning in w]) # for debugging
mock_print.assert_any_call("\nvac_1_Cd")
mock_print.assert_any_call(
"vac_1_Cd_0: Energy difference between minimum, found with -0.55 bond distortion, "
Expand Down

0 comments on commit 2ef401e

Please sign in to comment.