From 09137950e5c297b6e61c58d360b19e3b0e02726e Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Tue, 16 Apr 2024 17:38:10 +0200 Subject: [PATCH] Added a pyproject.toml and tested black --- mala/interfaces/ase_calculator.py | 95 +++++++++++++++++-------------- pyproject.toml | 2 + 2 files changed, 53 insertions(+), 44 deletions(-) create mode 100644 pyproject.toml diff --git a/mala/interfaces/ase_calculator.py b/mala/interfaces/ase_calculator.py index f935271ad..fdb5fc8b1 100644 --- a/mala/interfaces/ase_calculator.py +++ b/mala/interfaces/ase_calculator.py @@ -3,8 +3,7 @@ from ase.calculators.calculator import Calculator, all_changes import numpy as np -from mala import Parameters, Network, DataHandler, Predictor, LDOS, Density, \ - DOS +from mala import Parameters, Network, DataHandler, Predictor, LDOS, Density, DOS from mala.common.parallelizer import get_rank, get_comm, barrier @@ -38,34 +37,40 @@ class MALA(Calculator): from the atomic positions. """ - implemented_properties = ['energy', 'forces'] + implemented_properties = ["energy", "forces"] - def __init__(self, params: Parameters, network: Network, - data: DataHandler, reference_data=None, - predictor=None): + def __init__( + self, + params: Parameters, + network: Network, + data: DataHandler, + reference_data=None, + predictor=None, + ): super(MALA, self).__init__() # Copy the MALA relevant objects. self.mala_parameters: Parameters = params if self.mala_parameters.targets.target_type != "LDOS": - raise Exception("The MALA calculator currently only works with the" - "LDOS.") + raise Exception("The MALA calculator currently only works with the" "LDOS.") self.network: Network = network self.data_handler: DataHandler = data # Prepare for prediction. if predictor is None: - self.predictor = Predictor(self.mala_parameters, self.network, - self.data_handler) + self.predictor = Predictor( + self.mala_parameters, self.network, self.data_handler + ) else: self.predictor = predictor if reference_data is not None: # Get critical values from a reference file (cutoff, # temperature, etc.) - self.data_handler.target_calculator.\ - read_additional_calculation_data(reference_data) + self.data_handler.target_calculator.read_additional_calculation_data( + reference_data + ) # Needed for e.g. Monte Carlo. self.last_energy_contributions = {} @@ -86,15 +91,15 @@ def load_model(cls, run_name, path="./"): path : str Path where the model is saved. """ - loaded_params, loaded_network, \ - new_datahandler, loaded_runner = Predictor.\ - load_run(run_name, path=path) - calculator = cls(loaded_params, loaded_network, new_datahandler, - predictor=loaded_runner) + loaded_params, loaded_network, new_datahandler, loaded_runner = ( + Predictor.load_run(run_name, path=path) + ) + calculator = cls( + loaded_params, loaded_network, new_datahandler, predictor=loaded_runner + ) return calculator - def calculate(self, atoms=None, properties=['energy'], - system_changes=all_changes): + def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): """ Perform the calculations. @@ -123,24 +128,20 @@ def calculate(self, atoms=None, properties=['energy'], # If an MPI environment is detected, ASE will use it for writing. # Therefore we have to do this before forking. - self.data_handler.\ - target_calculator.\ - write_tem_input_file(atoms, - self.data_handler. - target_calculator.qe_input_data, - self.data_handler. - target_calculator.qe_pseudopotentials, - self.data_handler. - target_calculator.grid_dimensions, - self.data_handler. - target_calculator.kpoints) + self.data_handler.target_calculator.write_tem_input_file( + atoms, + self.data_handler.target_calculator.qe_input_data, + self.data_handler.target_calculator.qe_pseudopotentials, + self.data_handler.target_calculator.grid_dimensions, + self.data_handler.target_calculator.kpoints, + ) ldos_calculator: LDOS = self.data_handler.target_calculator ldos_calculator.read_from_array(ldos) - energy, self.last_energy_contributions \ - = ldos_calculator.get_total_energy(return_energy_contributions= - True) + energy, self.last_energy_contributions = ldos_calculator.get_total_energy( + return_energy_contributions=True + ) barrier() # Use the LDOS determined DOS and density to get energy and forces. @@ -170,17 +171,23 @@ def calculate_properties(self, atoms, properties): # TODO: Check atoms. if "rdf" in properties: - self.results["rdf"] = self.data_handler.target_calculator.\ - get_radial_distribution_function(atoms) + self.results["rdf"] = ( + self.data_handler.target_calculator.get_radial_distribution_function( + atoms + ) + ) if "tpcf" in properties: - self.results["tpcf"] = self.data_handler.target_calculator.\ - get_three_particle_correlation_function(atoms) + self.results["tpcf"] = ( + self.data_handler.target_calculator.get_three_particle_correlation_function( + atoms + ) + ) if "static_structure_factor" in properties: - self.results["static_structure_factor"] = self.data_handler.\ - target_calculator.get_static_structure_factor(atoms) + self.results["static_structure_factor"] = ( + self.data_handler.target_calculator.get_static_structure_factor(atoms) + ) if "ion_ion_energy" in properties: - self.results["ion_ion_energy"] = self.\ - last_energy_contributions["e_ewald"] + self.results["ion_ion_energy"] = self.last_energy_contributions["e_ewald"] def save_calculator(self, filename, save_path="./"): """ @@ -197,6 +204,6 @@ def save_calculator(self, filename, save_path="./"): Path where the calculator should be saved. """ - self.predictor.save_run(filename, save_path=save_path, - additional_calculation_data=True) - + self.predictor.save_run( + filename, save_path=save_path, additional_calculation_data=True + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8bb6ee5f5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 88