Skip to content

Commit

Permalink
fix(trino): ensure that NULLs are preserved in array filter
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 10, 2024
1 parent 6a1fa4f commit 4e851e5
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,52 @@ def visit_ArrayFilter(self, op, *, arg, param, body, index):
else:
placeholder = sg.to_identifier("__trino_filter__")
index = sg.to_identifier(index)
return self.f.filter(
self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
# semantics are: arg if predicate(arg, index) else null
this=self.if_(body, param, NULL),
expressions=[param, index],
keep, value = map(sg.to_identifier, ("keep", "value"))

# first, zip the array with the index and call the user's function,
# returning a struct of {"keep": value-of-predicate, "value": array-element}
zipped = self.f.zip_with(
arg,
# users are limited to 10_000 elements here because it
# seems like trino won't ever actually address the limit
self.f.sequence(0, self.f.cardinality(arg) - 1),
sge.Lambda(
# semantics are: arg if predicate(arg, index) else null
this=self.cast(
sge.Struct(
expressions=[
sge.PropertyEQ(this=keep, expression=body),
sge.PropertyEQ(this=value, expression=param),
]
),
dt.Struct(
{
"keep": dt.boolean,
"value": op.arg.dtype.value_type,
}
),
),
# this=struct(keep=body, value=param),
expressions=[param, index],
),
)

# second, keep only the elements whose predicate returned true
filtered = self.f.filter(
# then, filter out elements that are null
zipped,
sge.Lambda(
this=sge.Dot(this=placeholder, expression=keep),
expressions=[placeholder],
),
)

# finally, extract the "value" field from the struct
return self.f.transform(
filtered,
sge.Lambda(
this=placeholder.is_(sg.not_(NULL)), expressions=[placeholder]
this=sge.Dot(this=placeholder, expression=value),
expressions=[placeholder],
),
)

Expand Down

0 comments on commit 4e851e5

Please sign in to comment.