diff --git a/tests/tpch/dask_queries.py b/tests/tpch/dask_queries.py index 43b015ad55..cb4d6bcaa7 100644 --- a/tests/tpch/dask_queries.py +++ b/tests/tpch/dask_queries.py @@ -577,8 +577,7 @@ def query_11(dataset_path, fs, scale): joined["value"] = joined.ps_supplycost * joined.ps_availqty - # FIXME: https://github.com/dask-contrib/dask-expr/issues/867 - res = joined.groupby("ps_partkey")["value"].sum(split_out=True) + res = joined.groupby("ps_partkey")["value"].sum() res = ( res[res > threshold] .round(2) @@ -678,8 +677,7 @@ def query_13(dataset_path, fs, scale): ) subquery = ( subquery.groupby("c_custkey") - # FIXME: https://github.com/dask-contrib/dask-expr/issues/867 - .o_orderkey.count(split_out=True) + .o_orderkey.count() .to_frame() .reset_index() .rename(columns={"o_orderkey": "c_count"})[["c_custkey", "c_count"]] @@ -790,7 +788,7 @@ def query_15(dataset_path, fs, scale): lineitem["revenue"] = lineitem.l_extendedprice * (1 - lineitem.l_discount) revenue = ( lineitem.groupby("l_suppkey") - .revenue.sum(split_out=True) + .revenue.sum() .to_frame() .reset_index() .rename(columns={"revenue": "total_revenue", "l_suppkey": "supplier_no"}) @@ -843,15 +841,17 @@ def query_16(dataset_path, fs, scale): supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs) supplier["is_complaint"] = supplier.s_comment.str.contains("Customer.*Complaints") - # FIXME: We have to compute this early because passing a `dask_expr.Series` to `isin` is not supported - complaint_suppkeys = supplier[supplier.is_complaint].s_suppkey.compute() + # We can only broadcast 1 partition series objects + complaint_suppkeys = supplier[supplier.is_complaint].s_suppkey.repartition( + npartitions=1 + ) + partsupp = partsupp[~partsupp.ps_suppkey.isin(complaint_suppkeys)] table = partsupp.merge(part, left_on="ps_partkey", right_on="p_partkey") table = table[ (table.p_brand != "Brand#45") & (~table.p_type.str.startswith("MEDIUM POLISHED")) & (table.p_size.isin((49, 14, 23, 45, 19, 3, 36, 9))) - & (~table.ps_suppkey.isin(complaint_suppkeys)) ] return ( table.groupby(by=["p_brand", "p_type", "p_size"])