Skip to content

Commit

Permalink
Benchmark Foreign Models (#33)
Browse files Browse the repository at this point in the history
* Implement new methods to extract only the initial states

* Check if ic_sets produced are actually the zeroth frame

* Add method to compute metric rollouts based on external rollout

* Test external metric computation based on a single seed

* Extend test to multiple seeds

* Fix vectorized axis for multi seed config

* Add docstring

* Add flax.nnx foreign model benchmark guide

* add guide notebook to documentation

* Add tutorial for flax.linen

* Add flax linen example to docs

* Guide for benchmarking PyTorch models

* Add PyTorch guide to docs

* Link to foreign benchmarking notebooks
  • Loading branch information
Ceyron authored Nov 8, 2024
1 parent 3c6aebe commit 2eb8143
Show file tree
Hide file tree
Showing 8 changed files with 2,982 additions and 4 deletions.
182 changes: 181 additions & 1 deletion apebench/_base_scenario.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, Optional, Union

import equinox as eqx
import exponax as ex
Expand Down Expand Up @@ -156,6 +156,82 @@ def get_optimizer(self) -> optax.GradientTransformation:

return optimizer

def _produce_ic_set(
self,
*,
stepper: BaseStepper,
num_samples: int,
key: PRNGKeyArray,
) -> Float[Array, "num_samples num_channels *num_points"]:
"""
Generate the number of initial conditions as samples requested and
discretize them on the grid. Requires the `stepper` to warmup the
initial conditions if necessary.
"""
ic_distribution = self.get_ic_generator()
ic_set = ex.build_ic_set(
ic_distribution,
num_points=self.num_points,
num_samples=num_samples,
key=key,
)
if self.num_warmup_steps > 0:
ic_set = jax.vmap(
ex.repeat(
stepper,
self.num_warmup_steps,
)
)(ic_set)
return ic_set

def get_train_ic_set(
self,
) -> Float[Array, "num_train_samples num_channels *num_points"]:
"""
Use the attributes to produce the reference training initial condition set.
"""
return self._produce_ic_set(
stepper=self.get_ref_stepper(),
num_samples=self.num_train_samples,
key=jax.random.PRNGKey(self.train_seed),
)

def get_train_ic_set_coarse(
self,
) -> Float[Array, "num_train_samples num_channels *num_points"]:
"""
Use the attributes to produce training initial conditions with the coarse stepper instead.
"""
return self._produce_ic_set(
stepper=self.get_coarse_stepper(),
num_samples=self.num_train_samples,
key=jax.random.PRNGKey(self.train_seed),
)

def get_test_ic_set(
self,
) -> Float[Array, "num_test_samples num_channels *num_points"]:
"""
Use the attributes to produce the reference testing initial condition set.
"""
return self._produce_ic_set(
stepper=self.get_ref_stepper(),
num_samples=self.num_test_samples,
key=jax.random.PRNGKey(self.test_seed),
)

def get_test_ic_set_coarse(
self,
) -> Float[Array, "num_test_samples num_channels *num_points"]:
"""
Use the attributes to produce testing initial conditions with the coarse stepper instead.
"""
return self._produce_ic_set(
stepper=self.get_coarse_stepper(),
num_samples=self.num_test_samples,
key=jax.random.PRNGKey(self.test_seed),
)

def produce_data(
self,
*,
Expand Down Expand Up @@ -708,6 +784,110 @@ def perform_tests(

return results

def perform_tests_on_rollout(
self,
neural_rollout: Union[
Float[
Array,
"num_samples test_temporal_horizon num_channels *num_points",
],
Float[
Array,
"num_seeds num_samples test_temporal_horizon num_channels *num_points",
],
],
test_data_no_init: Optional[
Float[
Array,
"num_samples test_temporal_horizon num_channels *num_points",
]
] = None,
) -> Union[
dict[str, Float[Array, "test_temporal_horizon"]],
dict[str, Float[Array, "num_seeds test_temporal_horizon"]],
]:
"""
Compute all error metrics of the `report_metrics` attribute on an
externally produce rollout.
!!! tip
Use this function to benchmark external models by producing the
initial states with `scenario.get_test_ic_set()`, roll them out in
their respective framework for `test_temporal_horizon` steps, and
then call this function on the produced rollout. (Some frameworks
require different array formats, e.g.,
[TensorFlow](https://github.com/tensorflow/tensorflow) and
[Flax](https://github.com/google/flax) are typically channels-last.
Hence, some reshaping might be necessary.)
!!! warning
The `neural_rollout` must **not** contain the initial conditions as
the zeroth frame.
**Arguments:**
- `neural_rollout`: The neural rollout to be tested.
- `test_data_no_init`: The test data without the initial conditions.
If not provided, the test data is procedurally generated.
**Returns:**
- `results`: A dictionary with the metric names as keys and the error
rollouts. The rollout arrays always have a leading `num_seeds` axis,
even if the `neural_rollout` did not. This is to ensure
compatibility with the `perform_tests` function. Certainly, this
will be a singleton axis if `neural_rollout` did not have a leading
axis.
"""
if test_data_no_init is None:
test_data_no_init = self.get_test_data()[:, 1:]

if neural_rollout.shape == test_data_no_init.shape:
multi_seed = False
else:
if neural_rollout.shape[1:] == test_data_no_init.shape:
multi_seed = True
else:
raise ValueError(
f"""The shape of the neural rollout is {neural_rollout.shape}
and the shape of the test data is {test_data_no_init.shape}. They should
either be the same or the neural rollout should have an additional
leading axis for the seeds."""
)

metric_function_dict = self.get_metric_fns()

results = {}

for metric_config, func in metric_function_dict.items():
if multi_seed:
func_vectorized = jax.vmap(
jax.vmap(
func,
# Vectorize over time steps
in_axes=(1, 1),
),
# Vectorize over seeds (but broadcast the reference)
in_axes=(0, None),
)
else:
func_vectorized = jax.vmap(
func,
# Vectorize over time steps
in_axes=(1, 1),
)

metric_rollout = func_vectorized(neural_rollout, test_data_no_init)
if not multi_seed:
# Add singletone seed axis for compatibility with `perform_tests`
metric_rollout = metric_rollout[None]

results[metric_config] = metric_rollout

return results

def sample_trjs(
self, neural_stepper: eqx.Module
) -> Float[
Expand Down
Loading

0 comments on commit 2eb8143

Please sign in to comment.