Skip to content

Commit

Permalink
go to initial version
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Jan 12, 2024
1 parent da56def commit afab961
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions mlff/nn/layer/so3krates_layer_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def __call__(self,
idx_i=idx_i,
idx_j=idx_j) # (num_nodes, num_features), (num_nodes, num_orders)

x = x + x_att # (num_nodes, num_features)
x = nn.silu(
nn.Dense(num_features)(x) + nn.Dense(num_features)(x_att)
) # (num_nodes, num_features)
ev = ev + ev_att # (num_nodes, num_orders)

if self.layer_normalization_1:
Expand Down Expand Up @@ -155,7 +157,7 @@ def __dict_repr__(self) -> Dict[str, Dict[str, Any]]:
}


class AttentionBlock_(nn.Module):
class AttentionBlock(nn.Module):
degrees: Sequence[int]
num_heads: int = 4
num_features_head: int = 32
Expand Down Expand Up @@ -214,6 +216,8 @@ def __call__(self,
contraction_fn = make_l0_contraction_fn(self.degrees, dtype=ev.dtype)
degree_repeat_fn = make_degree_repeat_fn(self.degrees, axis=-1)

rbf_ij = rbf_ij * jnp.expand_dims(cut, axis=-1)

w_ij = nn.Dense(
features=tot_num_features,
name='radial_filter_layer_2'
Expand All @@ -235,7 +239,7 @@ def __call__(self,
)(
self.activation_fn(
nn.Dense(
features=tot_num_features // 4,
features=tot_num_features // 2,
name='spherical_filter_layer_1')(contraction_fn(ev_j - ev_i))
)
) # (num_pairs, tot_num_features)
Expand All @@ -262,8 +266,7 @@ def __call__(self,
k_j = self.qk_non_linearity(jnp.einsum('Hij, NHj -> NHi', Wk, x_H))[idx_j]
# (num_pairs, tot_num_heads, num_features_head)

alpha_ij = (q_i * w_ij * k_j).sum(axis=-1) / jnp.sqrt(self.num_features_head).astype(x.dtype) * jnp.expand_dims(cut,
axis=-1)
alpha_ij = mask.safe_scale((q_i * w_ij * k_j).sum(axis=-1), jnp.expand_dims(cut, axis=-1))
# (num_pairs, tot_num_heads)

alpha1_ij, alpha2_ij = jnp.split(alpha_ij, indices_or_sections=np.array([self.num_heads]), axis=-1)
Expand Down Expand Up @@ -335,12 +338,12 @@ def __call__(self, x, ev, *args, **kwargs):
degree_repeat_fn = make_degree_repeat_fn(self.degrees, axis=-1)

y = jnp.concatenate([x, contraction_fn(ev)], axis=-1) # shape: (N, num_features+num_degrees)
# y = self.activation_fn(y)
# y = nn.Dense(
# features=num_features // 2,
# name='mlp_layer_1'
# )(y) # (N, num_features // 2)
# y = self.activation_fn(y)
y = self.activation_fn(y)
y = nn.Dense(
features=num_features,
name='mlp_layer_1'
)(y) # (N, num_features)
y = self.activation_fn(y)
y = nn.Dense(
features=num_features + num_degrees,
kernel_init=self.last_layer_kernel_init,
Expand All @@ -354,7 +357,7 @@ def __call__(self, x, ev, *args, **kwargs):
return cx, degree_repeat_fn(cev) * ev


class AttentionBlock(nn.Module):
class AttentionBlock_(nn.Module):
degrees: Sequence[int]
num_heads: int = 4
num_features_head: int = 32
Expand Down

0 comments on commit afab961

Please sign in to comment.