Skip to content

Commit

Permalink
replace value_and_grad by jacrec
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jan 11, 2024
1 parent a254a9c commit d74009f
Showing 1 changed file with 79 additions and 1 deletion.
80 changes: 79 additions & 1 deletion mlff/nn/stacknet/observable_function_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def observable_fn(p, x):
return observable_fn


def get_energy_and_force_fn_sparse(model: StackNetSparse):
def get_energy_and_force_fn_sparse_(model: StackNetSparse):
def energy_fn(params,
positions: jnp.ndarray,
atomic_numbers: jnp.ndarray,
Expand Down Expand Up @@ -99,3 +99,81 @@ def energy_and_force_fn(params,
return dict(energy=energy, forces=forces)

return energy_and_force_fn


def get_energy_and_force_fn_sparse(model: StackNetSparse):
def energy_fn(params,
positions: jnp.ndarray,
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
node_mask: jnp.ndarray = None,
graph_mask: jnp.ndarray = None):
if batch_segments is None:
assert graph_mask is None
assert node_mask is None

graph_mask = jnp.ones((1,)).astype(jnp.bool_) # (1)
node_mask = jnp.ones((len(positions),)).astype(jnp.bool_) # (num_nodes)
batch_segments = jnp.zeros_like(atomic_numbers) # (num_nodes)

inputs = dict(positions=positions,
atomic_numbers=atomic_numbers,
idx_i=idx_i,
idx_j=idx_j,
cell=cell,
cell_offset=cell_offset,
batch_segments=batch_segments,
node_mask=node_mask,
graph_mask=graph_mask
)

energy = model.apply(params, inputs)['energy'] # (num_graphs)
energy = jnp.where(graph_mask, energy, jnp.asarray(0., dtype=energy.dtype)) # (num_graphs)
return -jnp.sum(energy), energy # (), (num_graphs)

def energy_and_force_fn(params,
positions: jnp.ndarray,
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
node_mask: jnp.ndarray = None,
graph_mask: jnp.ndarray = None,
*args,
**kwargs):

# _, energy = energy_fn(
# params,
# positions,
# atomic_numbers,
# idx_i,
# idx_j,
# cell,
# cell_offset,
# batch_segments,
# node_mask,
# graph_mask
# )

forces, energy = jax.jacrev(energy_fn, argnums=1, has_aux=True)(
params,
positions,
atomic_numbers,
idx_i,
idx_j,
cell,
cell_offset,
batch_segments,
node_mask,
graph_mask
)

return dict(energy=energy, forces=forces)

return energy_and_force_fn

0 comments on commit d74009f

Please sign in to comment.