diff --git a/harmonic/model_nf.py b/harmonic/model_nf.py index 6e13ca63..59a36b81 100644 --- a/harmonic/model_nf.py +++ b/harmonic/model_nf.py @@ -95,7 +95,8 @@ def train_flow( class FlowModel(md.Model): - """Normalizing flow model to approximate the log_e posterior by a normalizing flow.""" + """Normalizing flow model to approximate the log_e posterior by a normalizing + flow.""" def __init__( self, @@ -140,13 +141,16 @@ def fit( X (jnp.ndarray (nsamples, ndim)): Training samples. - batch_size (int, optional): Batch size used when training flow. Default = 64. + batch_size (int, optional): Batch size used when training flow. Default = + 64. epochs (int, optional): Number of epochs flow is trained for. Default = 3. - key (Union[jax.Array, jax.random.PRNGKeyArray], optional): Key used in random number generation process. + key (Union[jax.Array, jax.random.PRNGKeyArray], optional): Key used in + random number generation process. - verbose (bool, optional): Controls if progress bar and current loss are displayed when training. Default = False. + verbose (bool, optional): Controls if progress bar and current loss are + displayed when training. Default = False. Raises: @@ -295,17 +299,24 @@ def __init__( ndim_in (int): Dimension of the problem to solve. - n_scaled_layers (int, optional): Number of layers with scaler in RealNVP flow. Default = 2. + n_scaled_layers (int, optional): Number of layers with scaler in RealNVP + flow. Default = 2. - n_unscaled_layers (int, optional): Number of layers without scaler in RealNVP flow. Default = 4. + n_unscaled_layers (int, optional): Number of layers without scaler in + RealNVP flow. Default = 4. - learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Default = 0.001. + learning_rate (float, optional): Learning rate for adam optimizer used in + the fit method. Default = 0.001. - momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Default = 0.9 + momentum (float, optional): Learning rate for Adam optimizer used in the fit + method. Default = 0.9 - standardize(bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Default = False + standardize(bool, optional): Indicates if mean and variance should be + removed from training data when training the flow. Default = False - temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Default = 0.8. + temperature (float, optional): Scale factor by which the base distribution + Gaussian is compressed in the prediction step. Should be positive and <=1. + Default = 0.8. Raises: @@ -365,17 +376,24 @@ def __init__( n_bins (int, optional): Number of bins in the spline. Defaults to 8. - hidden_size (Sequence[int], optional): Size of the hidden layers in the conditioner. Defaults to [64, 64]. + hidden_size (Sequence[int], optional): Size of the hidden layers in the + conditioner. Defaults to [64, 64]. - spline_range (Sequence[float], optional): Range of the spline. Defaults to (-10.0, 10.0). + spline_range (Sequence[float], optional): Range of the spline. Defaults to + (-10.0, 10.0). - standardize (bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Defaults to False. + standardize (bool, optional): Indicates if mean and variance should be + removed from training data when training the flow. Defaults to False. - learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Defaults to 0.001. + learning_rate (float, optional): Learning rate for adam optimizer used in + the fit method. Defaults to 0.001. - momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Defaults to 0.9. + momentum (float, optional): Learning rate for Adam optimizer used in the fit + method. Defaults to 0.9. - temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Defaults to 0.8. + temperature (float, optional): Scale factor by which the base distribution + Gaussian is compressed in the prediction step. Should be positive and <=1. + Defaults to 0.8. Raises: