Skip to content

Commit

Permalink
Allow more flexibility for pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 4, 2024
1 parent aa169b5 commit 7fd83a3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions bayesflow/experimental/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ def __init__(
self,
summary_dim: int = 10,
depth: int = 2,
inner_pooling: str = "mean",
output_pooling: str = "mean",
num_dense_equivariant: int = 2,
num_dense_invariant_inner: int = 2,
num_dense_invariant_outer: int = 2,
units_equivariant: int = 128,
units_invariant_inner: int = 128,
units_invariant_outer: int = 128,
pooling: str = "mean",
activation: str | callable = "gelu",
kernel_regularizer: regularizers.Regularizer | None = None,
kernel_initializer: str = "he_uniform",
Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(
bias_regularizer=bias_regularizer,
spectral_normalization=spectral_normalization,
dropout=dropout,
pooling=pooling,
pooling=inner_pooling,
**kwargs
)
self.equivariant_modules.add(equivariant_module)
Expand All @@ -71,7 +72,7 @@ def __init__(
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer,
dropout=dropout,
pooling=pooling,
pooling=output_pooling,
spectral_normalization=spectral_normalization,
**kwargs
)
Expand Down

0 comments on commit 7fd83a3

Please sign in to comment.