Skip to content

Commit

Permalink
Change to a dictionary-based components system (#14)
Browse files Browse the repository at this point in the history
* Move optimization system into components

* Move components test into separate folder

* Add simple test for optimization components

* Move Initial conditions into separate components dict

* Fix copilot autocomplete

* Test set for new initial condition components

* Move activation functions into separate dict

* Add tests for activation fn components

* Add identity activation function

* Port architectures to components system

* Ensure dictionary access uses lower-case letters

* Add simple test set for architecture components

* Architecture extension dict no longer needed

* Adapt example of extending APEBench

* Fix wrong argument
  • Loading branch information
Ceyron authored Oct 25, 2024
1 parent b5f86ea commit 2597057
Show file tree
Hide file tree
Showing 14 changed files with 416 additions and 276 deletions.
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,17 +463,19 @@ print(metric_df.groupby(
## Extending APEBench

You can have experiments run with your architectures. For this, you have to
register them in the `apebench.arch_extensions` dictionary.
register them in the `apebench.components.architecture_dict` dictionary.

```python
import apebench
import pdequinox as pdeqx

def conv_net_extension(
config_str: str,
num_spatial_dims: int,
num_points: int,
num_channels: int,
*,
key: PRNGKeyArray,
activation_fn,
key,
):
config_args = config_str.split(";")

Expand All @@ -485,22 +487,24 @@ def conv_net_extension(
out_channels=num_channels,
hidden_channels=42,
depth=depth,
activation=jax.nn.relu,
activation=activation_fn,
key=key,
)

apebench.arch_extensions.update(
{"MyConvNet": conv_net_extension}
apebench.components.architecture_dict.update(
# Ensure that the key is in lower case
{"myconvnet": conv_net_extension}
)
```

Then you can use the `Conv` architecture in the `net` configuration string.
Then you can use the `MyConvNet` architecture in the `net` configuration string.
We prepend `"relu"` to identify the activation function.

```python
apebench.run_experiment(
scenario="diff_adv",
task="predict",
net="MyConvNet;42",
net="MyConvNet;42;relu",
train="one",
start_seed=0,
num_seeds=10,
Expand Down
1 change: 0 additions & 1 deletion apebench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from . import _scraper as scraper
from . import components, scenarios
from ._base_scenario import BaseScenario
from ._extensions import arch_extensions
from ._run import (
get_experiment_name,
melt_concat_from_list,
Expand Down
256 changes: 34 additions & 222 deletions apebench/_base_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
from jaxtyping import Array, Float, PRNGKeyArray

from ._corrected_stepper import CorrectedStepper
from ._extensions import arch_extensions
from .components import metric_dict
from .components import (
activation_fn_dict,
architecture_dict,
ic_dict,
lr_scheduler_dict,
metric_dict,
optimizer_dict,
)


class BaseScenario(eqx.Module, ABC):
Expand Down Expand Up @@ -79,44 +85,8 @@ def get_ic_generator(self) -> BaseRandomICGenerator:
"""

def _get_single_channel(config):
ic_args = config.split(";")
if ic_args[0].lower() == "fourier":
cutoff = int(ic_args[1])
zero_mean = ic_args[2].lower() == "true"
max_one = ic_args[3].lower() == "true"
if zero_mean:
offset_range = (0.0, 0.0)
else:
offset_range = (-0.5, 0.5)
ic_gen = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims=self.num_spatial_dims,
cutoff=cutoff,
offset_range=offset_range,
max_one=max_one,
)
elif ic_args[0].lower() == "diffused":
intensity = float(ic_args[1])
zero_mean = ic_args[2].lower() == "true"
max_one = ic_args[3].lower() == "true"
ic_gen = ex.ic.DiffusedNoise(
num_spatial_dims=self.num_spatial_dims,
intensity=intensity,
zero_mean=zero_mean,
max_one=max_one,
)
elif ic_args[0].lower() == "grf":
powerlaw_exponent = float(ic_args[1])
zero_mean = ic_args[2].lower() == "true"
max_one = ic_args[3].lower() == "true"
ic_gen = ex.ic.GaussianRandomField(
num_spatial_dims=self.num_spatial_dims,
powerlaw_exponent=powerlaw_exponent,
zero_mean=zero_mean,
max_one=max_one,
)
else:
raise ValueError("Unknown IC configuration")

ic_name = config.split(";")[0]
ic_gen = ic_dict[ic_name.lower()](config, self.num_spatial_dims)
return ic_gen

ic_args = self.ic_config.split(";")
Expand Down Expand Up @@ -167,37 +137,22 @@ def num_training_steps(self):
optim_args = self.optim_config.split(";")
return int(optim_args[1])

def get_optimizer(self):
def get_optimizer(self) -> optax.GradientTransformation:
"""
Returns the optimizer used in the scenario.
"""
optim_args = self.optim_config.split(";")
optimizer_name = optim_args[0]
num_training_steps = int(optim_args[1])
schedule_args = optim_args[2:]
if schedule_args[0] == "constant":
lr_schedule = optax.constant_schedule(float(schedule_args[1]))
elif schedule_args[0] == "warmup_cosine":
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=float(schedule_args[1]),
peak_value=float(schedule_args[2]),
warmup_steps=int(schedule_args[3]),
decay_steps=num_training_steps,
)
elif schedule_args[0] == "exp":
lr_schedule = optax.exponential_decay(
init_value=float(schedule_args[1]),
transition_steps=int(schedule_args[2]),
decay_rate=float(schedule_args[3]),
staircase=schedule_args[4].lower() == "true",
)
else:
raise ValueError("Unknown schedule")
scheduler_args = optim_args[2:]
scheduler_name = scheduler_args[0]

if optimizer_name == "adam":
optimizer = optax.adam(lr_schedule)
else:
raise ValueError("Unknown optimizer")
lr_scheduler = lr_scheduler_dict[scheduler_name.lower()](
";".join(scheduler_args), num_training_steps
)
optimizer = optimizer_dict[optimizer_name.lower()](self.optim_config)(
lr_scheduler
)

return optimizer

Expand Down Expand Up @@ -428,16 +383,9 @@ def get_activation(
"""
Parse a string to a callable activation function.
"""
if activation.lower() == "tanh":
return jax.nn.tanh
elif activation.lower() == "relu":
return jax.nn.relu
elif activation.lower() == "silu":
return jax.nn.silu
elif activation.lower() == "gelu":
return jax.nn.gelu
else:
raise ValueError("unknown activation string")
activation_fn_name = activation.split(";")[0]
activation_fn = activation_fn_dict[activation_fn_name.lower()](activation)
return activation_fn

def get_network(
self,
Expand Down Expand Up @@ -515,154 +463,18 @@ def get_network(
"""
network_args = network_config.split(";")

if network_args[0].lower() == "conv":
hidden_channels = int(network_args[1])
depth = int(network_args[2])
activation = self.get_activation(network_args[3])
network = pdeqx.arch.ConvNet(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
hidden_channels=hidden_channels,
depth=depth,
activation=activation,
boundary_mode="periodic",
key=key,
)
elif network_args[0].lower() == "res":
hidden_channels = int(network_args[1])
num_blocks = int(network_args[2])
activation = self.get_activation(network_args[3])

network = pdeqx.arch.ClassicResNet(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
hidden_channels=hidden_channels,
num_blocks=num_blocks,
activation=activation,
boundary_mode="periodic",
key=key,
)
elif network_args[0].lower() == "unet":
hidden_channels = int(network_args[1])
num_levels = int(network_args[2])
activation = self.get_activation(network_args[3])

network = pdeqx.arch.ClassicUNet(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
hidden_channels=hidden_channels,
num_levels=num_levels,
activation=activation,
boundary_mode="periodic",
key=key,
)
elif network_args[0].lower() == "dil":
dilation_depth = int(network_args[1])
hidden_channels = int(network_args[2])
num_blocks = int(network_args[3])
activation = self.get_activation(network_args[4])

dilation_rates = [2**i for i in range(dilation_depth + 1)]
dilation_rates = dilation_rates + dilation_rates[::-1][1:]

network = pdeqx.arch.DilatedResNet(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
hidden_channels=hidden_channels,
num_blocks=num_blocks,
dilation_rates=dilation_rates,
activation=activation,
boundary_mode="periodic",
key=key,
)
elif network_args[0].lower() == "fno":
num_modes = int(network_args[1])
hidden_channels = int(network_args[2])
num_blocks = int(network_args[3])
activation = self.get_activation(network_args[4])

network = pdeqx.arch.ClassicFNO(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
hidden_channels=hidden_channels,
num_blocks=num_blocks,
num_modes=num_modes,
activation=activation,
key=key,
)
elif network_args[0].lower() == "mlp":
width_size = int(network_args[1])
depth = int(network_args[2])
activation = self.get_activation(network_args[3])

network = pdeqx.arch.MLP(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
num_points=self.num_points, # Has to be know a priori
width_size=width_size,
depth=depth,
activation=activation,
key=key,
)
elif network_args[0].lower() == "pure":
kernel_size = int(network_args[1])

network = pdeqx.conv.PhysicsConv(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
kernel_size=kernel_size,
use_bias=False, # !!! no bias,
key=key,
boundary_mode="periodic",
)
elif network_args[0].lower() == "mores":
# Modern ResNet using pre-activation and group normalization
hidden_channels = int(network_args[1])
num_blocks = int(network_args[2])
activation = self.get_activation(network_args[3])

network = pdeqx.arch.ModernResNet(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
hidden_channels=hidden_channels,
num_blocks=num_blocks,
activation=activation,
boundary_mode="periodic",
key=key,
)
elif network_args[0].lower() == "mounet":
# Modern UNet using two resnet blocks per level
hidden_channels = int(network_args[1])
num_levels = int(network_args[2])
activation = self.get_activation(network_args[3])
network = pdeqx.arch.ModernUNet(
num_spatial_dims=self.num_spatial_dims,
in_channels=self.num_channels,
out_channels=self.num_channels,
hidden_channels=hidden_channels,
num_levels=num_levels,
activation=activation,
boundary_mode="periodic",
key=key,
)
else:
try:
network = arch_extensions[network_args[0].lower()](
network_config,
self.num_spatial_dims,
self.num_channels,
key=key,
)
except KeyError:
raise ValueError("Unknown network argument")
network_name = network_args[0]
activation_fn_config = network_args[-1]
activation_fn = self.get_activation(activation_fn_config)

network = architecture_dict[network_name.lower()](
network_config,
self.num_spatial_dims,
self.num_points,
self.num_channels,
activation_fn,
key,
)

return network

Expand Down
Loading

0 comments on commit 2597057

Please sign in to comment.