Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
Signed-off-by: Praateek Mahajan <[email protected]>
  • Loading branch information
praateekmahajan committed Sep 24, 2024
1 parent 224dd58 commit c142902
Showing 1 changed file with 28 additions and 30 deletions.
58 changes: 28 additions & 30 deletions crossfit/backend/cudf/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional

import cudf
import cupy as cp
from cudf.core.column import as_column
from typing import TYPE_CHECKING, Optional
from cudf.core.dtypes import ListDtype

if TYPE_CHECKING:
Expand All @@ -28,13 +29,14 @@
from packaging.version import parse as parse_version



@lru_cache
def _is_cudf_gte_24_10():
current_cudf_version = parse_version(version('cudf_cu12'))
cudf_24_10_version = parse_version('24.10')
current_cudf_version = parse_version(version("cudf_cu12"))
cudf_24_10_version = parse_version("24.10")

if current_cudf_version >= cudf_24_10_version or (current_cudf_version.base_version >= "24.10.0" and current_cudf_version.is_prerelease):
if current_cudf_version >= cudf_24_10_version or (
current_cudf_version.base_version >= "24.10.0" and current_cudf_version.is_prerelease
):
return True
elif current_cudf_version < cudf_24_10_version:
return False
Expand All @@ -43,41 +45,38 @@ def _is_cudf_gte_24_10():
raise NotImplementedError(msg)


def _construct_series_from_list_column(
index : cudf.Series,
lc : cudf.core.column.ListColumn
):
def _construct_series_from_list_column(index: Any, lc: cudf.core.column.ListColumn) -> cudf.Series:
if not _is_cudf_gte_24_10():
return cudf.Series(data=lc, index=index)
else:
# in pre 24.10 releases index could be any list
from cudf.core.index import ensure_index

return cudf.Series._from_column(column=lc, index=ensure_index(index))



def _construct_list_column(
size: int,
dtype: ListDtype,
mask: Optional["Buffer"] = None,
offset: int = 0,
null_count: Optional[int] = None,
children: tuple["NumericalColumn", "ColumnBase"] = (), # type: ignore[assignment]
size: int,
dtype: ListDtype,
mask: Optional["Buffer"] = None,
offset: int = 0,
null_count: Optional[int] = None,
children: tuple["NumericalColumn", "ColumnBase"] = (), # type: ignore[assignment]
) -> cudf.core.column.ListColumn:
kwargs = dict(
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
)

if not _is_cudf_gte_24_10():
return cudf.core.column.ListColumn(**kwargs)
else:
# in 24.10 ListColumn added `data` kwarg see https://github.com/rapidsai/crossfit/issues/84
return cudf.core.column.ListColumn(data = None, **kwargs)
return cudf.core.column.ListColumn(data=None, **kwargs)


def create_list_series_from_1d_or_2d_ar(ar, index):
"""
Expand All @@ -101,7 +100,7 @@ def create_list_series_from_1d_or_2d_ar(ar, index):
mask=mask,
offset=0,
null_count=0,
children=(offset_col, data)
children=(offset_col, data),
)
return _construct_series_from_list_column(lc=lc, index=index)

Expand Down Expand Up @@ -129,19 +128,18 @@ def create_nested_list_series_from_3d_ar(ar, index):
size=inner_offsets.size - 1,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(inner_list_offsets, inner_list_data),
mask = None,
mask=None,
offset=0,
null_count=None
null_count=None,
)

lc = _construct_list_column(
size=n_slices,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(outer_list_offsets, inner_lc),
mask = None,
mask=None,
offset=0,
null_count=None

null_count=None,
)

return _construct_series_from_list_column(lc=lc, index=index)

0 comments on commit c142902

Please sign in to comment.