Skip to content

Commit

Permalink
Merge pull request #222 from timcallow/calcenergy_grid_type
Browse files Browse the repository at this point in the history
Make default "grid_type" consistent and add checker
  • Loading branch information
timcallow authored Nov 9, 2023
2 parents ab00598 + f6fb204 commit cd204de
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 4 deletions.
24 changes: 24 additions & 0 deletions atoMEC/check_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,30 @@ def check_band_params(input_params):

return band_params

@staticmethod
def check_grid_type(grid_type):
r"""Check grid type.
Parameters
----------
grid_type : str
the grid type
Returns
-------
grid_type : str
the grid type
Raises
------
InputError.grid_type_error
if grid type not one of "log" or "sqrt"
"""
if grid_type not in ["log", "sqrt"]:
raise InputError.grid_error("Grid type must be either 'log' or 'sqrt'")
else:
return grid_type


class InputError(Exception):
"""Exit atoMEC and print relevant input error message."""
Expand Down
7 changes: 5 additions & 2 deletions atoMEC/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def CalcEnergy(
scf_params={},
band_params={},
force_bound=[],
grid_type="log",
grid_type="sqrt",
verbosity=0,
write_info=True,
write_density=True,
Expand Down Expand Up @@ -280,6 +280,9 @@ def CalcEnergy(
forces the orbital with quantum numbers :math:`\sigma=0,\ l=1,\ n=0` to be
always bound even if it has positive energy. This prevents convergence
issues.
grid_type : str, optional
the transformed radial grid used for the KS equations.
can be 'sqrt' (default) or 'log'
verbosity : int, optional
how much information is printed at each SCF cycle.
`verbosity=0` prints the total energy and convergence values (default).
Expand Down Expand Up @@ -332,7 +335,7 @@ def CalcEnergy(
config.conv_params = check_inputs.EnergyCalcs.check_conv_params(conv_params)
config.scf_params = check_inputs.EnergyCalcs.check_scf_params(scf_params)
config.band_params = check_inputs.EnergyCalcs.check_band_params(band_params)
config.grid_type = grid_type
config.grid_type = check_inputs.EnergyCalcs.check_grid_type(grid_type)

# experimental change
config.force_bound = force_bound
Expand Down
1 change: 1 addition & 0 deletions tests/boundary_conditions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _run(bc):
scf_params={"maxscf": 5, "mixfrac": 0.3},
grid_params={"ngrid": 1000, "ngrid_coarse": 600},
band_params={"nkpts": nkpts},
grid_type="log",
)

# extract the total free energy
Expand Down
1 change: 1 addition & 0 deletions tests/conductivity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _run_SCF():
4,
scf_params={"mixfrac": 0.3, "maxscf": 6},
grid_params={"ngrid": 1200, "ngrid_coarse": 300},
grid_type="log",
)

output_dict = {"Atom": F_at, "model": model, "SCF_out": output}
Expand Down
1 change: 1 addition & 0 deletions tests/energy_alt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _run(unbound):
5,
scf_params={"maxscf": 6, "mixfrac": 0.3},
grid_params={"ngrid": 1000, "ngrid_coarse": 300},
grid_type="log",
)

# construct the EnergyAlt object
Expand Down
8 changes: 8 additions & 0 deletions tests/exceptions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,11 @@ def test_band_params(self, bands_input):

with pytest.raises(SystemExit):
model.CalcEnergy(3, 3, band_params=bands_input)

def test_grid_type(self):
"""Test the grid type."""
atom = Atom("Al", 0.01, radius=1)
model = models.ISModel(atom, bc="bands", unbound="quantum")

with pytest.raises(SystemExit):
model.CalcEnergy(3, 3, grid_type="linear")
1 change: 1 addition & 0 deletions tests/functionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _run(func):
scf_params={"maxscf": 5, "mixfrac": 0.3},
grid_params={"ngrid": 1000, "ngrid_coarse": 90},
band_params={"nkpts": 30},
grid_type="log",
)

# extract the total free energy
Expand Down
3 changes: 1 addition & 2 deletions tests/gramschmidt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def SCF_output(self):
)
def test_overlap(self, input_SCF, case, expected):
"""Run overlap integral after orthonormalizatation."""
config.grid_type = "log"
assert np.isclose(
self._run_overlap(input_SCF, case),
expected,
Expand Down Expand Up @@ -66,6 +65,7 @@ def _run_SCF():
4,
scf_params={"mixfrac": 0.3, "maxscf": 5},
grid_params={"ngrid": 1000, "ngrid_coarse": 300},
grid_type="log",
)

return output
Expand Down Expand Up @@ -100,7 +100,6 @@ def _run_overlap(input_SCF, case):


if __name__ == "__main__":
config.grid_type = "log"
SCF_out = TestGS._run_SCF()
print("self_overlap_expected =", TestGS._run_overlap(SCF_out, "self"))
print("overlap_expected =", TestGS._run_overlap(SCF_out, "other"))
1 change: 1 addition & 0 deletions tests/localization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _run_SCF(spinpol):
2,
scf_params={"mixfrac": 0.3, "maxscf": 50},
grid_params={"ngrid": 1000, "ngrid_coarse": 300},
grid_type="log",
)

output_dict = {"Atom": Al_at, "model": model, "SCF_out": output}
Expand Down
1 change: 1 addition & 0 deletions tests/serial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _run(ngrid):
scf_params={"maxscf": 1, "mixfrac": 0.7},
band_params={"nkpts": 30},
grid_params={"ngrid": ngrid, "ngrid_coarse": 300},
grid_type="log",
)

# extract the total free energy
Expand Down
1 change: 1 addition & 0 deletions tests/spin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _run(spinmag):
scf_params={"maxscf": 4, "mixfrac": 0.3},
band_params={"nkpts": 50},
grid_params={"ngrid": 1000, "ngrid_coarse": 300},
grid_type="log",
)

# extract the total free energy
Expand Down

0 comments on commit cd204de

Please sign in to comment.