Skip to content

Commit

Permalink
Update cebra/integrations/sklearn/utils.py
Browse files Browse the repository at this point in the history
Co-authored-by: Steffen Schneider <[email protected]>
  • Loading branch information
icarosadero and stes authored Dec 18, 2024
1 parent 3e4ee86 commit 3ba6bc6
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@
import cebra.helper

from packaging import version
from sklearn import __version__ as sklearn_version
sklearn_version = version.parse(sklearn_version)
import sklearn

def _check_array_ensure_all_finite(array, **kwargs):
if version.parse(sklearn.__version__) < version.parse("1.8"):
key = "force_all_finite"
else:
key = "ensure_all_finite"
kwargs[key] = True
return sklearn_utils_validation.check_array(array, **kwargs)

def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple:
"""Handle deprecated arguments of a function until they are replaced.
Expand Down

0 comments on commit 3ba6bc6

Please sign in to comment.