Skip to content

Commit

Permalink
Blackified code
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomDefaultUser committed Apr 17, 2024
1 parent 0913795 commit 9ce07a6
Show file tree
Hide file tree
Showing 60 changed files with 6,796 additions and 3,971 deletions.
50 changes: 38 additions & 12 deletions mala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,43 @@
"""

from .version import __version__
from .common import Parameters, printout, check_modules, get_size, get_rank, \
finalize
from .descriptors import Bispectrum, Descriptor, AtomicDensity, \
MinterpyDescriptors
from .datahandling import DataHandler, DataScaler, DataConverter, Snapshot, \
DataShuffler
from .network import Network, Tester, Trainer, HyperOpt, \
HyperOptOptuna, HyperOptNASWOT, HyperOptOAT, Predictor, \
HyperparameterOAT, HyperparameterNASWOT, HyperparameterOptuna, \
HyperparameterACSD, ACSDAnalyzer, Runner
from .targets import LDOS, DOS, Density, fermi_function, \
AtomicForce, Target
from .common import (
Parameters,
printout,
check_modules,
get_size,
get_rank,
finalize,
)
from .descriptors import (
Bispectrum,
Descriptor,
AtomicDensity,
MinterpyDescriptors,
)
from .datahandling import (
DataHandler,
DataScaler,
DataConverter,
Snapshot,
DataShuffler,
)
from .network import (
Network,
Tester,
Trainer,
HyperOpt,
HyperOptOptuna,
HyperOptNASWOT,
HyperOptOAT,
Predictor,
HyperparameterOAT,
HyperparameterNASWOT,
HyperparameterOptuna,
HyperparameterACSD,
ACSDAnalyzer,
Runner,
)
from .targets import LDOS, DOS, Density, fermi_function, AtomicForce, Target
from .interfaces import MALA
from .datageneration import TrajectoryAnalyzer, OFDFTInitializer
1 change: 1 addition & 0 deletions mala/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""General functions for MALA, such as parameters."""

from .parameters import Parameters
from .parallelizer import printout, get_rank, get_size, finalize
from .check_modules import check_modules
73 changes: 48 additions & 25 deletions mala/common/check_modules.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,65 @@
"""Function to check module availability in MALA."""

import importlib


def check_modules():
"""Check whether/which optional modules MALA can access."""
# The optional libs in MALA.
optional_libs = {
"mpi4py": {"available": False, "description":
"Enables inference parallelization."},
"horovod": {"available": False, "description":
"Enables training parallelization."},
"lammps": {"available": False, "description":
"Enables descriptor calculation for data preprocessing "
"and inference."},
"oapackage": {"available": False, "description":
"Enables usage of OAT method for hyperparameter "
"optimization."},
"total_energy": {"available": False, "description":
"Enables calculation of total energy."},
"asap3": {"available": False, "description":
"Enables trajectory analysis."},
"dftpy": {"available": False, "description":
"Enables OF-DFT-MD initialization."},
"minterpy": {"available": False, "description":
"Enables minterpy descriptor calculation for data preprocessing."}
"mpi4py": {
"available": False,
"description": "Enables inference parallelization.",
},
"horovod": {
"available": False,
"description": "Enables training parallelization.",
},
"lammps": {
"available": False,
"description": "Enables descriptor calculation for data preprocessing "
"and inference.",
},
"oapackage": {
"available": False,
"description": "Enables usage of OAT method for hyperparameter "
"optimization.",
},
"total_energy": {
"available": False,
"description": "Enables calculation of total energy.",
},
"asap3": {
"available": False,
"description": "Enables trajectory analysis.",
},
"dftpy": {
"available": False,
"description": "Enables OF-DFT-MD initialization.",
},
"minterpy": {
"available": False,
"description": "Enables minterpy descriptor calculation for data preprocessing.",
},
}

# Find out if libs are available.
for lib in optional_libs:
optional_libs[lib]["available"] = importlib.util.find_spec(lib) \
is not None
optional_libs[lib]["available"] = (
importlib.util.find_spec(lib) is not None
)

# Print info about libs.
print("The following optional modules are available in MALA:")
for lib in optional_libs:
available_string = "installed" if optional_libs[lib]["available"] \
else "not installed"
print("{0}: \t {1} \t {2}".format(lib, available_string,
optional_libs[lib]["description"]))
optional_libs[lib]["available"] = \
available_string = (
"installed" if optional_libs[lib]["available"] else "not installed"
)
print(
"{0}: \t {1} \t {2}".format(
lib, available_string, optional_libs[lib]["description"]
)
)
optional_libs[lib]["available"] = (
importlib.util.find_spec(lib) is not None
)
8 changes: 4 additions & 4 deletions mala/common/json_serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def from_json(cls, json_dict):

def _standard_serializer(self):
data = {}
members = inspect.getmembers(self,
lambda a: not (inspect.isroutine(a)))
members = inspect.getmembers(
self, lambda a: not (inspect.isroutine(a))
)
for member in members:
# Filter out all private members, builtins, etc.
if member[0][0] != "_":
data[member[0]] = member[1]
json_dict = {"object": type(self).__name__,
"data": data}
json_dict = {"object": type(self).__name__, "data": data}
return json_dict

@classmethod
Expand Down
20 changes: 13 additions & 7 deletions mala/common/parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions for operating MALA in parallel."""

from collections import defaultdict
import platform
import warnings
Expand Down Expand Up @@ -46,8 +47,10 @@ def set_horovod_status(new_value):
"""
if use_mpi is True and new_value is True:
raise Exception("Cannot use horovod and inference-level MPI at "
"the same time yet.")
raise Exception(
"Cannot use horovod and inference-level MPI at "
"the same time yet."
)
global use_horovod
use_horovod = new_value

Expand All @@ -66,8 +69,10 @@ def set_mpi_status(new_value):
"""
if use_horovod is True and new_value is True:
raise Exception("Cannot use horovod and inference-level MPI at "
"the same time yet.")
raise Exception(
"Cannot use horovod and inference-level MPI at "
"the same time yet."
)
global use_mpi
use_mpi = new_value
if use_mpi:
Expand Down Expand Up @@ -96,6 +101,7 @@ def set_lammps_instance(new_instance):
"""
import lammps

global lammps_instance
if isinstance(new_instance, lammps.core.lammps):
lammps_instance = new_instance
Expand Down Expand Up @@ -162,7 +168,7 @@ def get_local_rank():
ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
node2rankssofar = defaultdict(int)
local_rank = None
for (rank, node) in ranks_nodes:
for rank, node in ranks_nodes:
if rank == comm.Get_rank():
local_rank = node2rankssofar[node]
node2rankssofar[node] += 1
Expand Down Expand Up @@ -204,13 +210,13 @@ def get_comm():
def barrier():
"""General interface for a barrier."""
if use_horovod:
hvd.allreduce(torch.tensor(0), name='barrier')
hvd.allreduce(torch.tensor(0), name="barrier")
if use_mpi:
comm.Barrier()
return


def printout(*values, sep=' ', min_verbosity=0):
def printout(*values, sep=" ", min_verbosity=0):
"""
Interface to built-in "print" for parallel runs. Can be used like print.
Expand Down
Loading

0 comments on commit 9ce07a6

Please sign in to comment.