diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index c1671b9e..1f6c621f 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -28,9 +28,15 @@ import cebra.helper from packaging import version -from sklearn import __version__ as sklearn_version -sklearn_version = version.parse(sklearn_version) +import sklearn +def _check_array_ensure_all_finite(array, **kwargs): + if version.parse(sklearn.__version__) < version.parse("1.8"): + key = "force_all_finite" + else: + key = "ensure_all_finite" + kwargs[key] = True + return sklearn_utils_validation.check_array(array, **kwargs) def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced.