diff --git a/harmonic/model.pyx b/harmonic/model.pyx index fe9b73f4..c68c130b 100644 --- a/harmonic/model.pyx +++ b/harmonic/model.pyx @@ -68,17 +68,20 @@ class Model(metaclass=abc.ABCMeta): (double): Predicted log_e posterior value. """ - - @abc.abstractmethod + def is_fitted(self): """Specify whether model has been fitted. - + Returns: (bool): Whether the model has been fitted. """ + return self.fitted + + + def serialize(self, filename): """Serialize Model object. @@ -227,18 +230,6 @@ class HyperSphere(Model): self.fitted = False - def is_fitted(self): - """Specify whether model has been fitted. - - Returns: - - (bool): Whether the model has been fitted. - - """ - - return self.fitted - - def set_R(self, double R): """Set the radius of the hyper-sphere and calculate its volume. @@ -663,17 +654,6 @@ class KernelDensityEstimate(Model): return - - def is_fitted(self): - """Specify whether model has been fitted. - - Returns: - - (bool): Whether the model has been fitted. - - """ - - return self.fitted def set_scales(self, np.ndarray[double, ndim=2, mode="c"] X): @@ -1409,18 +1389,6 @@ class ModifiedGaussianMixtureModel(Model): self.fitted = False - def is_fitted(self): - """Specify whether model has been fitted. - - Returns: - - (bool): Whether the model has been fitted. - - """ - - return self.fitted - - def set_weights(self, np.ndarray[double, ndim=1, mode="c"] weights_in): """Set the weights of the Gaussians. diff --git a/harmonic/model_nf.py b/harmonic/model_nf.py index 192121f2..c95d94be 100644 --- a/harmonic/model_nf.py +++ b/harmonic/model_nf.py @@ -119,17 +119,6 @@ def __init__( self.temperature = temperature self.flow = None - def is_fitted(self) -> bool: - """Specify whether model has been fitted. - - Returns: - - (bool): Whether the model has been fitted. - - """ - - return self.fitted - def create_train_state(self, rng): params = self.flow.init(rng, jnp.ones((1, self.ndim)))["params"] tx = optax.adam(self.learning_rate, self.momentum)