Skip to content

Commit

Permalink
Added a pyproject.toml and tested black
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Apr 16, 2024
1 parent 62bbaeb commit 0913795
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 44 deletions.
95 changes: 51 additions & 44 deletions mala/interfaces/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from ase.calculators.calculator import Calculator, all_changes
import numpy as np

from mala import Parameters, Network, DataHandler, Predictor, LDOS, Density, \
DOS
from mala import Parameters, Network, DataHandler, Predictor, LDOS, Density, DOS
from mala.common.parallelizer import get_rank, get_comm, barrier


Expand Down Expand Up @@ -38,34 +37,40 @@ class MALA(Calculator):
from the atomic positions.
"""

implemented_properties = ['energy', 'forces']
implemented_properties = ["energy", "forces"]

def __init__(self, params: Parameters, network: Network,
data: DataHandler, reference_data=None,
predictor=None):
def __init__(
self,
params: Parameters,
network: Network,
data: DataHandler,
reference_data=None,
predictor=None,
):
super(MALA, self).__init__()

# Copy the MALA relevant objects.
self.mala_parameters: Parameters = params
if self.mala_parameters.targets.target_type != "LDOS":
raise Exception("The MALA calculator currently only works with the"
"LDOS.")
raise Exception("The MALA calculator currently only works with the" "LDOS.")

self.network: Network = network
self.data_handler: DataHandler = data

# Prepare for prediction.
if predictor is None:
self.predictor = Predictor(self.mala_parameters, self.network,
self.data_handler)
self.predictor = Predictor(
self.mala_parameters, self.network, self.data_handler
)
else:
self.predictor = predictor

if reference_data is not None:
# Get critical values from a reference file (cutoff,
# temperature, etc.)
self.data_handler.target_calculator.\
read_additional_calculation_data(reference_data)
self.data_handler.target_calculator.read_additional_calculation_data(
reference_data
)

# Needed for e.g. Monte Carlo.
self.last_energy_contributions = {}
Expand All @@ -86,15 +91,15 @@ def load_model(cls, run_name, path="./"):
path : str
Path where the model is saved.
"""
loaded_params, loaded_network, \
new_datahandler, loaded_runner = Predictor.\
load_run(run_name, path=path)
calculator = cls(loaded_params, loaded_network, new_datahandler,
predictor=loaded_runner)
loaded_params, loaded_network, new_datahandler, loaded_runner = (
Predictor.load_run(run_name, path=path)
)
calculator = cls(
loaded_params, loaded_network, new_datahandler, predictor=loaded_runner
)
return calculator

def calculate(self, atoms=None, properties=['energy'],
system_changes=all_changes):
def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
"""
Perform the calculations.
Expand Down Expand Up @@ -123,24 +128,20 @@ def calculate(self, atoms=None, properties=['energy'],

# If an MPI environment is detected, ASE will use it for writing.
# Therefore we have to do this before forking.
self.data_handler.\
target_calculator.\
write_tem_input_file(atoms,
self.data_handler.
target_calculator.qe_input_data,
self.data_handler.
target_calculator.qe_pseudopotentials,
self.data_handler.
target_calculator.grid_dimensions,
self.data_handler.
target_calculator.kpoints)
self.data_handler.target_calculator.write_tem_input_file(
atoms,
self.data_handler.target_calculator.qe_input_data,
self.data_handler.target_calculator.qe_pseudopotentials,
self.data_handler.target_calculator.grid_dimensions,
self.data_handler.target_calculator.kpoints,
)

ldos_calculator: LDOS = self.data_handler.target_calculator

ldos_calculator.read_from_array(ldos)
energy, self.last_energy_contributions \
= ldos_calculator.get_total_energy(return_energy_contributions=
True)
energy, self.last_energy_contributions = ldos_calculator.get_total_energy(
return_energy_contributions=True
)
barrier()

# Use the LDOS determined DOS and density to get energy and forces.
Expand Down Expand Up @@ -170,17 +171,23 @@ def calculate_properties(self, atoms, properties):
# TODO: Check atoms.

if "rdf" in properties:
self.results["rdf"] = self.data_handler.target_calculator.\
get_radial_distribution_function(atoms)
self.results["rdf"] = (
self.data_handler.target_calculator.get_radial_distribution_function(
atoms
)
)
if "tpcf" in properties:
self.results["tpcf"] = self.data_handler.target_calculator.\
get_three_particle_correlation_function(atoms)
self.results["tpcf"] = (
self.data_handler.target_calculator.get_three_particle_correlation_function(
atoms
)
)
if "static_structure_factor" in properties:
self.results["static_structure_factor"] = self.data_handler.\
target_calculator.get_static_structure_factor(atoms)
self.results["static_structure_factor"] = (
self.data_handler.target_calculator.get_static_structure_factor(atoms)
)
if "ion_ion_energy" in properties:
self.results["ion_ion_energy"] = self.\
last_energy_contributions["e_ewald"]
self.results["ion_ion_energy"] = self.last_energy_contributions["e_ewald"]

def save_calculator(self, filename, save_path="./"):
"""
Expand All @@ -197,6 +204,6 @@ def save_calculator(self, filename, save_path="./"):
Path where the calculator should be saved.
"""
self.predictor.save_run(filename, save_path=save_path,
additional_calculation_data=True)

self.predictor.save_run(
filename, save_path=save_path, additional_calculation_data=True
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
line-length = 88

0 comments on commit 0913795

Please sign in to comment.