Skip to content

Commit

Permalink
Repair broken queries (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Feb 22, 2024
1 parent 01e248b commit 2329e6d
Show file tree
Hide file tree
Showing 22 changed files with 86 additions and 91 deletions.
3 changes: 1 addition & 2 deletions queries/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
CWD = Path(__file__).parent
ROOT = CWD.parent
DATASET_BASE_DIR = ROOT / "data" / "tables" / f"scale-{SCALE_FACTOR}"
ANSWERS_BASE_DIR = ROOT / "tpch-dbgen/answers"
ANSWERS_PARQUET_BASE_DIR = ROOT / "data" / "answers"
ANSWERS_BASE_DIR = ROOT / "data" / "answers"
TIMINGS_FILE = ROOT / os.environ.get("TIMINGS_FILE", "timings.csv")
DEFAULT_PLOTS_DIR = ROOT / "plots"

Expand Down
4 changes: 2 additions & 2 deletions queries/dask/q3.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import datetime
from datetime import datetime

from queries.dask import utils

Q_NUM = 3


def q():
var1 = datetime.datetime.strptime("1995-03-15", "%Y-%m-%d")
var1 = datetime(1995, 3, 15)
var2 = "BUILDING"

line_item_ds = utils.get_line_item_ds
Expand Down
4 changes: 2 additions & 2 deletions queries/dask/q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


def q():
date1 = datetime.strptime("1993-10-01", "%Y-%m-%d")
date2 = datetime.strptime("1993-07-01", "%Y-%m-%d")
date1 = datetime(1993, 10, 1)
date2 = datetime(1993, 7, 1)

line_item_ds = utils.get_line_item_ds
orders_ds = utils.get_orders_ds
Expand Down
4 changes: 2 additions & 2 deletions queries/dask/q5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


def q():
date1 = datetime.datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.datetime.strptime("1995-01-01", "%Y-%m-%d")
date1 = datetime(1994, 1, 1)
date2 = datetime(1995, 1, 1)

region_ds = utils.get_region_ds
nation_ds = utils.get_nation_ds
Expand Down
6 changes: 3 additions & 3 deletions queries/dask/q6.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import datetime
from datetime import datetime

import pandas as pd

Expand All @@ -8,8 +8,8 @@


def q():
date1 = datetime.datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.datetime.strptime("1995-01-01", "%Y-%m-%d")
date1 = datetime(1994, 1, 1)
date2 = datetime(1995, 1, 1)
var3 = 24

line_item_ds = utils.get_line_item_ds
Expand Down
12 changes: 8 additions & 4 deletions queries/dask/q7.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import datetime
import warnings
from datetime import datetime

import dask.dataframe as dd
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
import dask.dataframe as dd

from queries.dask import utils

Q_NUM = 7


def q():
var1 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var2 = datetime.strptime("1997-01-01", "%Y-%m-%d")
var1 = datetime(1995, 1, 1)
var2 = datetime(1997, 1, 1)

nation_ds = utils.get_nation_ds
customer_ds = utils.get_customer_ds
line_item_ds = utils.get_line_item_ds
Expand Down
14 changes: 5 additions & 9 deletions queries/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dask.dataframe as dd
import pandas as pd
from linetimer import CodeTimer, linetimer
from pandas.testing import assert_series_equal

from queries.common_utils import (
ANSWERS_BASE_DIR,
Expand All @@ -30,14 +31,9 @@ def read_ds(path: str) -> Union:
return dd.from_pandas(pd.read_parquet(path), npartitions=os.cpu_count())


def get_query_answer(query: int, base_dir: Path = ANSWERS_BASE_DIR) -> dd.DataFrame:
answer_df = pd.read_csv(
base_dir / f"q{query}.out",
sep="|",
parse_dates=True,
infer_datetime_format=True,
)
return answer_df.rename(columns=lambda x: x.strip())
def get_query_answer(query: int, base_dir: str = ANSWERS_BASE_DIR) -> pd.DataFrame:
path = base_dir / f"q{query}.parquet"
return pd.read_parquet(path)


def test_results(q_num: int, result_df: pd.DataFrame):
Expand All @@ -52,7 +48,7 @@ def test_results(q_num: int, result_df: pd.DataFrame):
s1 = s1.astype("string").apply(lambda x: x.strip())
s2 = s2.astype("string").apply(lambda x: x.strip())

pd.testing.assert_series_equal(left=s1, right=s2, check_index=False)
assert_series_equal(left=s1, right=s2, check_index=False, check_dtype=False)


@on_second_call
Expand Down
13 changes: 6 additions & 7 deletions queries/duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import polars as pl
from duckdb import DuckDBPyRelation
from linetimer import CodeTimer, linetimer
from polars import testing as pl_test
from polars.testing import assert_frame_equal

from queries.common_utils import (
ANSWERS_PARQUET_BASE_DIR,
ANSWERS_BASE_DIR,
DATASET_BASE_DIR,
FILE_TYPE,
INCLUDE_IO,
Expand Down Expand Up @@ -40,16 +40,15 @@ def _scan_ds(path: Path):
return path


def get_query_answer(
query: int, base_dir: str = ANSWERS_PARQUET_BASE_DIR
) -> pl.LazyFrame:
return pl.scan_parquet(Path(base_dir) / f"q{query}.parquet")
def get_query_answer(query: int, base_dir: Path = ANSWERS_BASE_DIR) -> pl.LazyFrame:
path = base_dir / f"q{query}.parquet"
return pl.scan_parquet(path)


def test_results(q_num: int, result_df: pl.DataFrame):
with CodeTimer(name=f"Testing result of duckdb Query {q_num}", unit="s"):
answer = get_query_answer(q_num).collect()
pl_test.assert_frame_equal(left=result_df, right=answer, check_dtype=False)
assert_frame_equal(left=result_df, right=answer, check_dtype=False)


def get_line_item_ds(base_dir: str = DATASET_BASE_DIR) -> str:
Expand Down
4 changes: 2 additions & 2 deletions queries/modin/q3.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import datetime
from datetime import datetime

from queries.modin import utils

Q_NUM = 3


def q():
var1 = var2 = datetime.datetime.strptime("1995-03-15", "%Y-%m-%d")
var1 = var2 = datetime(1995, 3, 15)
var3 = "BUILDING"

customer_ds = utils.get_customer_ds
Expand Down
6 changes: 3 additions & 3 deletions queries/modin/q4.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
from datetime import datetime

from queries.modin import utils

Q_NUM = 4


def q():
date1 = datetime.datetime.strptime("1993-10-01", "%Y-%m-%d")
date2 = datetime.datetime.strptime("1993-07-01", "%Y-%m-%d")
date1 = datetime(1993, 10, 1)
date2 = datetime(1993, 7, 1)

line_item_ds = utils.get_line_item_ds
orders_ds = utils.get_orders_ds
Expand Down
6 changes: 3 additions & 3 deletions queries/modin/q5.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
from datetime import datetime

from queries.modin import utils

Q_NUM = 5


def q():
date1 = datetime.datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.datetime.strptime("1995-01-01", "%Y-%m-%d")
date1 = datetime(1994, 1, 1)
date2 = datetime(1995, 1, 1)

region_ds = utils.get_region_ds
nation_ds = utils.get_nation_ds
Expand Down
6 changes: 3 additions & 3 deletions queries/modin/q6.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import datetime
from datetime import datetime

import modin.pandas as pd

Expand All @@ -8,8 +8,8 @@


def q():
date1 = datetime.datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.datetime.strptime("1995-01-01", "%Y-%m-%d")
date1 = datetime(1994, 1, 1)
date2 = datetime(1995, 1, 1)
var3 = 24

line_item_ds = utils.get_line_item_ds
Expand Down
8 changes: 5 additions & 3 deletions queries/modin/q7.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import datetime
from datetime import datetime

import modin.pandas as pd

Expand All @@ -8,6 +8,9 @@


def q():
var1 = datetime(1995, 1, 1)
var2 = datetime(1997, 1, 1)

nation_ds = utils.get_nation_ds
customer_ds = utils.get_customer_ds
line_item_ds = utils.get_line_item_ds
Expand Down Expand Up @@ -35,8 +38,7 @@ def query():
supplier_ds = supplier_ds()

lineitem_filtered = line_item_ds[
(line_item_ds["l_shipdate"] >= datetime.strptime("1995-01-01", "%Y-%m-%d"))
& (line_item_ds["l_shipdate"] < datetime.strptime("1997-01-01", "%Y-%m-%d"))
(line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)
]
lineitem_filtered["l_year"] = lineitem_filtered["l_shipdate"].dt.year
lineitem_filtered["revenue"] = lineitem_filtered["l_extendedprice"] * (
Expand Down
11 changes: 3 additions & 8 deletions queries/modin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,11 @@ def __read_parquet_ds(path: str) -> PandasDF:
return pd.read_parquet(path, dtype_backend="pyarrow", engine="pyarrow")


def get_query_answer(query: int, base_dir: Path = ANSWERS_BASE_DIR) -> PandasDF:
def get_query_answer(query: int, base_dir: str = ANSWERS_BASE_DIR) -> PandasDF:
import pandas as pd

answer_df = pd.read_csv(
base_dir / f"q{query}.out",
sep="|",
parse_dates=True,
infer_datetime_format=True,
)
return answer_df.rename(columns=lambda x: x.strip())
path = base_dir / f"q{query}.parquet"
return pd.read_parquet(path)


def test_results(q_num: int, result_df: PandasDF):
Expand Down
6 changes: 3 additions & 3 deletions queries/pandas/q5.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import datetime
from datetime import datetime

from queries.pandas import utils

Q_NUM = 5


def q():
date1 = datetime.datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.datetime.strptime("1995-01-01", "%Y-%m-%d")
date1 = datetime(1994, 1, 1)
date2 = datetime(1995, 1, 1)

region_ds = utils.get_region_ds
nation_ds = utils.get_nation_ds
Expand Down
2 changes: 1 addition & 1 deletion queries/pandas/q8.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def udf(df):
numerator = df["volume"].sum()
return round(numerator / demonimator, 2)

total = total.groupby("o_year", as_index=False).apply(udf)
total = total.groupby("o_year", as_index=False).apply(udf, include_groups=False)
total.columns = ["o_year", "mkt_share"]
total = total.sort_values(by=["o_year"], ascending=[True])

Expand Down
28 changes: 15 additions & 13 deletions queries/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import pandas as pd
from linetimer import CodeTimer, linetimer
from pandas.api.types import is_string_dtype
from pandas.core.frame import DataFrame as PandasDF
from pandas.testing import assert_series_equal

from queries.common_utils import (
ANSWERS_BASE_DIR,
Expand All @@ -16,26 +18,23 @@
on_second_call,
)

pd.options.mode.copy_on_write = True


def _read_ds(path: Path) -> PandasDF:
path = f"{path}.{FILE_TYPE}"
if FILE_TYPE == "parquet":
return pd.read_parquet(path, dtype_backend="pyarrow", engine="pyarrow")
return pd.read_parquet(path, dtype_backend="pyarrow")
elif FILE_TYPE == "feather":
return pd.read_feather(path)
return pd.read_feather(path, dtype_backend="pyarrow")
else:
msg = f"file type: {FILE_TYPE} not expected"
raise ValueError(msg)


def get_query_answer(query: int, base_dir: str = ANSWERS_BASE_DIR) -> PandasDF:
answer_df = pd.read_csv(
Path(base_dir) / f"q{query}.out",
sep="|",
parse_dates=True,
infer_datetime_format=True,
)
return answer_df.rename(columns=lambda x: x.strip())
path = base_dir / f"q{query}.parquet"
return pd.read_parquet(path, dtype_backend="pyarrow")


def test_results(q_num: int, result_df: PandasDF):
Expand All @@ -46,11 +45,14 @@ def test_results(q_num: int, result_df: PandasDF):
s1 = result_df[c]
s2 = answer[c]

if t.name == "object":
s1 = s1.astype("string").apply(lambda x: x.strip())
s2 = s2.astype("string").apply(lambda x: x.strip())
if is_string_dtype(t):
s1 = s1.apply(lambda x: x.strip())

# TODO: Remove this cast
if s2.dtype == "date32[day][pyarrow]":
s2 = s2.astype("timestamp[us][pyarrow]")

pd.testing.assert_series_equal(left=s1, right=s2, check_index=False)
assert_series_equal(left=s1, right=s2, check_index=False, check_dtype=False)


@on_second_call
Expand Down
8 changes: 3 additions & 5 deletions queries/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from polars.testing import assert_frame_equal

from queries.common_utils import (
ANSWERS_PARQUET_BASE_DIR,
ANSWERS_BASE_DIR,
DATASET_BASE_DIR,
FILE_TYPE,
INCLUDE_IO,
Expand All @@ -34,10 +34,8 @@ def _scan_ds(path: Path):
return scan.collect().rechunk().lazy()


def get_query_answer(
query: int, base_dir: str = ANSWERS_PARQUET_BASE_DIR
) -> pl.LazyFrame:
return pl.scan_parquet(Path(base_dir) / f"q{query}.parquet")
def get_query_answer(query: int, base_dir: Path = ANSWERS_BASE_DIR) -> pl.LazyFrame:
return pl.scan_parquet(base_dir / f"q{query}.parquet")


def test_results(q_num: int, result_df: pl.DataFrame):
Expand Down
Loading

0 comments on commit 2329e6d

Please sign in to comment.