diff --git a/examples/normal_gamma_splines.py b/examples/normal_gamma_splines.py index 5497d850..3a720269 100644 --- a/examples/normal_gamma_splines.py +++ b/examples/normal_gamma_splines.py @@ -268,7 +268,7 @@ def run_example( # Fit model # ======================================================================= hm.logs.info_log("Fit model for {} epochs...".format(epochs_num)) - model = model_nf.RQSplineFlow( + model = model_nf.RQSplineModel( ndim, standardize=standardize, temperature=var_scale ) model.fit(chains_train.samples, epochs=epochs_num) diff --git a/examples/radiata_pine_splines.py b/examples/radiata_pine_splines.py index 5da4c721..a2d1885f 100644 --- a/examples/radiata_pine_splines.py +++ b/examples/radiata_pine_splines.py @@ -433,7 +433,7 @@ def run_example( Fit model by selecing the configuration of hyper-parameters which minimises the validation variances. """ - model = model_nf.RQSplineFlow(ndim, standardize=standardize, temperature=var_scale) + model = model_nf.RQSplineModel(ndim, standardize=standardize, temperature=var_scale) model.fit(chains_train.samples, epochs=epochs_num) # =========================================================================== diff --git a/examples/rastrigin_splines.py b/examples/rastrigin_splines.py index e515033f..61768d4e 100644 --- a/examples/rastrigin_splines.py +++ b/examples/rastrigin_splines.py @@ -200,7 +200,7 @@ def run_example( """ Fit model. """ - model = model_nf.RQSplineFlow( + model = model_nf.RQSplineModel( ndim, n_layers=n_layers, n_bins=n_bins, diff --git a/examples/rosenbrock_splines.py b/examples/rosenbrock_splines.py index c7a590d9..285962c4 100644 --- a/examples/rosenbrock_splines.py +++ b/examples/rosenbrock_splines.py @@ -220,7 +220,7 @@ def run_example( # Fit model # ======================================================================= hm.logs.info_log("Fit model for {} epochs...".format(epochs_num)) - model = model_nf.RQSplineFlow( + model = model_nf.RQSplineModel( ndim, standardize=standardize, temperature=var_scale ) model.fit(chains_train.samples, epochs=epochs_num) diff --git a/harmonic/model_nf.py b/harmonic/model_nf.py index 4fe3153f..41e2fa6c 100644 --- a/harmonic/model_nf.py +++ b/harmonic/model_nf.py @@ -94,56 +94,20 @@ def train_flow( return train_flow, train_epoch, train_step -# =============================================================================== -# NVP Flow - will generalise this to take a custom flow -# =============================================================================== - - -class RealNVPModel(md.Model): - """Normalizing flow model to approximate the log_e posterior by a NVP normalizing flow.""" +class FlowModel(md.Model): + """Normalizing flow model to approximate the log_e posterior by a normalizing flow.""" def __init__( self, ndim_in: int, - n_scaled_layers: int = 2, - n_unscaled_layers: int = 4, learning_rate: float = 0.001, momentum: float = 0.9, standardize: bool = False, temperature: float = 0.8, ): - """Constructor setting the hyper-parameters of the model. - - Args: - - ndim_in (int): Dimension of the problem to solve. - - n_scaled_layers (int, optional): Number of layers with scaler in RealNVP flow. Default = 2. - - n_unscaled_layers (int, optional): Number of layers without scaler in RealNVP flow. Default = 4. - - learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Default = 0.001. - - momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Default = 0.9 - - standardize(bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Default = False - - temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Default = 0.8. - - Raises: - - ValueError: If the ndim_in is not positive. - - ValueError: If n_scaled_layers is not positive. - - """ - if ndim_in < 1: raise ValueError("Dimension must be greater than 0.") - if n_scaled_layers <= 0: - raise ValueError("Number of scaled layers must be greater than 0.") - self.ndim = ndim_in self.fitted = False self.state = None @@ -151,11 +115,9 @@ def __init__( # Model parameters self.learning_rate = learning_rate self.momentum = momentum - self.n_scaled_layers = n_scaled_layers - self.n_unscaled_layers = n_unscaled_layers self.standardize = standardize - self.flow = flows.RealNVP(ndim_in, self.n_scaled_layers, self.n_unscaled_layers) self.temperature = temperature + self.flow = None def is_fitted(self) -> bool: """Specify whether model has been fitted. @@ -202,8 +164,15 @@ def fit( ValueError: Raised if the second dimension of X is not the same as ndim. + NotImplementedError: If called directly from FlowModel class. + """ + if self.flow is None: + raise NotImplementedError( + "This method cannot be used in the FlowModel class directly. Use a class with a specific flow implemented (RealNVPModel, RQSplineModel)." + ) + if X.shape[1] != self.ndim: raise ValueError("X second dimension not the same as ndim.") @@ -314,220 +283,129 @@ def sample(self, n_sample: int, rng_key=jax.random.PRNGKey(0)) -> jnp.ndarray: # =============================================================================== -# Rational Quadratic Spline Flow +# NVP Flow # =============================================================================== -class RQSplineFlow(md.Model): - """Rational quadratic spline flow model to approximate the log_e posterior by a normalizing flow.""" +class RealNVPModel(FlowModel): + """Normalizing flow model to approximate the log_e posterior by a NVP normalizing flow.""" def __init__( self, ndim_in: int, - n_layers: int = 8, - n_bins: int = 8, - hidden_size: Sequence[int] = [64, 64], - spline_range: Sequence[float] = (-10.0, 10.0), - standardize: bool = False, + n_scaled_layers: int = 2, + n_unscaled_layers: int = 4, learning_rate: float = 0.001, momentum: float = 0.9, + standardize: bool = False, temperature: float = 0.8, ): - """Constructor setting the hyper-parameters and domains of the model. - - Must be implemented by derived class (currently abstract). + """Constructor setting the hyper-parameters of the model. Args: ndim_in (int): Dimension of the problem to solve. - n_layers (int, optional): Number of layers in the flow. Defaults to 8. - - n_bins (int, optional): Number of bins in the spline. Defaults to 8. - - hidden_size (Sequence[int], optional): Size of the hidden layers in the conditioner. Defaults to [64, 64]. + n_scaled_layers (int, optional): Number of layers with scaler in RealNVP flow. Default = 2. - spline_range (Sequence[float], optional): Range of the spline. Defaults to (-10.0, 10.0). + n_unscaled_layers (int, optional): Number of layers without scaler in RealNVP flow. Default = 4. - standardize (bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Defaults to False. + learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Default = 0.001. - learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Defaults to 0.001. + momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Default = 0.9 - momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Defaults to 0.9. + standardize(bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Default = False - temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Defaults to 0.8. + temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Default = 0.8. Raises: ValueError: If the ndim_in is not positive. - """ - - if ndim_in < 1: - raise ValueError("Dimension must be greater than 0.") - - self.ndim = ndim_in - self.fitted = False - self.state = None - self.standardize = standardize - - # Flow parameters - self.n_layers = n_layers - self.hidden_size = hidden_size - self.n_bins = n_bins - self.spline_range = spline_range - self.flow = flows.RQSpline(ndim_in, n_layers, hidden_size, n_bins, spline_range) - self.temperature = temperature - - # Optimizer parameters - self.learning_rate = learning_rate - self.momentum = momentum - - def is_fitted(self): - """Specify whether model has been fitted. - - Returns: - - (bool): Whether the model has been fitted. + ValueError: If n_scaled_layers is not positive. """ - return self.fitted + if n_scaled_layers <= 0: + raise ValueError("Number of scaled layers must be greater than 0.") - def create_train_state(self, rng): - params = self.flow.init(rng, jnp.ones((1, self.ndim)))["params"] - tx = optax.adam(self.learning_rate, self.momentum) - return train_state.TrainState.create( - apply_fn=self.flow.apply, params=params, tx=tx + FlowModel.__init__( + self, + ndim_in, + learning_rate, + momentum, + standardize, + temperature, ) - def fit( - self, - X: jnp.ndarray, - batch_size: int = 64, - epochs: int = 3, - key=jax.random.PRNGKey(1000), - verbose: bool = False, - ): - """Fit the parameters of the model. - - Args: - - X (jnp.ndarray (nsamples, ndim)): Sample x coordinates. - - batch_size (int, optional): Batch size used when training flow. Defaults to 64. - - epochs (int, optional): Number of epochs flow is trained for. Defaults to 3. - - key (Union[jax.Array, jax.random.PRNGKeyArray], optional): Key used in random number generation process. - - verbose (bool, optional): Controls if progress bar and current loss are displayed when training. Defaults to False. - + # Model parameters + self.n_scaled_layers = n_scaled_layers + self.n_unscaled_layers = n_unscaled_layers + self.flow = flows.RealNVP(ndim_in, self.n_scaled_layers, self.n_unscaled_layers) - Raises: - ValueError: Raised if the second dimension of X is not the same as ndim. +# =============================================================================== +# Rational Quadratic Spline Flow +# =============================================================================== - """ - if X.shape[1] != self.ndim: - raise ValueError("X second dimension not the same as ndim.") +class RQSplineModel(FlowModel): + """Rational quadratic spline flow model to approximate the log_e posterior by a normalizing flow.""" - key, rng_model, rng_init, rng_train = jax.random.split(key, 4) + def __init__( + self, + ndim_in: int, + n_layers: int = 8, + n_bins: int = 8, + hidden_size: Sequence[int] = [64, 64], + spline_range: Sequence[float] = (-10.0, 10.0), + standardize: bool = False, + learning_rate: float = 0.001, + momentum: float = 0.9, + temperature: float = 0.8, + ): + """Constructor setting the hyper-parameters and domains of the model. - variables = self.flow.init(rng_model, jnp.ones((1, self.ndim))) - state = self.create_train_state(rng_init) + Must be implemented by derived class (currently abstract). - # set up standardisation - if self.standardize: - # self.pre_offset = jnp.min(X, axis = 0) #maxmin - self.pre_offset = jnp.mean(X, axis=0) - # self.pre_amp = (jnp.max(X, axis=0) - self.pre_offset) - self.pre_amp = jnp.sqrt(jnp.diag(jnp.cov(X.T))) + Args: - X = (X - self.pre_offset) / self.pre_amp + ndim_in (int): Dimension of the problem to solve. - train_flow, train_epoch, train_step = make_training_loop(self.flow) - rng, state, loss_values = train_flow( - rng_train, state, variables, X, epochs, batch_size, verbose=verbose - ) + n_layers (int, optional): Number of layers in the flow. Defaults to 8. - self.state = state - self.variables = variables - self.fitted = True + n_bins (int, optional): Number of bins in the spline. Defaults to 8. - return + hidden_size (Sequence[int], optional): Size of the hidden layers in the conditioner. Defaults to [64, 64]. - def predict(self, x) -> jnp.ndarray: - """Predict the value of log_e posterior at batched input x. + spline_range (Sequence[float], optional): Range of the spline. Defaults to (-10.0, 10.0). - Args: + standardize (bool, optional): Indicates if mean and variance should be removed from training data when training the flow. Defaults to False. - x (jnp.ndarray (batch_size, ndim)): Batched sample for which to - predict posterior values. + learning_rate (float, optional): Learning rate for adam optimizer used in the fit method. Defaults to 0.001. - Returns: + momentum (float, optional): Learning rate for Adam optimizer used in the fit method. Defaults to 0.9. - jnp.ndarray (batch_size,): Predicted log_e posterior value. + temperature (float, optional): Scale factor by which the base distribution Gaussian is compressed in the prediction step. Should be positive and <=1. Defaults to 0.8. Raises: - ValueError: If var_scale is negative or greater than 1. + ValueError: If the ndim_in is not positive. """ - var_scale = self.temperature - - if var_scale > 1: - raise ValueError("Scaling must not be greater than 1.") - - if var_scale <= 0: - raise ValueError("Scaling must be positive.") - - if self.standardize: - x = (x - self.pre_offset) / self.pre_amp - - logprob = self.flow.apply( - {"params": self.state.params, "variables": self.variables}, - x, - var_scale, - method=self.flow.log_prob, + FlowModel.__init__( + self, + ndim_in, + learning_rate, + momentum, + standardize, + temperature, ) - if self.standardize: - logprob -= sum(jnp.log(self.pre_amp)) - - return logprob - - def sample(self, n_sample: int, rng_key=jax.random.PRNGKey(0)) -> jnp.ndarray: - """Sample from trained flow. - - Args: - nsample (int): Number of samples generated. - - rng_key (Union[jax.Array, jax.random.PRNGKeyArray], optional): Key used in random number generation process. - - Returns: - - jnp.array (n_sample, ndim): Samples from fitted distribution.""" - - var_scale = self.temperature - - if var_scale > 1: - raise ValueError("Scaling must not be greater than 1.") - - if var_scale <= 0: - raise ValueError("Scaling must be positive.") - - samples = self.flow.apply( - {"params": self.state.params, "variables": self.variables}, - rng_key, - n_sample, - var_scale, - method=self.flow.sample, - ) - - if self.standardize: - samples = (samples * self.pre_amp) + self.pre_offset - - return samples + # Flow parameters + self.n_layers = n_layers + self.hidden_size = hidden_size + self.n_bins = n_bins + self.spline_range = spline_range + self.flow = flows.RQSpline(ndim_in, n_layers, hidden_size, n_bins, spline_range)