Skip to content

Commit

Permalink
Add dt methods for cudf
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Durant committed Aug 16, 2024
1 parent b55ea39 commit 9b9f27f
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 11 deletions.
61 changes: 51 additions & 10 deletions src/akimbo/cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import awkward as ak
import cudf
from cudf import DataFrame, Series
from cudf import DataFrame, Series, _lib as libcudf
from cudf.core.column.string import StringMethods
from cudf.core.column.datetime import DatetimeColumn

from akimbo.ak_from_cudf import cudf_to_awkward as from_cudf
from akimbo.mixin import Accessor
from akimbo.datetimes import DatetimeAccessor, match as match_t
from akimbo.strings import StringAccessor
from akimbo.apply_tree import dec
from akimbo.apply_tree import dec, leaf


def match_string(arr):
Expand All @@ -26,23 +28,58 @@ def encode(self, encoding: str = "utf-8"):
raise NotImplementedError("cudf does not support bytearray type")


def dec_cu(op, match=match_string):

@functools.wraps(op)
def f(lay, **kwargs):
# op(column, ...)->column
col = op(lay._to_cudf(cudf, None, len(lay)), **kwargs)
return from_cudf(cudf.Series(col)).layout

return dec(func=f, match=match, inmode="ak")


for meth in dir(StringMethods):
if meth.startswith("_"):
continue

@functools.wraps(getattr(StringMethods, meth))
def f(lay, *args, method=meth, **kwargs):
if not match_string(lay):
return

# unnecessary round-tripping, and repeating logic from `dec`?
args = args or kwargs.pop("args", ())
col = getattr(StringMethods(cudf.Series(lay._to_cudf(cudf, None, len(lay)))), method)(*args, **kwargs)
def f(lay, method=meth, **kwargs):
# this is different from dec_cu, because we need to instantiate StringMethods
# before getting the method from it
col = getattr(StringMethods(cudf.Series(lay._to_cudf(cudf, None, len(lay)))), method)(**kwargs)
return from_cudf(col).layout

setattr(CudfStringAccessor, meth, dec(func=f, match=match_string, inmode="ak"))


class CudfDatetimeAccessor(DatetimeAccessor):

...


for meth in dir(DatetimeColumn):
if meth.startswith("_"):
continue

@functools.wraps(getattr(DatetimeColumn, meth))
def f(lay, method=meth, **kwargs):
# this is different from dec_cu, because we need to instantiate StringMethods
# before getting the method from it
m = getattr(lay._to_cudf(cudf, None, len(lay)), method)
if callable(m):
col = m(**kwargs)
else:
# attributes giving components
col = m
return from_cudf(cudf.Series(col)).layout

if isinstance(getattr(DatetimeColumn, meth), property):
setattr(CudfDatetimeAccessor, meth, property(dec(func=f, match=match_t, inmode="ak")))
else:
setattr(CudfDatetimeAccessor, meth, dec(func=f, match=match_t, inmode="ak"))


class CudfAwkwardAccessor(Accessor):
series_type = Series
dataframe_type = DataFrame
Expand All @@ -51,6 +88,8 @@ class CudfAwkwardAccessor(Accessor):
def _to_output(cls, arr):
if isinstance(arr, ak.Array):
return ak.to_cudf(arr)
elif isinstance(arr, ak.contents.Content):
return arr._to_cudf(cudf, None, len(arr))
return arr

@classmethod
Expand All @@ -67,11 +106,13 @@ def str(self):
# need to find string ops within cudf
return CudfStringAccessor(self)

cast = dec_cu(libcudf.unary.cast, match=leaf)

@property
def dt(self):
"""Nested datetime operations"""
# need to find datetime ops within cudf
raise NotImplementedError
return CudfDatetimeAccessor(self)

def apply(self, fn: Callable, *args, **kwargs):
if "CPUDispatcher" in str(fn):
Expand Down
6 changes: 5 additions & 1 deletion src/akimbo/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ def ak_to_series(ds, backend="pandas", extract=True):

# TODO: actually don't use this, use dask-awkward, or dask.dataframe
s = akimbo.polars.PolarsAwkwardAccessor._to_output(ds)
elif backend == "cudf":
import akimbo.cudf

s = akimbo.cudf.CudfAwkwardAccessor._to_output(ds)
else:
raise ValueError("Backend must be in {'pandas', 'polars', 'dask'}")
if extract:
if extract and ds.fields:
return s.ak.unmerge()
return s

Expand Down
28 changes: 28 additions & 0 deletions tests/test_cudf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import datetime

import pytest

import pyarrow as pa
import awkward as ak

pytest.importorskip("akimbo.cudf")

import akimbo.io
import cudf


Expand Down Expand Up @@ -37,3 +40,28 @@ def test_string_methods():
# non-str output
s2 = series.ak.str.len()
assert s2.ak.to_list() == [{"s": [3, 2], "i": [0]}, {"s": [3, 2], "i": [2]}]


def test_cast():
s = cudf.Series([0, 1, 2])
# shows that cast to timestamp needs to be two-step in cudf
s2 = s.ak.cast('m8[s]').ak.cast('M8[s]')
out = s2.ak.to_list()
assert out == [
datetime.datetime(1970, 1, 1, 0, 0),
datetime.datetime(1970, 1, 1, 0, 0, 1),
datetime.datetime(1970, 1, 1, 0, 0, 2)
]


def test_times():
data = [
datetime.datetime(1970, 1, 1, 0, 0),
datetime.datetime(1970, 1, 1, 0, 0, 1),
None,
datetime.datetime(1970, 1, 1, 0, 0, 2)
]
arr = ak.Array([[data], [], [data]])
s = akimbo.io.ak_to_series(arr, "cudf")
s2 = s.ak.dt.second
assert s2.ak.to_list() == [[[0, 1, None, 2]], [], [[0, 1, None, 2]]]

0 comments on commit 9b9f27f

Please sign in to comment.