Skip to content

Commit

Permalink
update loading of params in calcualator and default total charge and …
Browse files Browse the repository at this point in the history
…spin to zero
  • Loading branch information
thorben-frank committed Jun 24, 2024
1 parent 88e25e6 commit 6ae8d8b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
32 changes: 23 additions & 9 deletions mlff/mdx/potential/mlff_potential_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,31 @@ def load_model_from_workdir(workdir: str, model='so3krates'):

loaded_mngr = checkpoint.CheckpointManager(
pathlib.Path(workdir) / "checkpoints",
{
"params": checkpoint.PyTreeCheckpointer(),
},
item_names=('params',),
item_handlers={'params': checkpoint.StandardCheckpointHandler()},
options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"),
)
mgr_state = loaded_mngr.restore(
loaded_mngr.latest_step(),
{
"params": checkpoint.PyTreeCheckpointer(),
})
params = mgr_state.get("params")

# loaded_mngr = checkpoint.CheckpointManager(
# pathlib.Path(workdir) / "checkpoints",
# {
# "params": checkpoint.PyTreeCheckpointer(),
# },
# options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"),
# )

mngr_state = loaded_mngr.restore(
loaded_mngr.latest_step()
)

# mgr_state = loaded_mngr.restore(
# loaded_mngr.latest_step(),
# {
# "params": checkpoint.PyTreeCheckpointer(),
# })
# params = mgr_state.get("params")

params = mngr_state.get('params')

if model == 'so3krates':
net = from_config.make_so3krates_sparse_from_config(cfg)
Expand Down
13 changes: 8 additions & 5 deletions mlff/nn/observable/observable_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ def __call__(self, inputs: Dict, *args, **kwargs):

num_graphs = len(graph_mask)
if self.learn_atomic_type_shifts:
energy_offset = self.param(
'energy_offset',
nn.initializers.zeros_init(),
(self.zmax + 1, )
)[atomic_numbers] # (num_nodes)
energy_offset = jnp.take(
self.param(
'energy_offset',
nn.initializers.zeros_init(),
(self.zmax + 1, )
),
atomic_numbers
) # (num_nodes)
else:
energy_offset = jnp.zeros((1,), dtype=x.dtype)

Expand Down
16 changes: 8 additions & 8 deletions mlff/nn/stacknet/observable_function_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def observable_fn(
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
total_charge: jnp.ndarray = jnp.array([0]),
num_unpaired_electrons: jnp.ndarray = jnp.array([0]),
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand Down Expand Up @@ -71,8 +71,8 @@ def observable_fn(
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
total_charge: jnp.ndarray = jnp.array([0]),
num_unpaired_electrons: jnp.ndarray = jnp.array([0]),
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand Down Expand Up @@ -113,8 +113,8 @@ def energy_fn(params,
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
total_charge: jnp.ndarray = jnp.array([0]),
num_unpaired_electrons: jnp.ndarray = jnp.array([0]),
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand Down Expand Up @@ -150,8 +150,8 @@ def energy_and_force_fn(params,
atomic_numbers: jnp.ndarray,
idx_i: jnp.ndarray,
idx_j: jnp.ndarray,
total_charge: jnp.ndarray = None,
num_unpaired_electrons: jnp.ndarray = None,
total_charge: jnp.ndarray = jnp.array([0]),
num_unpaired_electrons: jnp.ndarray = jnp.array([0]),
cell: jnp.ndarray = None,
cell_offset: jnp.ndarray = None,
batch_segments: jnp.ndarray = None,
Expand Down

0 comments on commit 6ae8d8b

Please sign in to comment.