Skip to content

Commit

Permalink
Add goodness of fit metric
Browse files Browse the repository at this point in the history
  • Loading branch information
CeliaBenquet committed Nov 29, 2024
1 parent 9e14790 commit 297ee92
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 0 deletions.
62 changes: 62 additions & 0 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,68 @@ def infonce_loss(
return avg_loss


def goodness_of_fit(model: cebra_sklearn_cebra.CEBRA) -> List[float]:
"""Evaluate the goodness of fit (bits) for a given model.
This function calculates the goodness of fit for the provided model
using the specified batch size. The goodness of fit is computed offline
it is a way to normalize wrt batch size to compare models with
different batch sizes or different implementations.
Args:
model: The model to evaluate. This can be an instance of either
`cebra_sklearn_cebra.CEBRA` or `cebra_solver.Solver`.
batch_size: Batch size used to train the model.
Returns:
A list of float values representing the goodness of fit for the model.
"""

if isinstance(model, cebra_sklearn_cebra.CEBRA):
if model.batch_size is None:
raise NotImplementedError(
"Batch size is None, please provide a model with a batch size to compute the goodness of fit."
)
if model.solver_name_ == 'single-session':
gof = _goodness_of_fit(loss=model.state_dict_["loss"],
batch_size=model.batch_size)
elif model.solver_name_ == 'multi-session':
# For the multisession implementation, the batch size is multiplied by the
# number of datasets to get the correct comparison.
gof = _goodness_of_fit(loss=model.state_dict_["loss"],
batch_size=model.batch_size *
model.num_sessions_)
else:
raise NotImplementedError(f"Invalid solver: {model.solver_name_}.")
elif isinstance(model, list):
raise ValueError(
f"Model should correspond to a single CEBRA model,"
f"got {type(model)}, containing {len(model)} elements.")
else:
raise ValueError(f"Provide CEBRA model, got {type(model)}.")
return gof


def _goodness_of_fit(loss: List[float], batch_size: int) -> List[float]:
"""
Compute offline the goodness of fit (bits) from a provided loss.
This is a way to normalize wrt batch size to compare models with
different batch sizes or different implementations.
Args:
loss: A list of size `max_iteration`, corresponding to the loss across training.
batch_size: Batch size used to train the model. For multisession implementation,
you need to multiply the batch size by the number of datasets to get the correct
comparison.
Returns:
A list of float corresponding to the goodness of fit for the provided loss and batch size.
"""
log_batch_size = np.log(batch_size)
return [(1 / np.log(2)) * (log_batch_size - lb) for lb in loss]


def _consistency_scores(
embeddings: List[Union[npt.NDArray, torch.Tensor]],
datasets: List[Union[int, str]],
Expand Down
51 changes: 51 additions & 0 deletions tests/test_sklearn_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,57 @@ def test_sklearn_infonce_loss():
)


def test_sklearn_goodness_of_fit():
max_loss_iterations = 2
cebra_model = cebra_sklearn_cebra.CEBRA(
model_architecture="offset10-model",
max_iterations=5,
batch_size=128,
)

# Example data
X = torch.tensor(np.random.uniform(0, 1, (1000, 50)))
y_c1 = torch.tensor(np.random.uniform(0, 1, (1000, 5)))

X2 = torch.tensor(np.random.uniform(0, 1, (500, 20)))
y2_c1 = torch.tensor(np.random.uniform(0, 1, (500, 5)))

# Single session
cebra_model.fit(X, y_c1)

gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model)
assert isinstance(gof, list)
_gof = cebra.sklearn.metrics._goodness_of_fit(
cebra_model.state_dict_["loss"], batch_size=128)
assert isinstance(_gof, list)
assert gof == _gof

# Multisession
cebra_model.fit([X, X2], [y_c1, y2_c1])

gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model)
assert isinstance(gof, list)
_gof = cebra.sklearn.metrics._goodness_of_fit(
cebra_model.state_dict_["loss"], batch_size=128 * 2)
assert isinstance(_gof, list)
assert gof == _gof

# Multiple models passed
with pytest.raises(ValueError, match="single.*model"):
_ = cebra.sklearn.metrics.goodness_of_fit([cebra_model, cebra_model])

# No batch size
cebra_model_no_bs = cebra_sklearn_cebra.CEBRA(
model_architecture="offset10-model",
max_iterations=max_loss_iterations,
batch_size=None,
)

cebra_model_no_bs.fit(X)
with pytest.raises(NotImplementedError, match="Batch.*size"):
gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model_no_bs)


def test_sklearn_datasets_consistency():
# Example data
np.random.seed(42)
Expand Down

0 comments on commit 297ee92

Please sign in to comment.