Skip to content

Commit

Permalink
add dense in exchange block
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jan 12, 2024
1 parent 730bc46 commit b91e66e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions mlff/nn/layer/so3krates_layer_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,11 @@ def __call__(self, x, ev, *args, **kwargs):
# name='mlp_layer_1'
# )(y) # (N, num_features)
# y = self.activation_fn(y)
# y = nn.Dense(
# features=num_features + num_degrees,
# kernel_init=self.last_layer_kernel_init,
# name='mlp_layer_2'
# )(y) # (N, num_features + num_degrees)
y = nn.Dense(
features=num_features + num_degrees,
kernel_init=self.last_layer_kernel_init,
name='mlp_layer_2'
)(y) # (N, num_features + num_degrees)
cx, cev = jnp.split(
y,
indices_or_sections=np.array([num_features]),
Expand Down
2 changes: 1 addition & 1 deletion mlff/nn/representation/so3krates_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def init_so3krates_sparse(

layers = [SO3kratesLayerSparse(
degrees=degrees,
use_spherical_filter=i > 0,
use_spherical_filter=True,
num_heads=num_heads,
num_features_head=num_features_head,
qk_non_linearity=getattr(nn.activation, qk_non_linearity) if qk_non_linearity != 'identity' else lambda u: u,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pkg_resources


def test_fit(disable_jit: bool = True):
def test_fit(disable_jit: bool = False):
if disable_jit:
from jax.config import config
config.update('jax_disable_jit', True)
Expand Down

0 comments on commit b91e66e

Please sign in to comment.