Skip to content

Commit

Permalink
feat: support kwargs in plugin 'field' functions and raise error on u…
Browse files Browse the repository at this point in the history
…nsupported binview layout (pola-rs#13944)
  • Loading branch information
ritchie46 authored Jan 24, 2024
1 parent d03d86f commit d6c6cef
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 30 deletions.
15 changes: 15 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,21 @@ impl DataType {
matches!(self, DataType::Binary)
}

pub fn contains_views(&self) -> bool {
use DataType::*;
match self {
Binary | String => true,
#[cfg(feature = "dtype-categorical")]
Categorical(_, _) => true,
List(inner) => inner.contains_views(),
#[cfg(feature = "dtype-array")]
Array(inner, _) => inner.contains_views(),
#[cfg(feature = "dtype-struct")]
Struct(fields) => fields.iter().any(|field| field.dtype.contains_views()),
_ => false,
}
}

/// Check if type is sortable
pub fn is_ord(&self) -> bool {
#[cfg(feature = "dtype-categorical")]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use polars_core::error::PolarsResult;
use polars_core::prelude::{ArrowField, Series};

pub const MAJOR: u16 = 0;
pub const MINOR: u16 = 0;
pub const MINOR: u16 = 1;

pub const fn get_version() -> (u16, u16) {
(MAJOR, MINOR)
Expand Down
78 changes: 56 additions & 22 deletions crates/polars-plan/src/dsl/function_expr/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,34 +120,68 @@ pub(super) unsafe fn plugin_field(
fields: &[Field],
lib: &str,
symbol: &str,
kwargs: &[u8],
) -> PolarsResult<Field> {
let plugin = get_lib(lib)?;
let lib = &plugin.0;
let major = plugin.1;
let minor = plugin.2;

if major == 0 {
// *const ArrowSchema: pointer to heap Box<ArrowSchema>
// usize: length of the boxed slice
// *mut ArrowSchema: pointer where the return value can be written
let symbol: libloading::Symbol<
unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema),
> = lib
.get((format!("_polars_plugin_field_{}", symbol)).as_bytes())
.unwrap();
// we deallocate the fields buffer
let ffi_fields = fields
.iter()
.map(|field| arrow::ffi::export_field_to_c(&field.to_arrow(true)))
.collect::<Vec<_>>()
.into_boxed_slice();
let n_args = ffi_fields.len();
let slice_ptr = ffi_fields.as_ptr();

// we deallocate the fields buffer
let fields = fields
.iter()
.map(|field| arrow::ffi::export_field_to_c(&field.to_arrow(true)))
.collect::<Vec<_>>()
.into_boxed_slice();
let n_args = fields.len();
let slice_ptr = fields.as_ptr();

let mut return_value = ArrowSchema::empty();
let return_value_ptr = &mut return_value as *mut ArrowSchema;
symbol(slice_ptr, n_args, return_value_ptr);
let mut return_value = ArrowSchema::empty();
let return_value_ptr = &mut return_value as *mut ArrowSchema;

if major == 0 {
match minor {
0 => {
let views = fields.iter().any(|field| field.dtype.contains_views());
polars_ensure!(!views, ComputeError: "cannot call plugin\n\nThis Polars' version has a different 'binary/string' layout. Please compile with latest 'pyo3-polars'");

// *const ArrowSchema: pointer to heap Box<ArrowSchema>
// usize: length of the boxed slice
// *mut ArrowSchema: pointer where the return value can be written
let symbol: libloading::Symbol<
unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema),
> = lib
.get((format!("_polars_plugin_field_{}", symbol)).as_bytes())
.unwrap();
symbol(slice_ptr, n_args, return_value_ptr);
},
1 => {
// *const ArrowSchema: pointer to heap Box<ArrowSchema>
// usize: length of the boxed slice
// *mut ArrowSchema: pointer where the return value can be written
// *const u8: pointer to &[u8] (kwargs)
// usize: length of the u8 slice
let symbol: libloading::Symbol<
unsafe extern "C" fn(
*const ArrowSchema,
usize,
*mut ArrowSchema,
*const u8,
usize,
),
> = lib
.get((format!("_polars_plugin_field_{}", symbol)).as_bytes())
.unwrap();

let kwargs_ptr = kwargs.as_ptr();
let kwargs_len = kwargs.len();

symbol(slice_ptr, n_args, return_value_ptr, kwargs_ptr, kwargs_len);
},
_ => {
polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}-{}", major, minor)
},
}
if !return_value.is_null() {
let arrow_field = import_field_from_c(&return_value)?;
let out = Field::from(&arrow_field);
Expand All @@ -159,7 +193,7 @@ pub(super) unsafe fn plugin_field(
polars_bail!(ComputeError: "the plugin failed with message: {}", msg)
}
} else {
polars_bail!(ComputeError: "this polars engine doesn't support plugin version: {}", major)
polars_bail!(ComputeError: "this Polars engine doesn't support plugin version: {}", major)
}
}

Expand Down
16 changes: 9 additions & 7 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,11 @@ impl FunctionExpr {
Random { .. } => mapper.with_same_dtype(),
SetSortedFlag(_) => mapper.with_same_dtype(),
#[cfg(feature = "ffi_plugin")]
FfiPlugin { lib, symbol, .. } => unsafe {
plugin::plugin_field(fields, lib, symbol.as_ref())
},
FfiPlugin {
lib,
symbol,
kwargs,
} => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) },
BackwardFill { .. } => mapper.with_same_dtype(),
ForwardFill { .. } => mapper.with_same_dtype(),
SumHorizontal => mapper.map_to_supertype(),
Expand Down Expand Up @@ -313,7 +315,7 @@ impl<'a> FieldsMapper<'a> {
}

/// Map a single dtype.
pub fn map_dtype(&self, func: impl Fn(&DataType) -> DataType) -> PolarsResult<Field> {
pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult<Field> {
let dtype = func(self.fields[0].data_type());
Ok(Field::new(self.fields[0].name(), dtype))
}
Expand All @@ -325,7 +327,7 @@ impl<'a> FieldsMapper<'a> {
/// Map a single field with a potentially failing mapper function.
pub fn try_map_field(
&self,
func: impl Fn(&Field) -> PolarsResult<Field>,
func: impl FnOnce(&Field) -> PolarsResult<Field>,
) -> PolarsResult<Field> {
func(&self.fields[0])
}
Expand Down Expand Up @@ -360,7 +362,7 @@ impl<'a> FieldsMapper<'a> {
/// Map a single dtype with a potentially failing mapper function.
pub fn try_map_dtype(
&self,
func: impl Fn(&DataType) -> PolarsResult<DataType>,
func: impl FnOnce(&DataType) -> PolarsResult<DataType>,
) -> PolarsResult<Field> {
let dtype = func(self.fields[0].data_type())?;
Ok(Field::new(self.fields[0].name(), dtype))
Expand All @@ -369,7 +371,7 @@ impl<'a> FieldsMapper<'a> {
/// Map all dtypes with a potentially failing mapper function.
pub fn try_map_dtypes(
&self,
func: impl Fn(&[&DataType]) -> PolarsResult<DataType>,
func: impl FnOnce(&[&DataType]) -> PolarsResult<DataType>,
) -> PolarsResult<Field> {
let mut fld = self.fields[0].clone();
let dtypes = self
Expand Down

0 comments on commit d6c6cef

Please sign in to comment.