Skip to content

Commit

Permalink
Update cebra/integrations/sklearn/utils.py
Browse files Browse the repository at this point in the history
Co-authored-by: Steffen Schneider <[email protected]>
  • Loading branch information
icarosadero and stes authored Dec 18, 2024
1 parent 3ba6bc6 commit 128257b
Showing 1 changed file with 12 additions and 29 deletions.
41 changes: 12 additions & 29 deletions cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,18 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
Returns:
The converted and validated array.
"""

if sklearn_version < version.parse("1.8"):
return sklearn_utils_validation.check_array(
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
order=None,
copy=False,
force_all_finite=True,
ensure_2d=True,
allow_nd=False,
ensure_min_samples=min_samples,
ensure_min_features=1,
)
else:
return sklearn_utils_validation.check_array(
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
order=None,
copy=False,
ensure_all_finite=True,
ensure_2d=True,
allow_nd=False,
ensure_min_samples=min_samples,
ensure_min_features=1,
)
return _check_array_ensure_all_finite(
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
order=None,
copy=False,
ensure_2d=True,
allow_nd=False,
ensure_min_samples=min_samples,
ensure_min_features=1,
)


def check_label_array(y: npt.NDArray, *, min_samples: int):
Expand Down

0 comments on commit 128257b

Please sign in to comment.