Skip to content

Commit

Permalink
Robj to R choice (#437)
Browse files Browse the repository at this point in the history
Co-authored-by: Etienne Bacher <[email protected]>
  • Loading branch information
sorhawell and etiennebacher authored Nov 7, 2023
1 parent 04bbe96 commit 6322b56
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 77 deletions.
4 changes: 4 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ test_robj_to_expr <- function(robj) .Call(wrap__test_robj_to_expr, robj)

test_wrong_call_pl_lit <- function(robj) .Call(wrap__test_wrong_call_pl_lit, robj)

test_robj_to_rchoice <- function(robj) .Call(wrap__test_robj_to_rchoice, robj)

polars_features <- function() .Call(wrap__polars_features)

concat_lf <- function(l, rechunk, parallel, to_supertypes) .Call(wrap__concat_lf, l, rechunk, parallel, to_supertypes)
Expand Down Expand Up @@ -317,6 +319,8 @@ RPolarsErr$mistyped <- function(s) .Call(wrap__RPolarsErr__mistyped, self, s)

RPolarsErr$misvalued <- function(s) .Call(wrap__RPolarsErr__misvalued, self, s)

RPolarsErr$notachoice <- function(s) .Call(wrap__RPolarsErr__notachoice, self, s)

RPolarsErr$plain <- function(s) .Call(wrap__RPolarsErr__plain, self, s)

RPolarsErr$rcall <- function(c) .Call(wrap__RPolarsErr__rcall, self, c)
Expand Down
31 changes: 13 additions & 18 deletions R/lazyframe__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -982,35 +982,30 @@ LazyFrame_join = function(
suffix = "_right",
allow_parallel = TRUE,
force_parallel = FALSE) {
if (inherits(other, "LazyFrame")) {
# nothing
} else if (inherits(other, "DataFrame")) {
other = other
} else {
stop(paste("Expected a `LazyFrame` as join table, got ", class(other)))
}

how_opts = c("inner", "left", "outer", "semi", "anti", "cross")
how = match.arg(how[1L], how_opts)
uw = \(res) unwrap(res, "in $join():")

if (inherits(other, "DataFrame")) {
other = other$lazy()
}

if (!is.null(on)) {
rexprs = do.call(construct_ProtoExprArray, as.list(on))
rexprs_left = rexprs
rexprs_right = rexprs
rexprs_right = rexprs_left = as.list(on)
} else if ((!is.null(left_on) && !is.null(right_on))) {
rexprs_left = do.call(construct_ProtoExprArray, as.list(left_on))
rexprs_right = do.call(construct_ProtoExprArray, as.list(right_on))
rexprs_left = as.list(left_on)
rexprs_right = as.list(right_on)
} else if (how != "cross") {
stop("must specify `on` OR ( `left_on` AND `right_on` ) ")
Err_plain("must specify `on` OR ( `left_on` AND `right_on` ) ") |> uw()
} else {
rexprs_left = do.call(construct_ProtoExprArray, as.list(self$columns))
rexprs_right = do.call(construct_ProtoExprArray, as.list(other$columns))
rexprs_left = as.list(self$columns)
rexprs_right = as.list(other$columns)
}

.pr$LazyFrame$join(
self, other, rexprs_left, rexprs_right,
how, suffix, allow_parallel, force_parallel
)
) |>
uw()
}


Expand Down
48 changes: 22 additions & 26 deletions src/rust/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::lazy::dsl::*;

use crate::rdataframe::DataFrame as RDF;
use crate::rdatatype::{
new_asof_strategy, new_ipc_compression, new_join_type, new_parquet_compression,
new_unique_keep_strategy, RPolarsDataType,
new_asof_strategy, new_ipc_compression, new_parquet_compression, new_unique_keep_strategy,
RPolarsDataType,
};
use crate::robj_to;
use crate::rpolarserr::{polars_to_rpolars_err, RResult, Rctx, WithRctx};
Expand Down Expand Up @@ -406,31 +406,27 @@ impl LazyFrame {
#[allow(clippy::too_many_arguments)]
fn join(
&self,
other: &LazyFrame,
left_on: &ProtoExprArray,
right_on: &ProtoExprArray,
how: &str,
suffix: &str,
allow_parallel: bool,
force_parallel: bool,
) -> LazyFrame {
let ldf = self.0.clone();
let other = other.0.clone();
let left_on = pra_to_vec(left_on, "select");
let right_on = pra_to_vec(right_on, "select");
let how = new_join_type(how);

LazyFrame(
ldf.join_builder()
.with(other)
.left_on(left_on)
.right_on(right_on)
.allow_parallel(allow_parallel)
.force_parallel(force_parallel)
.how(how)
.suffix(suffix)
other: Robj,
left_on: Robj,
right_on: Robj,
how: Robj,
suffix: Robj,
allow_parallel: Robj,
force_parallel: Robj,
) -> RResult<LazyFrame> {
Ok(LazyFrame(
self.0
.clone()
.join_builder()
.with(robj_to!(PLLazyFrame, other)?)
.left_on(robj_to!(VecPLExprCol, left_on)?)
.right_on(robj_to!(VecPLExprCol, right_on)?)
.allow_parallel(robj_to!(bool, allow_parallel)?)
.force_parallel(robj_to!(bool, force_parallel)?)
.how(robj_to!(JoinType, how)?)
.suffix(robj_to!(str, suffix)?)
.finish(),
)
))
}

pub fn sort_by_exprs(
Expand Down
4 changes: 2 additions & 2 deletions src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ impl Expr {
min_periods: robj_to!(usize, min_periods)?,
center: robj_to!(bool, center)?,
by: robj_to!(Option, String, by)?,
closed_window: robj_to!(Option, new_closed_window, closed)?,
closed_window: robj_to!(Option, ClosedWindow, closed)?,
fn_params: Some(pl::Arc::new(pl::RollingQuantileParams {
prob: robj_to!(f64, quantile)?,
interpol: robj_to!(new_quantile_interpolation_option, interpolation)?,
Expand Down Expand Up @@ -2472,7 +2472,7 @@ pub fn make_rolling_options(
min_periods: robj_to!(usize, min_periods)?,
center: robj_to!(bool, center)?,
by: robj_to!(Option, String, by_null)?,
closed_window: robj_to!(Option, new_closed_window, closed_null)?,
closed_window: robj_to!(Option, ClosedWindow, closed_null)?,
..Default::default()
})
}
Expand Down
56 changes: 30 additions & 26 deletions src/rust/src/rdatatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub struct RField(pub pl::Field);
use pl::UniqueKeepStrategy;
use polars::prelude::AsofStrategy;

use crate::utils::robj_to_rchoice;

#[extendr]
impl RField {
fn new(name: String, datatype: &RPolarsDataType) -> RField {
Expand Down Expand Up @@ -245,18 +247,6 @@ impl DataTypeVector {
}
}

pub fn new_join_type(s: &str) -> pl::JoinType {
match s {
"cross" => pl::JoinType::Cross,
"inner" => pl::JoinType::Inner,
"left" => pl::JoinType::Left,
"outer" => pl::JoinType::Outer,
"semi" => pl::JoinType::Semi,
"anti" => pl::JoinType::Anti,
_ => panic!("polars internal error: jointype not recognized"),
}
}

pub fn new_asof_strategy(s: &str) -> Result<AsofStrategy, String> {
match s {
"forward" => Ok(AsofStrategy::Forward),
Expand Down Expand Up @@ -296,20 +286,6 @@ pub fn new_quantile_interpolation_option(robj: Robj) -> RResult<QuantileInterpol
}
}

pub fn new_closed_window(robj: Robj) -> RResult<pl::ClosedWindow> {
let s = robj_to_string(robj.clone())?;
use pl::ClosedWindow as CW;
match s.as_str() {
"both" => Ok(CW::Both),
"left" => Ok(CW::Left),
"none" => Ok(CW::None),
"right" => Ok(CW::Right),
_ => rerr()
.bad_val("ClosedWindow choice: [{}] is not any of 'both', 'left', 'none' or 'right'")
.bad_robj(&robj),
}
}

pub fn new_null_behavior(
s: &str,
) -> std::result::Result<polars::series::ops::NullBehavior, String> {
Expand Down Expand Up @@ -512,6 +488,34 @@ pub fn new_rolling_cov_options(
})
}

pub fn robj_to_join_type(robj: Robj) -> RResult<pl::JoinType> {
let s = robj_to_rchoice(robj)?;
match s.as_str() {
"cross" => Ok(pl::JoinType::Cross),
"inner" => Ok(pl::JoinType::Inner),
"left" => Ok(pl::JoinType::Left),
"outer" => Ok(pl::JoinType::Outer),
"semi" => Ok(pl::JoinType::Semi),
"anti" => Ok(pl::JoinType::Anti),
s => rerr().bad_val(format!(
"JoinType choice ['{s}'] should be one of 'cross', 'inner', 'left', 'outer', 'semi', 'anti'"
)),
}
}

pub fn robj_to_closed_window(robj: Robj) -> RResult<pl::ClosedWindow> {
use pl::ClosedWindow as CW;
match robj_to_rchoice(robj)?.as_str() {
"both" => Ok(CW::Both),
"left" => Ok(CW::Left),
"none" => Ok(CW::None),
"right" => Ok(CW::Right),
s => rerr().bad_val(format!(
"ClosedWindow choice ['{s}'] should be one of 'both', 'left', 'none', 'right'"
)),
}
}

extendr_module! {
mod rdatatype;
impl RPolarsDataType;
Expand Down
11 changes: 9 additions & 2 deletions src/rust/src/rlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use crate::robj_to;
use crate::rpolarserr::{rdbg, RResult};
use crate::series::Series;
use crate::utils::extendr_concurrent::{ParRObj, ThreadCom};
use crate::utils::robj_to_rchoice;
use crate::RFnSignature;
use crate::CONFIG;
use extendr_api::prelude::*;
use polars::prelude as pl;

use std::result::Result;

#[extendr]
Expand Down Expand Up @@ -65,7 +65,7 @@ fn r_date_range_lazy(
robj_to!(PLExprCol, start)?,
robj_to!(PLExprCol, end)?,
robj_to!(pl_duration, every)?,
robj_to!(new_closed_window, closed)?,
robj_to!(ClosedWindow, closed)?,
robj_to!(Option, timeunit, time_unit)?,
robj_to!(Option, String, time_zone)?,
);
Expand Down Expand Up @@ -221,6 +221,12 @@ fn test_wrong_call_pl_lit(robj: Robj) -> RResult<Robj> {
Ok(R!("pl$lit({{robj}})")?) // this call should have been polars::pl$lit(...
}

#[extendr]
fn test_robj_to_rchoice(robj: Robj) -> RResult<String> {
// robj can be any non-zero length char vec, will return first string.
robj_to_rchoice(robj)
}

#[extendr]
fn polars_features() -> List {
list!(
Expand Down Expand Up @@ -301,6 +307,7 @@ extendr_module! {
fn test_print_string;
fn test_robj_to_expr;
fn test_wrong_call_pl_lit;
fn test_robj_to_rchoice;

//feature flags
fn polars_features;
Expand Down
12 changes: 12 additions & 0 deletions src/rust/src/rpolarserr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub enum Rctx {
Mistyped(String),
#[error("Expected a value that {0}")]
Misvalued(String),
#[error("Not a valid R choice because {0}")]
NotAChoice(String),
#[error("{0}")]
Plain(String),
#[error("Encountered the following error in Rust-Polars:\n\t{0}")]
Expand All @@ -51,6 +53,7 @@ pub trait WithRctx<T> {
fn hint(self, cause: impl Into<String>) -> RResult<T>;
fn mistyped(self, ty: impl Into<String>) -> RResult<T>;
fn misvalued(self, scope: impl Into<String>) -> RResult<T>;
fn notachoice(self, scope: impl Into<String>) -> RResult<T>;
fn plain(self, msg: impl Into<String>) -> RResult<T>;
fn when(self, env: impl Into<String>) -> RResult<T>;
}
Expand Down Expand Up @@ -96,6 +99,10 @@ impl<T, E: Into<RPolarsErr>> WithRctx<T> for core::result::Result<T, E> {
self.ctx(Rctx::Misvalued(scope.into()))
}

fn notachoice(self, scope: impl Into<String>) -> RResult<T> {
self.ctx(Rctx::NotAChoice(scope.into()))
}

fn plain(self, msg: impl Into<String>) -> RResult<T> {
self.ctx(Rctx::Plain(msg.into()))
}
Expand Down Expand Up @@ -129,6 +136,7 @@ impl RPolarsErr {
Hint(msg) => ("Hint", msg),
Mistyped(ty) => ("TypeMismatch", ty),
Misvalued(scope) => ("ValueOutOfScope", scope),
NotAChoice(err) => ("NotAChoice", err),
Plain(msg) => ("PlainErrorMessage", msg),
Polars(err) => ("PolarsError", err),
When(target) => ("When", target),
Expand Down Expand Up @@ -166,6 +174,10 @@ impl RPolarsErr {
self.push_back_rctx(Rctx::Misvalued(s))
}

pub fn notachoice(&self, s: String) -> Self {
self.push_back_rctx(Rctx::NotAChoice(s))
}

pub fn plain(&self, s: String) -> Self {
self.push_back_rctx(Rctx::Plain(s))
}
Expand Down
Loading

0 comments on commit 6322b56

Please sign in to comment.