Skip to content

Commit

Permalink
fix(polars): make multi-argument udfs work again
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 16, 2024
1 parent cae91bc commit 5c50052
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,16 +1287,18 @@ def execute_count_distinct_star(op, **kw):
# Convert polars series into a list
# -> map the function element by element
# -> convert back to a polars series
InputType.PYTHON: lambda func, dtype, args: pl.Series(
map(func, *(arg.to_list() for arg in args)),
InputType.PYTHON: lambda func, dtype, fields, args: pl.Series(
map(func, *(arg.to_list() for arg in map(args.struct.field, fields))),
dtype=PolarsType.from_ibis(dtype),
),
# Convert polars series into a pyarrow array
# -> invoke the function on the pyarrow array
# -> cast the result to match the ibis dtype
# -> convert back to a polars series
InputType.PYARROW: lambda func, dtype, args: pl.from_arrow(
func(*(arg.to_arrow() for arg in args)).cast(dtype.to_pyarrow()),
InputType.PYARROW: lambda func, dtype, fields, args: pl.from_arrow(
func(*(arg.to_arrow() for arg in map(args.struct.field, fields))).cast(
dtype.to_pyarrow()
),
),
}

Expand All @@ -1306,9 +1308,12 @@ def execute_scalar_udf(op, **kw):
input_type = op.__input_type__
if input_type in _UDF_INVOKERS:
dtype = op.dtype
return pl.map_batches(
exprs=[translate(arg, **kw) for arg in op.args],
function=partial(_UDF_INVOKERS[input_type], op.__func__, dtype),
argnames = op.argnames
args = pl.struct(
**dict(zip(argnames, (translate(arg, **kw) for arg in op.args)))
)
return args.map_batches(
function=partial(_UDF_INVOKERS[input_type], op.__func__, dtype, argnames),
return_dtype=PolarsType.from_ibis(dtype),
)
elif input_type == InputType.BUILTIN:
Expand Down

0 comments on commit 5c50052

Please sign in to comment.