diff --git a/src/akimbo/cudf.py b/src/akimbo/cudf.py index 633ae4d..6c2bb8a 100644 --- a/src/akimbo/cudf.py +++ b/src/akimbo/cudf.py @@ -3,13 +3,15 @@ import awkward as ak import cudf -from cudf import DataFrame, Series +from cudf import DataFrame, Series, _lib as libcudf from cudf.core.column.string import StringMethods +from cudf.core.column.datetime import DatetimeColumn from akimbo.ak_from_cudf import cudf_to_awkward as from_cudf from akimbo.mixin import Accessor +from akimbo.datetimes import DatetimeAccessor, match as match_t from akimbo.strings import StringAccessor -from akimbo.apply_tree import dec +from akimbo.apply_tree import dec, leaf def match_string(arr): @@ -26,23 +28,58 @@ def encode(self, encoding: str = "utf-8"): raise NotImplementedError("cudf does not support bytearray type") +def dec_cu(op, match=match_string): + + @functools.wraps(op) + def f(lay, **kwargs): + # op(column, ...)->column + col = op(lay._to_cudf(cudf, None, len(lay)), **kwargs) + return from_cudf(cudf.Series(col)).layout + + return dec(func=f, match=match, inmode="ak") + + for meth in dir(StringMethods): if meth.startswith("_"): continue @functools.wraps(getattr(StringMethods, meth)) - def f(lay, *args, method=meth, **kwargs): - if not match_string(lay): - return - - # unnecessary round-tripping, and repeating logic from `dec`? - args = args or kwargs.pop("args", ()) - col = getattr(StringMethods(cudf.Series(lay._to_cudf(cudf, None, len(lay)))), method)(*args, **kwargs) + def f(lay, method=meth, **kwargs): + # this is different from dec_cu, because we need to instantiate StringMethods + # before getting the method from it + col = getattr(StringMethods(cudf.Series(lay._to_cudf(cudf, None, len(lay)))), method)(**kwargs) return from_cudf(col).layout setattr(CudfStringAccessor, meth, dec(func=f, match=match_string, inmode="ak")) +class CudfDatetimeAccessor(DatetimeAccessor): + + ... + + +for meth in dir(DatetimeColumn): + if meth.startswith("_"): + continue + + @functools.wraps(getattr(DatetimeColumn, meth)) + def f(lay, method=meth, **kwargs): + # this is different from dec_cu, because we need to instantiate StringMethods + # before getting the method from it + m = getattr(lay._to_cudf(cudf, None, len(lay)), method) + if callable(m): + col = m(**kwargs) + else: + # attributes giving components + col = m + return from_cudf(cudf.Series(col)).layout + + if isinstance(getattr(DatetimeColumn, meth), property): + setattr(CudfDatetimeAccessor, meth, property(dec(func=f, match=match_t, inmode="ak"))) + else: + setattr(CudfDatetimeAccessor, meth, dec(func=f, match=match_t, inmode="ak")) + + class CudfAwkwardAccessor(Accessor): series_type = Series dataframe_type = DataFrame @@ -51,6 +88,8 @@ class CudfAwkwardAccessor(Accessor): def _to_output(cls, arr): if isinstance(arr, ak.Array): return ak.to_cudf(arr) + elif isinstance(arr, ak.contents.Content): + return arr._to_cudf(cudf, None, len(arr)) return arr @classmethod @@ -67,11 +106,13 @@ def str(self): # need to find string ops within cudf return CudfStringAccessor(self) + cast = dec_cu(libcudf.unary.cast, match=leaf) + @property def dt(self): """Nested datetime operations""" # need to find datetime ops within cudf - raise NotImplementedError + return CudfDatetimeAccessor(self) def apply(self, fn: Callable, *args, **kwargs): if "CPUDispatcher" in str(fn): diff --git a/src/akimbo/io.py b/src/akimbo/io.py index 8fa36a3..a7745ca 100644 --- a/src/akimbo/io.py +++ b/src/akimbo/io.py @@ -18,9 +18,13 @@ def ak_to_series(ds, backend="pandas", extract=True): # TODO: actually don't use this, use dask-awkward, or dask.dataframe s = akimbo.polars.PolarsAwkwardAccessor._to_output(ds) + elif backend == "cudf": + import akimbo.cudf + + s = akimbo.cudf.CudfAwkwardAccessor._to_output(ds) else: raise ValueError("Backend must be in {'pandas', 'polars', 'dask'}") - if extract: + if extract and ds.fields: return s.ak.unmerge() return s diff --git a/tests/test_cudf.py b/tests/test_cudf.py index 6f65d28..b15b5c8 100644 --- a/tests/test_cudf.py +++ b/tests/test_cudf.py @@ -1,3 +1,5 @@ +import datetime + import pytest import pyarrow as pa @@ -5,6 +7,7 @@ pytest.importorskip("akimbo.cudf") +import akimbo.io import cudf @@ -37,3 +40,28 @@ def test_string_methods(): # non-str output s2 = series.ak.str.len() assert s2.ak.to_list() == [{"s": [3, 2], "i": [0]}, {"s": [3, 2], "i": [2]}] + + +def test_cast(): + s = cudf.Series([0, 1, 2]) + # shows that cast to timestamp needs to be two-step in cudf + s2 = s.ak.cast('m8[s]').ak.cast('M8[s]') + out = s2.ak.to_list() + assert out == [ + datetime.datetime(1970, 1, 1, 0, 0), + datetime.datetime(1970, 1, 1, 0, 0, 1), + datetime.datetime(1970, 1, 1, 0, 0, 2) + ] + + +def test_times(): + data = [ + datetime.datetime(1970, 1, 1, 0, 0), + datetime.datetime(1970, 1, 1, 0, 0, 1), + None, + datetime.datetime(1970, 1, 1, 0, 0, 2) + ] + arr = ak.Array([[data], [], [data]]) + s = akimbo.io.ak_to_series(arr, "cudf") + s2 = s.ak.dt.second + assert s2.ak.to_list() == [[[0, 1, None, 2]], [], [[0, 1, None, 2]]]