Skip to content

Commit

Permalink
add charge and spin to ase data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Apr 1, 2024
1 parent 012caf3 commit 885e71e
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 66 deletions.
48 changes: 24 additions & 24 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,18 @@ def run_training(config: config_dict.ConfigDict, model: str = 'so3krates'):
f'Loader assumes that SPICE is in hdf5 format. Found {data_filepath.suffix} as'
f'suffix.')
loader = data.SpiceDataLoaderSparse(input_file=data_filepath)
elif data_filepath.stem[:4].lower() == 'qcml':
logging.mlff(f'Found QCML dataset at {data_filepath}.')
if data_filepath.suffix != '.hdf5':
raise ValueError(
f'Loader assumes that QCML is in hdf5 format. Found {data_filepath.suffix} as'
f'suffix.')
loader = data.QCMLLoaderSparse(
input_file=data_filepath,
# We need to do the inverse transforms, since in config everything is in ASE default units.
min_distance_filter=config.data.filter.min_distance / length_unit,
max_force_filter=config.data.filter.max_force / energy_unit * length_unit,
)
# elif data_filepath.stem[:4].lower() == 'qcml':
# logging.mlff(f'Found QCML dataset at {data_filepath}.')
# if data_filepath.suffix != '.hdf5':
# raise ValueError(
# f'Loader assumes that QCML is in hdf5 format. Found {data_filepath.suffix} as'
# f'suffix.')
# loader = data.QCMLLoaderSparse(
# input_file=data_filepath,
# # We need to do the inverse transforms, since in config everything is in ASE default units.
# min_distance_filter=config.data.filter.min_distance / length_unit,
# max_force_filter=config.data.filter.max_force / energy_unit * length_unit,
# )
elif data_filepath.is_dir():
tf_record_present = len([1 for x in os.scandir(data_filepath) if Path(x).suffix[:9] == '.tfrecord']) > 0
if tf_record_present:
Expand Down Expand Up @@ -388,18 +388,18 @@ def run_evaluation(
f'Loader assumes that SPICE is in hdf5 format. Found {data_filepath.suffix} as'
f'suffix.')
loader = data.SpiceDataLoaderSparse(input_file=data_filepath)
elif data_filepath.stem[:4].lower() == 'qcml':
logging.mlff(f'Found QCML dataset at {data_filepath}.')
if data_filepath.suffix != '.hdf5':
raise ValueError(
f'Loader assumes that QCML is in hdf5 format. Found {data_filepath.suffix} as'
f'suffix.')
loader = data.QCMLLoaderSparse(
input_file=data_filepath,
# We need to do the inverse transforms, since in config everything is in ASE default units.
min_distance_filter=config.data.filter.min_distance / length_unit,
max_force_filter=config.data.filter.max_force / energy_unit * length_unit,
)
# elif data_filepath.stem[:4].lower() == 'qcml':
# logging.mlff(f'Found QCML dataset at {data_filepath}.')
# if data_filepath.suffix != '.hdf5':
# raise ValueError(
# f'Loader assumes that QCML is in hdf5 format. Found {data_filepath.suffix} as'
# f'suffix.')
# loader = data.QCMLLoaderSparse(
# input_file=data_filepath,
# # We need to do the inverse transforms, since in config everything is in ASE default units.
# min_distance_filter=config.data.filter.min_distance / length_unit,
# max_force_filter=config.data.filter.max_force / energy_unit * length_unit,
# )
elif data_filepath.is_dir():
tf_record_present = len([1 for x in os.scandir(data_filepath) if Path(x).suffix[:9] == '.tfrecord']) > 0
if tf_record_present:
Expand Down
58 changes: 30 additions & 28 deletions mlff/data/data_loader_qcml.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,40 +63,42 @@ def keep(idx: int):
for k in tqdm(data):
if keep(i):
positions = np.array(data[k]['positions'])
atomic_numbers = data[k]['atomic_numbers']
atomic_numbers = np.array(data[k]['atomic_numbers'])
forces = data[k]['forces']
energy = data[k]['energy']
charge = data[k]['charge']
multiplicity = data[k]['multiplicity']
if len(atomic_numbers) > 1:
senders, receivers, minimal_distance = compute_senders_and_receivers_np(
positions,
cutoff=cutoff
)

senders, receivers, minimal_distance = compute_senders_and_receivers_np(
positions,
cutoff=cutoff
)

if (
minimal_distance < self.min_distance_filter or
np.abs(forces).max() > self.max_force_filter
):
g = None
if (
minimal_distance < self.min_distance_filter or
np.abs(forces).max() > self.max_force_filter
):
g = None
else:
g = jraph.GraphsTuple(
n_node=np.array([len(atomic_numbers)]),
n_edge=np.array([len(receivers)]),
globals=dict(
energy=np.array(energy).reshape(-1),
total_charge=np.array(charge).reshape(-1),
num_unpaired_electrons=np.array(multiplicity).reshape(-1) - 1
),
nodes=dict(
atomic_numbers=atomic_numbers.reshape(-1).astype(np.int16),
positions=positions,
forces=np.array(forces)
),
edges=dict(cell=None, cell_offsets=None),
receivers=np.array(senders), # opposite convention in mlff
senders=np.array(receivers)
)
else:
g = jraph.GraphsTuple(
n_node=np.array([len(atomic_numbers)]),
n_edge=np.array([len(receivers)]),
globals=dict(
energy=np.array(energy).reshape(-1),
total_charge=np.array(charge).reshape(-1),
num_unpaired_electrons=np.array(multiplicity).reshape(-1) - 1
),
nodes=dict(
atomic_numbers=np.array(atomic_numbers).reshape(-1).astype(np.int16),
positions=positions,
forces=np.array(forces)
),
edges=dict(cell=None, cell_offsets=None),
receivers=np.array(senders), # opposite convention in mlff
senders=np.array(receivers)
)
g = None

loaded_data += [g]
if g is not None:
Expand Down
75 changes: 61 additions & 14 deletions mlff/data/dataloader_sparse_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,26 @@
logging.mlff = partial(logging.log, logging.MLFF)


def compute_senders_and_receivers_np(
positions, cutoff: float
):
"""Computes an edge list from atom positions and a fixed cutoff radius."""
num_atoms = positions.shape[0]
displacements = positions[None, :, :] - positions[:, None, :]
distances = np.linalg.norm(displacements, axis=-1)
distances_min = distances[distances > 0].min()
mask = ~np.eye(num_atoms, dtype=np.bool_) # get rid of self interactions
keep_edges = np.where((distances < cutoff) & mask)
senders = keep_edges[0].astype(np.int32)
receivers = keep_edges[1].astype(np.int32)
return senders, receivers, distances_min


@dataclass
class AseDataLoaderSparse:
input_file: str
min_distance_filter: float = 0.
max_force_filter: float = 1.e6

def cardinality(self):
n = 0
Expand All @@ -44,12 +61,20 @@ def keep(idx: int):
i = 0
for a in tqdm(iread(self.input_file)):
if keep(i):
graph = ASE_to_jraph(a, cutoff=cutoff)
num_nodes = len(graph.nodes['atomic_numbers'])
num_edges = len(graph.receivers)
max_num_of_nodes = max_num_of_nodes if num_nodes <= max_num_of_nodes else num_nodes
max_num_of_edges = max_num_of_edges if num_edges <= max_num_of_edges else num_edges
graph = ASE_to_jraph(
a,
min_distance_filter=self.min_distance_filter,
max_force_filter=self.max_force_filter,
cutoff=cutoff
)

loaded_data.append(graph)

if graph is not None:
num_nodes = len(graph.nodes['atomic_numbers'])
num_edges = len(graph.receivers)
max_num_of_nodes = max_num_of_nodes if num_nodes <= max_num_of_nodes else num_nodes
max_num_of_edges = max_num_of_edges if num_edges <= max_num_of_edges else num_edges
else:
pass
i += 1
Expand All @@ -67,17 +92,21 @@ def keep(idx: int):
def ASE_to_jraph(
mol: Atoms,
cutoff: float,
min_distance_filter: float,
max_force_filter: float,
self_interaction: bool = False,
) -> jraph.GraphsTuple:
):
"""Convert an ASE Atoms object to a jraph.GraphTuple object.
Args:
mol (Atoms): ASE Atoms object.
cutoff (float): Cutoff radius for neighbor interactions.
min_distance_filter (float):
max_force_filter (float):
self_interaction (bool): Include self-interaction in neighbor list.
Returns:
jraph.GraphsTuple: Jraph graph representation of the Atoms object.
jraph.GraphsTuple: Jraph graph representation of the Atoms object if filter != True else None.
"""

atomic_numbers = mol.get_atomic_numbers()
Expand All @@ -100,35 +129,53 @@ def ASE_to_jraph(
forces = None
stress = None

total_charge = mol.info.get('total_charge')
multiplicity = mol.info.get('multiplicity')

if mol.get_pbc().any():
i, j, S = neighbor_list('ijS', mol, cutoff, self_interaction=self_interaction)
cell = np.array(mol.get_cell())
edge_features = {
"cell": np.repeat(np.array(cell)[None], repeats=len(S), axis=0),
"cell_offset": np.array(S)
}
senders = np.array(j)
receivers = np.array(i)
else:
i, j = neighbor_list('ij', mol, cutoff, self_interaction=self_interaction)
# i, j = neighbor_list('ij', mol, cutoff, self_interaction=self_interaction)
edge_features = {
"cell": None,
"cell_offset": None
}

if len(atomic_numbers) == 1:
return None

senders, receivers, minimal_distance = compute_senders_and_receivers_np(
positions,
cutoff=cutoff
)

if (
minimal_distance < min_distance_filter or
np.abs(forces).max() > max_force_filter
):
return None

node_features = {
"positions": np.array(positions),
"atomic_numbers": np.array(atomic_numbers),
"forces": np.array(forces) if forces is not None else None,
}

senders = np.array(j)
receivers = np.array(i)

n_node = np.array([mol.get_global_number_of_atoms()])
n_edge = np.array([len(i)])
n_edge = np.array([len(senders)])

global_context = {
"energy": np.array([energy]) if energy is not None else None,
"stress": np.array(stress) if stress is not None else None
"energy": np.array([energy]).reshape(-1) if energy is not None else None,
"stress": np.array(stress) if stress is not None else None,
"total_charge": np.array([total_charge]) if total_charge is not None else None,
"num_unpaired_electrons": np.array([multiplicity]) - 1 if multiplicity is not None else None,
}

return jraph.GraphsTuple(
Expand Down

0 comments on commit 885e71e

Please sign in to comment.