diff --git a/queries/dask/utils.py b/queries/dask/utils.py index 191acc9..7f3b46c 100644 --- a/queries/dask/utils.py +++ b/queries/dask/utils.py @@ -24,59 +24,66 @@ def read_ds(path: Path) -> DataFrame: - if settings.run.file_type != "parquet": - msg = f"unsupported file type: {settings.run.file_type!r}" - raise ValueError(msg) - - if settings.run.include_io: - return dd.read_parquet(path, dtype_backend="pyarrow") # type: ignore[attr-defined,no-any-return] - # TODO: Load into memory before returning the Dask DataFrame. # Code below is tripped up by date types # df = pd.read_parquet(path, dtype_backend="pyarrow") # return dd.from_pandas(df, npartitions=os.cpu_count()) - msg = "cannot run Dask starting from an in-memory representation" - raise RuntimeError(msg) + if not settings.run.include_io: + msg = "cannot run Dask starting from an in-memory representation" + raise RuntimeError(msg) + + path_str = f"{path}.{settings.run.file_type}" + if settings.run.file_type == "parquet": + return dd.read_parquet(path_str, dtype_backend="pyarrow") # type: ignore[attr-defined,no-any-return] + elif settings.run.file_type == "csv": + df = dd.read_csv(path_str, dtype_backend="pyarrow") # type: ignore[attr-defined] + for c in df.columns: + if c.endswith("date"): + df[c] = df[c].astype("date32[day][pyarrow]") + return df # type: ignore[no-any-return] + else: + msg = f"unsupported file type: {settings.run.file_type!r}" + raise ValueError(msg) @on_second_call def get_line_item_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "lineitem.parquet") + return read_ds(settings.dataset_base_dir / "lineitem") @on_second_call def get_orders_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "orders.parquet") + return read_ds(settings.dataset_base_dir / "orders") @on_second_call def get_customer_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "customer.parquet") + return read_ds(settings.dataset_base_dir / "customer") @on_second_call def get_region_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "region.parquet") + return read_ds(settings.dataset_base_dir / "region") @on_second_call def get_nation_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "nation.parquet") + return read_ds(settings.dataset_base_dir / "nation") @on_second_call def get_supplier_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "supplier.parquet") + return read_ds(settings.dataset_base_dir / "supplier") @on_second_call def get_part_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "part.parquet") + return read_ds(settings.dataset_base_dir / "part") @on_second_call def get_part_supp_ds() -> DataFrame: - return read_ds(settings.dataset_base_dir / "partsupp.parquet") + return read_ds(settings.dataset_base_dir / "partsupp") def run_query(query_number: int, query: Callable[..., Any]) -> None: diff --git a/queries/duckdb/utils.py b/queries/duckdb/utils.py index 3c2be83..76c9375 100644 --- a/queries/duckdb/utils.py +++ b/queries/duckdb/utils.py @@ -21,6 +21,16 @@ def _scan_ds(path: Path) -> str: f"create temp table if not exists {name} as select * from read_parquet('{path_str}');" ) return name + if settings.run.file_type == "csv": + if settings.run.include_io: + duckdb.read_csv(path_str) + return f"'{path_str}'" + else: + name = path_str.replace("/", "_").replace(".", "_").replace("-", "_") + duckdb.sql( + f"create temp table if not exists {name} as select * from read_csv('{path_str}');" + ) + return name elif settings.run.file_type == "feather": msg = "duckdb does not support feather for now" raise ValueError(msg) diff --git a/queries/pandas/utils.py b/queries/pandas/utils.py index d682ebb..3849b76 100644 --- a/queries/pandas/utils.py +++ b/queries/pandas/utils.py @@ -24,6 +24,12 @@ def _read_ds(path: Path) -> pd.DataFrame: path_str = f"{path}.{settings.run.file_type}" if settings.run.file_type == "parquet": return pd.read_parquet(path_str, dtype_backend="pyarrow") + elif settings.run.file_type == "csv": + df = pd.read_csv(path_str, dtype_backend="pyarrow") + for c in df.columns: + if c.endswith("date"): + df[c] = df[c].astype("date32[day][pyarrow]") # type: ignore[call-overload] + return df elif settings.run.file_type == "feather": return pd.read_feather(path_str, dtype_backend="pyarrow") else: diff --git a/queries/polars/utils.py b/queries/polars/utils.py index c85a48e..2ff26aa 100644 --- a/queries/polars/utils.py +++ b/queries/polars/utils.py @@ -16,16 +16,21 @@ def _scan_ds(path: Path) -> pl.LazyFrame: path_str = f"{path}.{settings.run.file_type}" + if settings.run.file_type == "parquet": scan = pl.scan_parquet(path_str) elif settings.run.file_type == "feather": scan = pl.scan_ipc(path_str) + elif settings.run.file_type == "csv": + scan = pl.scan_csv(path_str, try_parse_dates=True) else: msg = f"unsupported file type: {settings.run.file_type!r}" raise ValueError(msg) + if settings.run.include_io: return scan - return scan.collect().rechunk().lazy() + else: + return scan.collect().rechunk().lazy() def get_line_item_ds() -> pl.LazyFrame: diff --git a/queries/pyspark/executor.py b/queries/pyspark/executor.py index f4fa8cb..11c52a9 100644 --- a/queries/pyspark/executor.py +++ b/queries/pyspark/executor.py @@ -1,40 +1,4 @@ -from linetimer import CodeTimer - -# TODO: works for now, but need dynamic imports for this. -from queries.pyspark import ( # noqa: F401 - q1, - q2, - q3, - q4, - q5, - q6, - q7, - q8, - q9, - q10, - q11, - q12, - q13, - q14, - q15, - q16, - q17, - q18, - q19, - q20, - q21, - q22, -) +from queries.common_utils import execute_all if __name__ == "__main__": - num_queries = 22 - - with CodeTimer(name="Overall execution of ALL spark queries", unit="s"): - for query_number in range(1, num_queries + 1): - submodule = f"q{query_number}" - try: - eval(f"{submodule}.q()") - except Exception as exc: - print( - f"Exception occurred while executing PySpark query {query_number}:\n{exc}" - ) + execute_all("pyspark") diff --git a/queries/pyspark/utils.py b/queries/pyspark/utils.py index 140ec32..711bc97 100644 --- a/queries/pyspark/utils.py +++ b/queries/pyspark/utils.py @@ -4,17 +4,11 @@ from pyspark.sql import SparkSession -from queries.common_utils import ( - check_query_result_pd, - on_second_call, - run_query_generic, -) +from queries.common_utils import check_query_result_pd, run_query_generic from settings import Settings if TYPE_CHECKING: - from pathlib import Path - - from pyspark.sql import DataFrame as SparkDF + from pyspark.sql import DataFrame settings = Settings() @@ -31,62 +25,59 @@ def get_or_create_spark() -> SparkSession: return spark -def _read_parquet_ds(path: Path, table_name: str) -> SparkDF: - df = get_or_create_spark().read.parquet(str(path)) - df.createOrReplaceTempView(table_name) - return df +def _read_ds(table_name: str) -> DataFrame: + # TODO: Persist data in memory before query + if not settings.run.include_io: + msg = "cannot run PySpark starting from an in-memory representation" + raise RuntimeError(msg) + path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}" -@on_second_call -def get_line_item_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "lineitem.parquet", "lineitem") + if settings.run.file_type == "parquet": + df = get_or_create_spark().read.parquet(str(path)) + elif settings.run.file_type == "csv": + df = get_or_create_spark().read.csv(str(path), header=True, inferSchema=True) + else: + msg = f"unsupported file type: {settings.run.file_type!r}" + raise ValueError(msg) + + df.createOrReplaceTempView(table_name) + return df -@on_second_call -def get_orders_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "orders.parquet", "orders") +def get_line_item_ds() -> DataFrame: + return _read_ds("lineitem") -@on_second_call -def get_customer_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "customer.parquet", "customer") +def get_orders_ds() -> DataFrame: + return _read_ds("orders") -@on_second_call -def get_region_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "region.parquet", "region") +def get_customer_ds() -> DataFrame: + return _read_ds("customer") -@on_second_call -def get_nation_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "nation.parquet", "nation") +def get_region_ds() -> DataFrame: + return _read_ds("region") -@on_second_call -def get_supplier_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "supplier.parquet", "supplier") +def get_nation_ds() -> DataFrame: + return _read_ds("nation") -@on_second_call -def get_part_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "part.parquet", "part") +def get_supplier_ds() -> DataFrame: + return _read_ds("supplier") -@on_second_call -def get_part_supp_ds() -> SparkDF: - return _read_parquet_ds(settings.dataset_base_dir / "partsupp.parquet", "partsupp") +def get_part_ds() -> DataFrame: + return _read_ds("part") -def drop_temp_view() -> None: - spark = get_or_create_spark() - [ - spark.catalog.dropTempView(t.name) - for t in spark.catalog.listTables() - if t.isTemporary - ] +def get_part_supp_ds() -> DataFrame: + return _read_ds("partsupp") -def run_query(query_number: int, df: SparkDF) -> None: +def run_query(query_number: int, df: DataFrame) -> None: query = df.toPandas run_query_generic( query, query_number, "pyspark", query_checker=check_query_result_pd diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 38cc886..a5aca1a 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -101,6 +101,7 @@ lf = lf.select(columns) lf.sink_parquet(settings.dataset_base_dir / f"{table_name}.parquet") + lf.sink_csv(settings.dataset_base_dir / f"{table_name}.csv") # IPC currently not relevant - # lf.sink_ipc(base_path / f"{table_name}.ipc") + # lf.sink_ipc(base_path / f"{table_name}.feather") diff --git a/settings.py b/settings.py index 3110d54..8319fd6 100644 --- a/settings.py +++ b/settings.py @@ -1,8 +1,11 @@ from pathlib import Path +from typing import Literal, TypeAlias from pydantic import computed_field from pydantic_settings import BaseSettings, SettingsConfigDict +FileType: TypeAlias = Literal["parquet", "feather", "csv"] + class Paths(BaseSettings): answers: Path = Path("data/answers") @@ -20,7 +23,7 @@ class Paths(BaseSettings): class Run(BaseSettings): include_io: bool = False - file_type: str = "parquet" + file_type: FileType = "parquet" log_timings: bool = False show_results: bool = False