Skip to content

Commit

Permalink
Move is_fitted to model class.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 24, 2023
1 parent a2d2e2b commit ab9e61b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 49 deletions.
44 changes: 6 additions & 38 deletions harmonic/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,20 @@ class Model(metaclass=abc.ABCMeta):
(double): Predicted log_e posterior value.
"""

@abc.abstractmethod

def is_fitted(self):
"""Specify whether model has been fitted.
Returns:
(bool): Whether the model has been fitted.
"""

return self.fitted



def serialize(self, filename):
"""Serialize Model object.
Expand Down Expand Up @@ -227,18 +230,6 @@ class HyperSphere(Model):
self.fitted = False


def is_fitted(self):
"""Specify whether model has been fitted.
Returns:
(bool): Whether the model has been fitted.
"""

return self.fitted


def set_R(self, double R):
"""Set the radius of the hyper-sphere and calculate its volume.
Expand Down Expand Up @@ -663,17 +654,6 @@ class KernelDensityEstimate(Model):

return


def is_fitted(self):
"""Specify whether model has been fitted.
Returns:
(bool): Whether the model has been fitted.
"""

return self.fitted


def set_scales(self, np.ndarray[double, ndim=2, mode="c"] X):
Expand Down Expand Up @@ -1409,18 +1389,6 @@ class ModifiedGaussianMixtureModel(Model):
self.fitted = False


def is_fitted(self):
"""Specify whether model has been fitted.
Returns:
(bool): Whether the model has been fitted.
"""

return self.fitted


def set_weights(self, np.ndarray[double, ndim=1, mode="c"] weights_in):
"""Set the weights of the Gaussians.
Expand Down
11 changes: 0 additions & 11 deletions harmonic/model_nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,6 @@ def __init__(
self.temperature = temperature
self.flow = None

def is_fitted(self) -> bool:
"""Specify whether model has been fitted.
Returns:
(bool): Whether the model has been fitted.
"""

return self.fitted

def create_train_state(self, rng):
params = self.flow.init(rng, jnp.ones((1, self.ndim)))["params"]
tx = optax.adam(self.learning_rate, self.momentum)
Expand Down

0 comments on commit ab9e61b

Please sign in to comment.