diff --git a/crossfit/backend/cudf/series.py b/crossfit/backend/cudf/series.py index 79b8d66..3246256 100644 --- a/crossfit/backend/cudf/series.py +++ b/crossfit/backend/cudf/series.py @@ -12,28 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import lru_cache import cudf import cupy as cp from cudf.core.column import as_column -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional +from cudf.core.dtypes import ListDtype if TYPE_CHECKING: from cudf.core.buffer import Buffer - from cudf.core.dtypes import ListDtype from cudf.core.column.numerical import NumericalColumn from cudf.core.column import ColumnBase from importlib_metadata import version +from packaging.version import parse as parse_version + + +@lru_cache +def _is_cudf_gte_24_10(): + current_cudf_version = parse_version(version('cudf_cu12')) + cudf_24_10_version = parse_version('24.10') + + if current_cudf_version >= cudf_24_10_version or (current_cudf_version.base_version >= "24.10.0" and current_cudf_version.is_prerelease): + return True + elif current_cudf_version < cudf_24_10_version: + return False + else: + msg = f"Found uncaught cudf version {current_cudf_version}" + raise NotImplementedError(msg) + + +def _construct_series_from_list_column( + index : cudf.Series, + lc : cudf.core.column.ListColumn +): + if not _is_cudf_gte_24_10(): + return cudf.Series(data=lc, index=index) + else: + # in pre 24.10 releases index could be any list + from cudf.core.index import ensure_index + return cudf.Series._from_column(column=lc, index=ensure_index(index)) + + + def _construct_list_column( size: int, dtype: ListDtype, - mask: Buffer | None = None, + mask: Optional["Buffer"] = None, offset: int = 0, - null_count: int | None = None, - children: tuple[NumericalColumn, ColumnBase] = (), # type: ignore[assignment] + null_count: Optional[int] = None, + children: tuple["NumericalColumn", "ColumnBase"] = (), # type: ignore[assignment] ) -> cudf.core.column.ListColumn: - kwargs = dict( size=size, dtype=dtype, @@ -43,9 +73,10 @@ def _construct_list_column( children=children, ) - if version('cudf') <= "24.08": + if not _is_cudf_gte_24_10(): return cudf.core.column.ListColumn(**kwargs) - elif version('cudf') >= "24.10": + else: + # in 24.10 ListColumn added `data` kwarg see https://github.com/rapidsai/crossfit/issues/84 return cudf.core.column.ListColumn(data = None, **kwargs) def create_list_series_from_1d_or_2d_ar(ar, index): @@ -72,7 +103,7 @@ def create_list_series_from_1d_or_2d_ar(ar, index): null_count=0, children=(offset_col, data) ) - return cudf.Series(lc, index=index) + return _construct_series_from_list_column(lc=lc, index=index) def create_nested_list_series_from_3d_ar(ar, index): @@ -94,16 +125,23 @@ def create_nested_list_series_from_3d_ar(ar, index): outer_list_offsets = as_column(outer_offsets) # Constructing the nested ListColumn - inner_lc = cudf.core.column.ListColumn( + inner_lc = _construct_list_column( size=inner_offsets.size - 1, dtype=cudf.ListDtype(inner_list_data.dtype), children=(inner_list_offsets, inner_list_data), + mask = None, + offset=0, + null_count=None ) - lc = cudf.core.column.ListColumn( + lc = _construct_list_column( size=n_slices, dtype=cudf.ListDtype(inner_list_data.dtype), children=(outer_list_offsets, inner_lc), + mask = None, + offset=0, + null_count=None + ) - return cudf.Series(lc, index=index) + return _construct_series_from_list_column(lc=lc, index=index) diff --git a/tests/backend/cudf_backend/test_series.py b/tests/backend/cudf_backend/test_series.py index 329a0a0..1c0ba78 100644 --- a/tests/backend/cudf_backend/test_series.py +++ b/tests/backend/cudf_backend/test_series.py @@ -37,7 +37,6 @@ def test_create_nested_list_series_from_3d_ar(): tensor = torch.tensor(nested_list) index = [1, 2] series = create_nested_list_series_from_3d_ar(tensor, index) - print(series) assert isinstance(series, cudf.Series) expected = cudf.Series(nested_list, index=index) # convert to pandas because cudf.Series.equals doesn't work for list series