From b63039c2ac9c0ba045cc320e375390d695040ed5 Mon Sep 17 00:00:00 2001 From: Thorben Frank Date: Mon, 22 Jan 2024 17:41:57 +0100 Subject: [PATCH] use clu for metrics logging --- mlff/config/from_config.py | 5 +-- mlff/utils/evaluation_utils.py | 34 ++++++++++++--- mlff/utils/training_utils.py | 77 ++++++++++++++++++---------------- setup.py | 1 + 4 files changed, 72 insertions(+), 45 deletions(-) diff --git a/mlff/config/from_config.py b/mlff/config/from_config.py index 651d77e..1d8ae20 100644 --- a/mlff/config/from_config.py +++ b/mlff/config/from_config.py @@ -252,13 +252,12 @@ def run_evaluation(config, num_test: int = None, testing_targets: Sequence[str] num_train = config.training.num_train num_valid = config.training.num_valid - if (num_train + num_valid + num_test) > num_data: + upper_bound = (num_train + num_valid + num_test) if num_test is not None else num_data + if upper_bound > num_data: raise ValueError( f'num_train + num_valid + num_test = {num_train + num_valid + num_test} > num_data = {num_data}.' ) - upper_bound = (num_train + num_valid + num_test) if num_test is not None else num_data - testing_data = data.transformations.subtract_atomic_energy_shifts( data.transformations.unit_conversion( all_data[(num_train + num_valid):upper_bound], diff --git a/mlff/utils/evaluation_utils.py b/mlff/utils/evaluation_utils.py index 077d1ff..79b97c1 100644 --- a/mlff/utils/evaluation_utils.py +++ b/mlff/utils/evaluation_utils.py @@ -1,11 +1,22 @@ +from clu import metrics +import flax.struct as flax_struct import jax import jax.numpy as jnp import jraph import numpy as np from tqdm import tqdm +from typing import Any from mlff.nn.stacknet.observable_function_sparse import get_energy_and_force_fn_sparse +@flax_struct.dataclass +class MetricsTesting(metrics.Collection): + energy_mse: metrics.Average.from_output('energy_mse') + forces_mse: metrics.Average.from_output('forces_mse') + energy_mae: metrics.Average.from_output('energy_mae') + forces_mae: metrics.Average.from_output('forces_mae') + + def evaluate( model, params, @@ -46,6 +57,7 @@ def evaluate( # Start iteration over validation batches. testing_metrics = [] + test_metrics: Any = None for graph_batch_testing in tqdm(iterator_testing): batch_testing = graph_to_batch_fn(graph_batch_testing) batch_testing = jax.tree_map(jnp.array, batch_testing) @@ -74,15 +86,25 @@ def evaluate( y_predicted=output_prediction[t], y_true=batch_testing[t], msk=msk ), - testing_metrics += [metrics_dict] + test_metrics = ( + MetricsTesting.single_from_model_output(**metrics_dict) + if test_metrics is None + else test_metrics.merge(MetricsTesting.single_from_model_output(**metrics_dict)) + ) + test_metrics = test_metrics.compute() - testing_metrics_np = jax.device_get(testing_metrics) - testing_metrics_np = { - k: np.mean([m[k] for m in testing_metrics_np]) for k in testing_metrics_np[0] + # testing_metrics_np = jax.device_get(testing_metrics) + # testing_metrics_np = { + # k: np.mean([m[k] for m in testing_metrics_np]) for k in testing_metrics_np[0] + # } + + test_metrics = { + f'test_{k}': float(v) for k, v in test_metrics.items() } + for t in testing_targets: - testing_metrics_np[f'{t}_rmse'] = np.sqrt(testing_metrics_np[f'{t}_mse']) - return testing_metrics_np + test_metrics[f'test_{t}_rmse'] = np.sqrt(test_metrics[f'test_{t}_mse']) + return test_metrics def calculate_mse(y_predicted, y_true, msk): diff --git a/mlff/utils/training_utils.py b/mlff/utils/training_utils.py index c79ebe5..f795115 100644 --- a/mlff/utils/training_utils.py +++ b/mlff/utils/training_utils.py @@ -1,3 +1,4 @@ +from clu import metrics import jraph import jax import jax.numpy as jnp @@ -5,9 +6,10 @@ import optax from orbax import checkpoint from pathlib import Path -from typing import Callable, Dict +from typing import Any, Callable, Dict import wandb from flax.core import unfreeze +from flax import struct as flax_struct property_to_mask = { 'energy': 'graph_mask', @@ -16,23 +18,20 @@ } -# def scaled_safe_masked_mse_loss(y, y_true, scale, msk): -# """ -# -# Args: -# y (): shape: (B,d1, *, dN) -# y_true (): (B,d1, *, dN) -# scale (): (d1, *, dN) or everything broadcast-able to (B, d1, *, dN) -# msk (): shape: (B) -# -# Returns: -# -# """ -# full_mask = ~jnp.isnan(y_true) & jnp.expand_dims(msk, [y_true.ndim - 1 - o for o in range(0, y_true.ndim - 1)]) -# diff = jnp.where(full_mask, y_true, 0.) - jnp.where(full_mask, y, 0.) -# v = safe_mask(full_mask, fn=lambda u: scale * u ** 2, operand=diff) -# den = full_mask.reshape(-1).sum().astype(dtype=v.dtype) -# return safe_mask(den > 0, lambda x: v.reshape(-1).sum() / x, den, 0.) +@flax_struct.dataclass +class MetricsTraining(metrics.Collection): + loss: metrics.Average.from_output('loss') + grad_norm: metrics.Average.from_output('grad_norm') + energy_mse: metrics.Average.from_output('energy_mse') + forces_mse: metrics.Average.from_output('forces_mse') + + +@flax_struct.dataclass +class MetricsEvaluation(metrics.Collection): + loss: metrics.Average.from_output('loss') + energy_mse: metrics.Average.from_output('energy_mse') + forces_mse: metrics.Average.from_output('forces_mse') + def scaled_mse_loss(y, y_label, scale, mask): full_mask = ~jnp.isnan(y_label) & jnp.expand_dims(mask, [y_label.ndim - 1 - o for o in range(0, y_label.ndim - 1)]) @@ -80,7 +79,7 @@ def loss_fn(params, batch: Dict[str, jnp.ndarray]): ) loss += weights[target] * _l - metrics.update({target: _l / _scales[target].mean()}) + metrics.update({f'{target}_mse': _l / _scales[target].mean()}) loss = jnp.reshape(loss, ()) metrics.update({'loss': loss}) @@ -106,11 +105,11 @@ def training_step_fn(params, opt_state, batch): """ (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, batch) - if log_gradient_values: - metrics['gradient_norms'] = unfreeze(jax.tree_map(lambda x: jnp.linalg.norm(x.reshape(-1), axis=0), grads)) + # if log_gradient_values: + # metrics['grad_norm'] = unfreeze(jax.tree_map(lambda x: jnp.linalg.norm(x.reshape(-1), axis=0), grads)) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params=params, updates=updates) - metrics['gradients_norm'] = optax.global_norm(grads) + metrics['grad_norm'] = optax.global_norm(grads) return params, opt_state, metrics return training_step_fn @@ -263,7 +262,7 @@ def fit( # Log training metrics. if use_wandb: wandb.log( - {'Training {}'.format(k): v for (k, v) in train_metrics_np.items()}, + {f'train_{k}': v for (k, v) in train_metrics_np.items()}, step=step ) @@ -277,34 +276,40 @@ def fit( ) # Start iteration over validation batches. - validation_metrics = [] + eval_metrics: Any = None for graph_batch_validation in iterator_validation: batch_validation = graph_to_batch_fn(graph_batch_validation) batch_validation = jax.tree_map(jnp.array, batch_validation) - validation_metrics += [ - validation_step_fn( - params, - batch_validation - ) - ] + eval_out = validation_step_fn( + params, + batch_validation + ) + + eval_metrics = ( + MetricsEvaluation.single_from_model_output(**eval_out) + if eval_metrics is None + else eval_metrics.merge(MetricsEvaluation.single_from_model_output(**eval_out)) + ) + + eval_metrics = eval_metrics.compute() - validation_metrics_np = jax.device_get(validation_metrics) - validation_metrics_np = { - k: np.mean([metrics[k] for metrics in validation_metrics]) for k in validation_metrics_np[0] + # Convert to dict to log with weights and bias. + eval_metrics = { + f'eval_{k}': float(v) for k, v in eval_metrics.items() } # Save checkpoint. ckpt_mngr.save( step, args=checkpoint.args.Composite(params=checkpoint.args.StandardSave(params)), - metrics={'loss': validation_metrics_np['loss'].item()} + metrics={'loss': eval_metrics['eval_loss']} ) # Log to weights and bias. if use_wandb: - wandb.log({ - f'Validation {k}': v for (k, v) in validation_metrics_np.items()}, + wandb.log( + eval_metrics, step=step ) # Finished validation process. diff --git a/setup.py b/setup.py index db0ce0a..c88c79f 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ packages=find_packages(), install_requires=[ "numpy", + "clu", # "jax == 0.4.8", "flax", "jaxopt",