Skip to content

Commit

Permalink
int_range: make datatype check in Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher committed Apr 15, 2024
1 parent e392caa commit 6995969
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
8 changes: 0 additions & 8 deletions R/functions__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1313,10 +1313,6 @@ pl_arg_sort_by = function(
#' pl$all()
#' )
pl_int_range = function(start = 0, end = NULL, step = 1, ..., dtype = pl$Int64) {
if (!dtype$is_integer()) {
Err_plain("`dtype` must be of type integer") |>
unwrap("in pl$int_range():")
}
if (is.null(end)) {
end = start
start = 0
Expand All @@ -1342,10 +1338,6 @@ pl_int_range = function(start = 0, end = NULL, step = 1, ..., dtype = pl$Int64)
#'
#' df$with_columns(int_range = pl$int_ranges("start", "end", dtype = pl$Int16))
pl_int_ranges = function(start = 0, end = NULL, step = 1, ..., dtype = pl$Int64) {
if (!dtype$is_integer()) {
Err_plain("`dtype` must be of type integer") |>
unwrap("in pl$int_ranges():")
}
if (is.null(end)) {
end = start
start = 0
Expand Down
8 changes: 7 additions & 1 deletion src/rust/src/rlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,13 @@ pub fn int_ranges(start: Robj, end: Robj, step: Robj, dtype: Robj) -> RResult<RP
let start = robj_to!(PLExprCol, start)?;
let end = robj_to!(PLExprCol, end)?;
let step = robj_to!(PLExprCol, step)?;
let dtype = robj_to!(RPolarsDataType, dtype)?.into();
let dtype: pl::DataType = robj_to!(RPolarsDataType, dtype)?.into();
if !dtype.is_integer() {
return Err(pl::PolarsError::ComputeError(
format!("non-integer `dtype` passed to `int_ranges`: {:?}", dtype,).into(),
)
.into());
}
let mut result = polars::lazy::dsl::int_ranges(start, end, step);
if dtype != pl::DataType::Int64 {
result = result.cast(pl::DataType::List(Box::new(dtype)))
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-lazy_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,12 @@ test_that("pl$int_range() works", {

expect_grepl_error(
pl$int_range(0, 3, dtype = pl$String) |> as_polars_series(),
"must be of type integer"
"non-integer `dtype` passed"
)

expect_grepl_error(
pl$int_range(0, 3, dtype = pl$Float32) |> as_polars_series(),
"must be of type integer"
"non-integer `dtype` passed"
)

# "step" works
Expand Down Expand Up @@ -612,7 +612,7 @@ test_that("pl$int_ranges() works", {

expect_grepl_error(
df$select(int_range = pl$int_ranges("start", "end", dtype = pl$String)),
"must be of type integer"
"non-integer `dtype` passed"
)

# "step" works
Expand Down

0 comments on commit 6995969

Please sign in to comment.