Skip to content

Commit

Permalink
use clu for metrics logging
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jan 22, 2024
1 parent 90c1a0c commit b63039c
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 45 deletions.
5 changes: 2 additions & 3 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,12 @@ def run_evaluation(config, num_test: int = None, testing_targets: Sequence[str]
num_train = config.training.num_train
num_valid = config.training.num_valid

if (num_train + num_valid + num_test) > num_data:
upper_bound = (num_train + num_valid + num_test) if num_test is not None else num_data
if upper_bound > num_data:
raise ValueError(
f'num_train + num_valid + num_test = {num_train + num_valid + num_test} > num_data = {num_data}.'
)

upper_bound = (num_train + num_valid + num_test) if num_test is not None else num_data

testing_data = data.transformations.subtract_atomic_energy_shifts(
data.transformations.unit_conversion(
all_data[(num_train + num_valid):upper_bound],
Expand Down
34 changes: 28 additions & 6 deletions mlff/utils/evaluation_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from clu import metrics
import flax.struct as flax_struct
import jax
import jax.numpy as jnp
import jraph
import numpy as np
from tqdm import tqdm
from typing import Any
from mlff.nn.stacknet.observable_function_sparse import get_energy_and_force_fn_sparse


@flax_struct.dataclass
class MetricsTesting(metrics.Collection):
energy_mse: metrics.Average.from_output('energy_mse')
forces_mse: metrics.Average.from_output('forces_mse')
energy_mae: metrics.Average.from_output('energy_mae')
forces_mae: metrics.Average.from_output('forces_mae')


def evaluate(
model,
params,
Expand Down Expand Up @@ -46,6 +57,7 @@ def evaluate(

# Start iteration over validation batches.
testing_metrics = []
test_metrics: Any = None
for graph_batch_testing in tqdm(iterator_testing):
batch_testing = graph_to_batch_fn(graph_batch_testing)
batch_testing = jax.tree_map(jnp.array, batch_testing)
Expand Down Expand Up @@ -74,15 +86,25 @@ def evaluate(
y_predicted=output_prediction[t], y_true=batch_testing[t], msk=msk
),

testing_metrics += [metrics_dict]
test_metrics = (
MetricsTesting.single_from_model_output(**metrics_dict)
if test_metrics is None
else test_metrics.merge(MetricsTesting.single_from_model_output(**metrics_dict))
)
test_metrics = test_metrics.compute()

testing_metrics_np = jax.device_get(testing_metrics)
testing_metrics_np = {
k: np.mean([m[k] for m in testing_metrics_np]) for k in testing_metrics_np[0]
# testing_metrics_np = jax.device_get(testing_metrics)
# testing_metrics_np = {
# k: np.mean([m[k] for m in testing_metrics_np]) for k in testing_metrics_np[0]
# }

test_metrics = {
f'test_{k}': float(v) for k, v in test_metrics.items()
}

for t in testing_targets:
testing_metrics_np[f'{t}_rmse'] = np.sqrt(testing_metrics_np[f'{t}_mse'])
return testing_metrics_np
test_metrics[f'test_{t}_rmse'] = np.sqrt(test_metrics[f'test_{t}_mse'])
return test_metrics


def calculate_mse(y_predicted, y_true, msk):
Expand Down
77 changes: 41 additions & 36 deletions mlff/utils/training_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from clu import metrics
import jraph
import jax
import jax.numpy as jnp
import numpy as np
import optax
from orbax import checkpoint
from pathlib import Path
from typing import Callable, Dict
from typing import Any, Callable, Dict
import wandb
from flax.core import unfreeze
from flax import struct as flax_struct

property_to_mask = {
'energy': 'graph_mask',
Expand All @@ -16,23 +18,20 @@
}


# def scaled_safe_masked_mse_loss(y, y_true, scale, msk):
# """
#
# Args:
# y (): shape: (B,d1, *, dN)
# y_true (): (B,d1, *, dN)
# scale (): (d1, *, dN) or everything broadcast-able to (B, d1, *, dN)
# msk (): shape: (B)
#
# Returns:
#
# """
# full_mask = ~jnp.isnan(y_true) & jnp.expand_dims(msk, [y_true.ndim - 1 - o for o in range(0, y_true.ndim - 1)])
# diff = jnp.where(full_mask, y_true, 0.) - jnp.where(full_mask, y, 0.)
# v = safe_mask(full_mask, fn=lambda u: scale * u ** 2, operand=diff)
# den = full_mask.reshape(-1).sum().astype(dtype=v.dtype)
# return safe_mask(den > 0, lambda x: v.reshape(-1).sum() / x, den, 0.)
@flax_struct.dataclass
class MetricsTraining(metrics.Collection):
loss: metrics.Average.from_output('loss')
grad_norm: metrics.Average.from_output('grad_norm')
energy_mse: metrics.Average.from_output('energy_mse')
forces_mse: metrics.Average.from_output('forces_mse')


@flax_struct.dataclass
class MetricsEvaluation(metrics.Collection):
loss: metrics.Average.from_output('loss')
energy_mse: metrics.Average.from_output('energy_mse')
forces_mse: metrics.Average.from_output('forces_mse')


def scaled_mse_loss(y, y_label, scale, mask):
full_mask = ~jnp.isnan(y_label) & jnp.expand_dims(mask, [y_label.ndim - 1 - o for o in range(0, y_label.ndim - 1)])
Expand Down Expand Up @@ -80,7 +79,7 @@ def loss_fn(params, batch: Dict[str, jnp.ndarray]):
)

loss += weights[target] * _l
metrics.update({target: _l / _scales[target].mean()})
metrics.update({f'{target}_mse': _l / _scales[target].mean()})

loss = jnp.reshape(loss, ())
metrics.update({'loss': loss})
Expand All @@ -106,11 +105,11 @@ def training_step_fn(params, opt_state, batch):
"""
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, batch)
if log_gradient_values:
metrics['gradient_norms'] = unfreeze(jax.tree_map(lambda x: jnp.linalg.norm(x.reshape(-1), axis=0), grads))
# if log_gradient_values:
# metrics['grad_norm'] = unfreeze(jax.tree_map(lambda x: jnp.linalg.norm(x.reshape(-1), axis=0), grads))
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params=params, updates=updates)
metrics['gradients_norm'] = optax.global_norm(grads)
metrics['grad_norm'] = optax.global_norm(grads)
return params, opt_state, metrics

return training_step_fn
Expand Down Expand Up @@ -263,7 +262,7 @@ def fit(
# Log training metrics.
if use_wandb:
wandb.log(
{'Training {}'.format(k): v for (k, v) in train_metrics_np.items()},
{f'train_{k}': v for (k, v) in train_metrics_np.items()},
step=step
)

Expand All @@ -277,34 +276,40 @@ def fit(
)

# Start iteration over validation batches.
validation_metrics = []
eval_metrics: Any = None
for graph_batch_validation in iterator_validation:
batch_validation = graph_to_batch_fn(graph_batch_validation)
batch_validation = jax.tree_map(jnp.array, batch_validation)

validation_metrics += [
validation_step_fn(
params,
batch_validation
)
]
eval_out = validation_step_fn(
params,
batch_validation
)

eval_metrics = (
MetricsEvaluation.single_from_model_output(**eval_out)
if eval_metrics is None
else eval_metrics.merge(MetricsEvaluation.single_from_model_output(**eval_out))
)

eval_metrics = eval_metrics.compute()

validation_metrics_np = jax.device_get(validation_metrics)
validation_metrics_np = {
k: np.mean([metrics[k] for metrics in validation_metrics]) for k in validation_metrics_np[0]
# Convert to dict to log with weights and bias.
eval_metrics = {
f'eval_{k}': float(v) for k, v in eval_metrics.items()
}

# Save checkpoint.
ckpt_mngr.save(
step,
args=checkpoint.args.Composite(params=checkpoint.args.StandardSave(params)),
metrics={'loss': validation_metrics_np['loss'].item()}
metrics={'loss': eval_metrics['eval_loss']}
)

# Log to weights and bias.
if use_wandb:
wandb.log({
f'Validation {k}': v for (k, v) in validation_metrics_np.items()},
wandb.log(
eval_metrics,
step=step
)
# Finished validation process.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
packages=find_packages(),
install_requires=[
"numpy",
"clu",
# "jax == 0.4.8",
"flax",
"jaxopt",
Expand Down

0 comments on commit b63039c

Please sign in to comment.