Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Praateek Mahajan <[email protected]>
  • Loading branch information
praateekmahajan committed Sep 24, 2024
1 parent 2686459 commit 224dd58
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
62 changes: 50 additions & 12 deletions crossfit/backend/cudf/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import lru_cache
import cudf
import cupy as cp
from cudf.core.column import as_column
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from cudf.core.dtypes import ListDtype

if TYPE_CHECKING:
from cudf.core.buffer import Buffer
from cudf.core.dtypes import ListDtype
from cudf.core.column.numerical import NumericalColumn
from cudf.core.column import ColumnBase

from importlib_metadata import version
from packaging.version import parse as parse_version



@lru_cache
def _is_cudf_gte_24_10():
current_cudf_version = parse_version(version('cudf_cu12'))
cudf_24_10_version = parse_version('24.10')

if current_cudf_version >= cudf_24_10_version or (current_cudf_version.base_version >= "24.10.0" and current_cudf_version.is_prerelease):
return True
elif current_cudf_version < cudf_24_10_version:
return False
else:
msg = f"Found uncaught cudf version {current_cudf_version}"
raise NotImplementedError(msg)


def _construct_series_from_list_column(
index : cudf.Series,
lc : cudf.core.column.ListColumn
):
if not _is_cudf_gte_24_10():
return cudf.Series(data=lc, index=index)
else:
# in pre 24.10 releases index could be any list
from cudf.core.index import ensure_index
return cudf.Series._from_column(column=lc, index=ensure_index(index))



def _construct_list_column(
size: int,
dtype: ListDtype,
mask: Buffer | None = None,
mask: Optional["Buffer"] = None,
offset: int = 0,
null_count: int | None = None,
children: tuple[NumericalColumn, ColumnBase] = (), # type: ignore[assignment]
null_count: Optional[int] = None,
children: tuple["NumericalColumn", "ColumnBase"] = (), # type: ignore[assignment]
) -> cudf.core.column.ListColumn:

kwargs = dict(
size=size,
dtype=dtype,
Expand All @@ -43,9 +73,10 @@ def _construct_list_column(
children=children,
)

if version('cudf') <= "24.08":
if not _is_cudf_gte_24_10():
return cudf.core.column.ListColumn(**kwargs)
elif version('cudf') >= "24.10":
else:
# in 24.10 ListColumn added `data` kwarg see https://github.com/rapidsai/crossfit/issues/84
return cudf.core.column.ListColumn(data = None, **kwargs)

def create_list_series_from_1d_or_2d_ar(ar, index):
Expand All @@ -72,7 +103,7 @@ def create_list_series_from_1d_or_2d_ar(ar, index):
null_count=0,
children=(offset_col, data)
)
return cudf.Series(lc, index=index)
return _construct_series_from_list_column(lc=lc, index=index)


def create_nested_list_series_from_3d_ar(ar, index):
Expand All @@ -94,16 +125,23 @@ def create_nested_list_series_from_3d_ar(ar, index):
outer_list_offsets = as_column(outer_offsets)

# Constructing the nested ListColumn
inner_lc = cudf.core.column.ListColumn(
inner_lc = _construct_list_column(
size=inner_offsets.size - 1,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(inner_list_offsets, inner_list_data),
mask = None,
offset=0,
null_count=None
)

lc = cudf.core.column.ListColumn(
lc = _construct_list_column(
size=n_slices,
dtype=cudf.ListDtype(inner_list_data.dtype),
children=(outer_list_offsets, inner_lc),
mask = None,
offset=0,
null_count=None

)

return cudf.Series(lc, index=index)
return _construct_series_from_list_column(lc=lc, index=index)
1 change: 0 additions & 1 deletion tests/backend/cudf_backend/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test_create_nested_list_series_from_3d_ar():
tensor = torch.tensor(nested_list)
index = [1, 2]
series = create_nested_list_series_from_3d_ar(tensor, index)
print(series)
assert isinstance(series, cudf.Series)
expected = cudf.Series(nested_list, index=index)
# convert to pandas because cudf.Series.equals doesn't work for list series
Expand Down

0 comments on commit 224dd58

Please sign in to comment.