Skip to content

Commit

Permalink
feat(rust): Expose a few more expression nodes in the expression IR (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored Jun 7, 2024
1 parent 1fc9a59 commit 1df442f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 18 deletions.
49 changes: 37 additions & 12 deletions py-polars/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use polars_time::prelude::RollingGroupOptions;
use pyo3::exceptions::PyNotImplementedError;
use pyo3::prelude::*;

use crate::series::PySeries;
use crate::Wrap;

#[pyclass]
Expand Down Expand Up @@ -342,6 +343,16 @@ pub struct Function {
options: PyObject,
}

#[pyclass]
pub struct Slice {
#[pyo3(get)]
input: usize,
#[pyo3(get)]
offset: usize,
#[pyo3(get)]
length: usize,
}

#[pyclass]
pub struct Len {}

Expand Down Expand Up @@ -545,9 +556,18 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
value: Wrap(lit.to_any_value().unwrap()).to_object(py),
dtype,
},
Duration(_, _) => return Err(PyNotImplementedError::new_err("duration literal")),
Time(_) => return Err(PyNotImplementedError::new_err("time literal")),
Series(_) => return Err(PyNotImplementedError::new_err("series literal")),
Duration(v, _) => Literal {
value: v.to_object(py),
dtype,
},
Time(ns) => Literal {
value: ns.to_object(py),
dtype,
},
Series(s) => Literal {
value: PySeries::new((**s).clone()).into_py(py),
dtype,
},
}
}
.into_py(py),
Expand Down Expand Up @@ -995,9 +1015,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
return Err(PyNotImplementedError::new_err("search sorted"))
},
FunctionExpr::Range(_) => return Err(PyNotImplementedError::new_err("range")),
FunctionExpr::DateOffset => {
return Err(PyNotImplementedError::new_err("date offset"))
},
FunctionExpr::DateOffset => ("offset_by",).to_object(py),
FunctionExpr::Trigonometry(trigfun) => match trigfun {
TrigonometricFunction::Cos => ("cos",),
TrigonometricFunction::Cot => ("cot",),
Expand Down Expand Up @@ -1107,13 +1125,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
parallel: _,
name: _,
} => return Err(PyNotImplementedError::new_err("value counts")),
FunctionExpr::UniqueCounts => {
return Err(PyNotImplementedError::new_err("unique counts"))
},
FunctionExpr::UniqueCounts => ("unique_counts",).to_object(py),
FunctionExpr::ApproxNUnique => {
return Err(PyNotImplementedError::new_err("approx nunique"))
},
FunctionExpr::Coalesce => return Err(PyNotImplementedError::new_err("coalesce")),
FunctionExpr::Coalesce => ("coalesce",).to_object(py),
FunctionExpr::ShrinkType => {
return Err(PyNotImplementedError::new_err("shrink type"))
},
Expand All @@ -1134,7 +1150,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::Log { base: _ } => return Err(PyNotImplementedError::new_err("log")),
FunctionExpr::Log1p => return Err(PyNotImplementedError::new_err("log1p")),
FunctionExpr::Exp => return Err(PyNotImplementedError::new_err("exp")),
FunctionExpr::Unique(_) => return Err(PyNotImplementedError::new_err("unique")),
FunctionExpr::Unique(maintain_order) => ("unique", maintain_order).to_object(py),
FunctionExpr::Round { decimals } => ("round", decimals).to_object(py),
FunctionExpr::RoundSF { digits } => ("round_sig_figs", digits).to_object(py),
FunctionExpr::Floor => ("floor",).to_object(py),
Expand Down Expand Up @@ -1262,7 +1278,16 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
.into_py(py)
},
AExpr::Wildcard => return Err(PyNotImplementedError::new_err("wildcard")),
AExpr::Slice { .. } => return Err(PyNotImplementedError::new_err("slice")),
AExpr::Slice {
input,
offset,
length,
} => Slice {
input: input.0,
offset: offset.0,
length: length.0,
}
.into_py(py),
AExpr::Nth(_) => return Err(PyNotImplementedError::new_err("nth")),
AExpr::Len => Len {}.into_py(py),
};
Expand Down
12 changes: 6 additions & 6 deletions py-polars/src/lazyframe/visitor/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub struct Select {
#[pyo3(get)]
expr: Vec<PyExprIR>,
#[pyo3(get)]
options: (), //ProjectionOptions,
should_broadcast: bool,
}

#[pyclass]
Expand Down Expand Up @@ -195,7 +195,7 @@ pub struct HStack {
#[pyo3(get)]
exprs: Vec<PyExprIR>,
#[pyo3(get)]
options: (), // ProjectionOptions,
should_broadcast: bool,
}

#[pyclass]
Expand Down Expand Up @@ -338,11 +338,11 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult<PyObject> {
input,
expr,
schema: _,
options: _,
options,
} => Select {
expr: expr.iter().map(|e| e.into()).collect(),
input: input.0,
options: (),
should_broadcast: options.should_broadcast,
}
.into_py(py),
IR::Sort {
Expand Down Expand Up @@ -428,11 +428,11 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult<PyObject> {
input,
exprs,
schema: _,
options: _,
options,
} => HStack {
input: input.0,
exprs: exprs.iter().map(|e| e.into()).collect(),
options: (),
should_broadcast: options.should_broadcast,
}
.into_py(py),
IR::Reduce {
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ fn _expr_nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<Agg>().unwrap();
m.add_class::<Ternary>().unwrap();
m.add_class::<Function>().unwrap();
m.add_class::<Slice>().unwrap();
m.add_class::<Len>().unwrap();
m.add_class::<Window>().unwrap();
m.add_class::<PyOperator>().unwrap();
Expand Down

0 comments on commit 1df442f

Please sign in to comment.