From 5c500526d37b04a5eb39d59644bd63c04dec9aab Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 16 Sep 2024 06:57:24 -0400 Subject: [PATCH] fix(polars): make multi-argument udfs work again --- ibis/backends/polars/compiler.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index eb4d9cb532328..2f369a65c2235 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1287,16 +1287,18 @@ def execute_count_distinct_star(op, **kw): # Convert polars series into a list # -> map the function element by element # -> convert back to a polars series - InputType.PYTHON: lambda func, dtype, args: pl.Series( - map(func, *(arg.to_list() for arg in args)), + InputType.PYTHON: lambda func, dtype, fields, args: pl.Series( + map(func, *(arg.to_list() for arg in map(args.struct.field, fields))), dtype=PolarsType.from_ibis(dtype), ), # Convert polars series into a pyarrow array # -> invoke the function on the pyarrow array # -> cast the result to match the ibis dtype # -> convert back to a polars series - InputType.PYARROW: lambda func, dtype, args: pl.from_arrow( - func(*(arg.to_arrow() for arg in args)).cast(dtype.to_pyarrow()), + InputType.PYARROW: lambda func, dtype, fields, args: pl.from_arrow( + func(*(arg.to_arrow() for arg in map(args.struct.field, fields))).cast( + dtype.to_pyarrow() + ), ), } @@ -1306,9 +1308,12 @@ def execute_scalar_udf(op, **kw): input_type = op.__input_type__ if input_type in _UDF_INVOKERS: dtype = op.dtype - return pl.map_batches( - exprs=[translate(arg, **kw) for arg in op.args], - function=partial(_UDF_INVOKERS[input_type], op.__func__, dtype), + argnames = op.argnames + args = pl.struct( + **dict(zip(argnames, (translate(arg, **kw) for arg in op.args))) + ) + return args.map_batches( + function=partial(_UDF_INVOKERS[input_type], op.__func__, dtype, argnames), return_dtype=PolarsType.from_ibis(dtype), ) elif input_type == InputType.BUILTIN: