Skip to content

Commit

Permalink
update scan_parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
sorhawell committed Nov 9, 2023
1 parent c97f6bd commit 4fa1156
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 64 deletions.
49 changes: 26 additions & 23 deletions R/parquet.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,23 @@
#' file.path(temp_dir, "**/*.parquet")
#' )$collect()
pl$scan_parquet = function(
file, # : str | Path,
n_rows = NULL, # : int | None = None,
cache = TRUE, # : bool = True,
file,
n_rows = NULL,
cache = TRUE,
parallel = c(
"Auto", # default
"None",
"Columns", # Parallelize over the row groups
"RowGroups" # Parallelize over the columns
), # Automatically determine over which unit to parallelize, This will choose the most occurring unit.
rechunk = TRUE, # : bool = True,
row_count_name = NULL, # : str | None = None,
row_count_offset = 0L, # : int = 0,
"Columns",
"RowGroups"
),
rechunk = TRUE,
row_count_name = NULL,
row_count_offset = 0L,
# storage_options,#: dict[str, object] | None = None, #seems fsspec specific
low_memory = FALSE, # : bool = False,
low_memory = FALSE,
hive_partitioning = TRUE) { #-> LazyFrame

parallel = parallel[1L]
if (!parallel %in% c("None", "Columns", "RowGroups", "Auto")) {
stop("unknown parallel strategy")
}

result_lf = new_from_parquet(
new_from_parquet(
path = file,
n_rows = n_rows,
cache = cache,
Expand All @@ -63,9 +58,8 @@ pl$scan_parquet = function(
row_count = row_count_offset,
low_memory = low_memory,
hive_partitioning = hive_partitioning
)

unwrap(result_lf)
) |>
unwrap("in pl$scan_parquet(): ")
}


Expand All @@ -85,14 +79,23 @@ pl$read_parquet = function(
file,
n_rows = NULL,
cache = TRUE,
parallel = c("Auto", "None", "Columns", "RowGroups"),
parallel = c(
"Auto", # default
"None",
"Columns",
"RowGroups"
),
rechunk = TRUE,
row_count_name = NULL,
row_count_offset = 0L,
low_memory = FALSE) {
# storage_options,#: dict[str, object] | None = None, #seems fsspec specific
low_memory = FALSE,
hive_partitioning = TRUE
) {
mc = match.call()
mc[[1]] = get("pl", envir = asNamespace("polars"))$scan_parquet
eval.parent(mc)$collect()
mc[[1]] = pl$scan_parquet
result(eval(mc)$collect()) |>
unwrap("in pl$read_parquet(): ")
}


Expand Down
68 changes: 27 additions & 41 deletions src/rust/src/rdataframe/read_parquet.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,41 @@
use crate::utils::r_result_list;

use crate::lazy::dataframe::LazyFrame;
use crate::robj_to;
use crate::rpolarserr::{polars_to_rpolars_err, RResult};

//use crate::utils::wrappers::*;
use crate::utils::wrappers::null_to_opt;
use extendr_api::{extendr, prelude::*};
use extendr_api::Rinternals;
use extendr_api::{extendr, extendr_module, Robj};
use polars::io::RowCount;
use polars::prelude::{self as pl};
//this function is derived from polars/py-polars/src/lazy/DataFrame.rs new_from_csv

#[allow(clippy::too_many_arguments)]
#[extendr]
pub fn new_from_parquet(
path: String,
n_rows: Nullable<i32>,
cache: bool,
parallel: String, //Wrap<ParallelStrategy>,
rechunk: bool,
row_name: Nullable<String>,
row_count: u32,
low_memory: bool,
hive_partitioning: bool,
) -> List {
let parallel_strategy = match parallel {
x if x == "Auto" => pl::ParallelStrategy::Auto,
_ => panic!("not implemented"),
};

let row_name = null_to_opt(row_name);

let row_count = row_name.map(|name| polars::io::RowCount {
name,
offset: row_count,
});
let n_rows = null_to_opt(n_rows);

path: Robj,
n_rows: Robj,
cache: Robj,
parallel: Robj,
rechunk: Robj,
row_name: Robj,
row_count: Robj,
low_memory: Robj,
hive_partitioning: Robj,
) -> RResult<LazyFrame> {
let offset = robj_to!(Option, u32, row_count)?.unwrap_or(0);
let opt_rowcount = robj_to!(Option, String, row_name)?.map(|name| RowCount { name, offset });
let args = pl::ScanArgsParquet {
n_rows: n_rows.map(|x| x as usize),
cache,
parallel: parallel_strategy,
rechunk,
row_count,
low_memory,
n_rows: robj_to!(Option, usize, n_rows)?,
cache: robj_to!(bool, cache)?,
parallel: robj_to!(ParallelStrategy, parallel)?,
rechunk: robj_to!(bool, rechunk)?,
row_count: opt_rowcount,
low_memory: robj_to!(bool, low_memory)?,
cloud_options: None, //TODO implement cloud options
use_statistics: true, //TODO expose use statistics
hive_partitioning,
hive_partitioning: robj_to!(bool, hive_partitioning)?,
};

let lf_result = pl::LazyFrame::scan_parquet(path, args)
.map_err(|x| x.to_string())
.map(LazyFrame);
r_result_list(lf_result)
pl::LazyFrame::scan_parquet(robj_to!(String, path)?, args)
.map_err(polars_to_rpolars_err)
.map(LazyFrame)
}

extendr_module! {
Expand Down
14 changes: 14 additions & 0 deletions src/rust/src/rdatatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,20 @@ pub fn robj_to_closed_window(robj: Robj) -> RResult<pl::ClosedWindow> {
}
}

pub fn robj_to_parallel_strategy(robj: extendr_api::Robj) -> RResult<pl::ParallelStrategy> {
use pl::ParallelStrategy as PS;
match robj_to_rchoice(robj)?.to_lowercase().as_str() {
//accept also lowercase as normal for most other enums
"auto" => Ok(PS::Auto),
"none" => Ok(PS::Auto),
"columns" => Ok(PS::Auto),
"rowgroups" => Ok(PS::Auto),
s => rerr().bad_val(format!(
"ParallelStrategy choice ['{s}'] should be one of 'Auto', 'None', 'Columns', 'RowGroups'"
)),
}
}

extendr_module! {
mod rdatatype;
impl RPolarsDataType;
Expand Down
4 changes: 4 additions & 0 deletions src/rust/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,10 @@ macro_rules! robj_to_inner {
$crate::rdatatype::robj_to_join_type($a)
};

(ParallelStrategy, $a:ident) => {
$crate::rdatatype::robj_to_parallel_strategy($a)
};

(PathBuf, $a:ident) => {
$crate::utils::robj_to_pathbuf($a)
};
Expand Down
44 changes: 44 additions & 0 deletions tests/testthat/test-parquet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

tmpf = tempfile()
on.exit(unlink(tmpf))
lf_exp = pl$LazyFrame(mtcars)
lf_exp$sink_parquet(tmpf, compression = "snappy")
df_exp = lf_exp$collect()$to_data_frame()

test_that("scan read parquet", {

#simple scan
expect_identical(
pl$scan_parquet(tmpf)$collect()$to_data_frame(),
df_exp
)

# simple read
expect_identical(
pl$read_parquet(tmpf)$to_data_frame(),
df_exp
)

# with row count
expect_identical(
pl$read_parquet(tmpf, row_count_name = "rc",row_count_offset = 5)$to_data_frame(),
data.frame(rc = as.numeric(5:36), df_exp)
)

# check all parallel strategies work
for(choice in c("auto", "COLUMNS", "None", "rowGroups")) {
expect_identical(
pl$read_parquet(tmpf, parallel = choice)$to_data_frame(),
df_exp
)
}

# bad parallel args
ctx = pl$read_parquet(tmpf, parallel = "34") |> get_err_ctx()
expect_true(startsWith(ctx$BadValue, "ParallelStrategy choice"))
expect_identical(ctx$BadArgument, "parallel")
ctx = pl$read_parquet(tmpf, parallel = 42) |> get_err_ctx()
expect_identical(ctx$NotAChoice, "input is not a character vector")


})

0 comments on commit 4fa1156

Please sign in to comment.