diff --git a/mlff/nn/layer/so3krates_layer_sparse.py b/mlff/nn/layer/so3krates_layer_sparse.py index 616ef63..b765b6d 100644 --- a/mlff/nn/layer/so3krates_layer_sparse.py +++ b/mlff/nn/layer/so3krates_layer_sparse.py @@ -90,7 +90,9 @@ def __call__(self, idx_i=idx_i, idx_j=idx_j) # (num_nodes, num_features), (num_nodes, num_orders) - x = x + x_att # (num_nodes, num_features) + x = nn.silu( + nn.Dense(num_features)(x) + nn.Dense(num_features)(x_att) + ) # (num_nodes, num_features) ev = ev + ev_att # (num_nodes, num_orders) if self.layer_normalization_1: @@ -155,7 +157,7 @@ def __dict_repr__(self) -> Dict[str, Dict[str, Any]]: } -class AttentionBlock_(nn.Module): +class AttentionBlock(nn.Module): degrees: Sequence[int] num_heads: int = 4 num_features_head: int = 32 @@ -214,6 +216,8 @@ def __call__(self, contraction_fn = make_l0_contraction_fn(self.degrees, dtype=ev.dtype) degree_repeat_fn = make_degree_repeat_fn(self.degrees, axis=-1) + rbf_ij = rbf_ij * jnp.expand_dims(cut, axis=-1) + w_ij = nn.Dense( features=tot_num_features, name='radial_filter_layer_2' @@ -235,7 +239,7 @@ def __call__(self, )( self.activation_fn( nn.Dense( - features=tot_num_features // 4, + features=tot_num_features // 2, name='spherical_filter_layer_1')(contraction_fn(ev_j - ev_i)) ) ) # (num_pairs, tot_num_features) @@ -262,8 +266,7 @@ def __call__(self, k_j = self.qk_non_linearity(jnp.einsum('Hij, NHj -> NHi', Wk, x_H))[idx_j] # (num_pairs, tot_num_heads, num_features_head) - alpha_ij = (q_i * w_ij * k_j).sum(axis=-1) / jnp.sqrt(self.num_features_head).astype(x.dtype) * jnp.expand_dims(cut, - axis=-1) + alpha_ij = mask.safe_scale((q_i * w_ij * k_j).sum(axis=-1), jnp.expand_dims(cut, axis=-1)) # (num_pairs, tot_num_heads) alpha1_ij, alpha2_ij = jnp.split(alpha_ij, indices_or_sections=np.array([self.num_heads]), axis=-1) @@ -335,12 +338,12 @@ def __call__(self, x, ev, *args, **kwargs): degree_repeat_fn = make_degree_repeat_fn(self.degrees, axis=-1) y = jnp.concatenate([x, contraction_fn(ev)], axis=-1) # shape: (N, num_features+num_degrees) - # y = self.activation_fn(y) - # y = nn.Dense( - # features=num_features // 2, - # name='mlp_layer_1' - # )(y) # (N, num_features // 2) - # y = self.activation_fn(y) + y = self.activation_fn(y) + y = nn.Dense( + features=num_features, + name='mlp_layer_1' + )(y) # (N, num_features) + y = self.activation_fn(y) y = nn.Dense( features=num_features + num_degrees, kernel_init=self.last_layer_kernel_init, @@ -354,7 +357,7 @@ def __call__(self, x, ev, *args, **kwargs): return cx, degree_repeat_fn(cev) * ev -class AttentionBlock(nn.Module): +class AttentionBlock_(nn.Module): degrees: Sequence[int] num_heads: int = 4 num_features_head: int = 32