Skip to content

Commit

Permalink
Add transformation attributes to flow derived classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Aug 14, 2024
1 parent e39a420 commit 3cc2ab8
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions harmonic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def __init__(
momentum: float = 0.9,
standardize: bool = False,
temperature: float = 0.8,
transformation: Callable = None,
log_J_det: Callable = None,
):
"""Constructor setting the hyper-parameters of the model.
Expand Down Expand Up @@ -348,6 +350,8 @@ def __init__(
momentum,
standardize,
temperature,
transformation,
log_J_det,
)

# Model parameters
Expand Down Expand Up @@ -377,6 +381,8 @@ def __init__(
temperature: float = 0.8,
multimodal_base: bool = False,
base_centers: Sequence[jax.Array] = None,
transformation: Callable = None,
log_J_det: Callable = None,
):
"""Constructor setting the hyper-parameters and domains of the model.
Expand Down Expand Up @@ -415,6 +421,8 @@ def __init__(
momentum,
standardize,
temperature,
transformation,
log_J_det,
)

# Flow parameters
Expand Down

0 comments on commit 3cc2ab8

Please sign in to comment.