Skip to content

Commit

Permalink
allow to pass optimizer args via config and bugfix ITP
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Apr 5, 2024
1 parent 1cbea7c commit f2db288
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions mlff/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ model:
input_convention: positions # Input convention.
optimizer:
name: adam # Name of the optimizer. See https://optax.readthedocs.io/en/latest/api.html#common-optimizers for available ones.
optimizer_args: null
learning_rate: 0.001 # Learning rate to use.
learning_rate_schedule: exponential_decay # Which learning rate schedule to use. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules for available ones.
learning_rate_schedule_args: # Arguments passed to the learning rate schedule. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules.
Expand Down
1 change: 1 addition & 0 deletions mlff/config/config_itp_net.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ model:
input_convention: positions # Input convention.
optimizer:
name: adam # Name of the optimizer. See https://optax.readthedocs.io/en/latest/api.html#common-optimizers for available ones.
optimizer_args: null
learning_rate: 0.001 # Learning rate to use.
learning_rate_schedule: exponential_decay # Which learning rate schedule to use. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules for available ones.
learning_rate_schedule_args: # Arguments passed to the learning rate schedule. See https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules.
Expand Down
1 change: 1 addition & 0 deletions mlff/config/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def make_optimizer_from_config(config: config_dict.ConfigDict = None):
opt_config = config.optimizer
return training_utils.make_optimizer(
name=opt_config.name,
optimizer_args=opt_config.optimizer_args if opt_config.optimizer_args is not None else dict(),
learning_rate_schedule=opt_config.learning_rate_schedule,
learning_rate=opt_config.learning_rate,
learning_rate_schedule_args=opt_config.learning_rate_schedule_args if opt_config.learning_rate_schedule_args is not None else dict(),
Expand Down
10 changes: 9 additions & 1 deletion mlff/nn/layer/itp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,15 @@ def __call__(self,
dst_idx=idx_i,
src_idx=idx_j,
num_segments=len(x)
)
) # (N, 1, (max_degree + 1)^2, num_features) always has 1 for parity axis, since x is invariant and parity 1
# and basis is parity 1 and equivariant.

# For concatenation of dense, add the pseudotensor dimension explicitly.
if self.include_pseudotensors and self.itp_connectivity == 'dense':
y = e3x.nn.change_max_degree_or_type(
y,
include_pseudotensors=True
) # (N, 2, (max_degree + 1)^2, num_features)

if self.message_normalization == 'avg_num_neighbors':
y = jnp.divide(y,
Expand Down
4 changes: 3 additions & 1 deletion mlff/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def fit_from_iterator(

def make_optimizer(
name: str = 'adam',
optimizer_args: Dict = dict(),
learning_rate: float = 1e-3,
learning_rate_schedule: str = 'constant_schedule',
learning_rate_schedule_args: Dict = dict(),
Expand All @@ -589,6 +590,7 @@ def make_optimizer(
Args:
name (str): Name of the optimizer. Defaults to the Adam optimizer.
optimizer_args (dict): Arguments passed to the optimizer.
learning_rate (float): Learning rate.
learning_rate_schedule (str): Learning rate schedule. Defaults to no schedule, meaning learning rate is
held constant.
Expand All @@ -603,7 +605,7 @@ def make_optimizer(
lr_schedule = getattr(optax, learning_rate_schedule)

lr_schedule = lr_schedule(learning_rate, **learning_rate_schedule_args)
opt = opt(lr_schedule)
opt = opt(lr_schedule, **optimizer_args)

clip_transform = getattr(optax, gradient_clipping)
clip_transform = clip_transform(**gradient_clipping_args)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_itp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,14 @@ def test_apply(
@pytest.mark.parametrize("_ipt_num_features", [None, 13])
@pytest.mark.parametrize("_itp_growth_rate", [14])
@pytest.mark.parametrize("_itp_dense_final_concatenation", [True])
@pytest.mark.parametrize("_include_pseudotensors", [False, True])
def test_apply_dense(
ipt_connectivity,
feature_collection_over_layers,
_ipt_num_features,
_itp_growth_rate,
_itp_dense_final_concatenation
_itp_dense_final_concatenation,
_include_pseudotensors
):
layer = ITPLayer(
mp_max_degree=2,
Expand All @@ -152,7 +154,7 @@ def test_apply_dense(
itp_dense_final_concatenation=_itp_dense_final_concatenation,
itp_connectivity=ipt_connectivity,
feature_collection_over_layers=feature_collection_over_layers,
include_pseudotensors=False
include_pseudotensors=_include_pseudotensors
)

params = layer.init(
Expand Down

0 comments on commit f2db288

Please sign in to comment.