From 3ba6bc6ac617be0a7d0846c6ad42154effc0c983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Wed, 18 Dec 2024 12:04:49 +0100 Subject: [PATCH] Update cebra/integrations/sklearn/utils.py Co-authored-by: Steffen Schneider --- cebra/integrations/sklearn/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index c1671b9e..1f6c621f 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -28,9 +28,15 @@ import cebra.helper from packaging import version -from sklearn import __version__ as sklearn_version -sklearn_version = version.parse(sklearn_version) +import sklearn +def _check_array_ensure_all_finite(array, **kwargs): + if version.parse(sklearn.__version__) < version.parse("1.8"): + key = "force_all_finite" + else: + key = "ensure_all_finite" + kwargs[key] = True + return sklearn_utils_validation.check_array(array, **kwargs) def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced.