Skip to content

Commit

Permalink
Minor extension to pass multiple arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiemo Bang committed Jun 30, 2023
1 parent bd2e3b0 commit 9600ed0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
28 changes: 28 additions & 0 deletions hydroflow/examples/python_udf/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use hydroflow_macro::hydroflow_syntax;
use pyo3::{PyResult, PyAny, Py, Python};

#[hydroflow::main]
async fn main() {
eprintln!("Vec sender starting...");

let v = vec![1, 2, 3, 4, 5];

let mut df = hydroflow_syntax! {
source_iter(v) -> inspect(
|x| println!("input:\t{:?}", x)
)
// Map to tuples
-> map(|x| (x, 1))
-> py_udf(r#"
def add(a, b):
return a + 1
"#, "add")
-> map(|x: PyResult<Py<PyAny>>| -> i32 {Python::with_gil(|py| {
x.unwrap().extract(py).unwrap()
})})
-> for_each(|x| println!("output:\t{:?}", x));
};

df.run_available();

}
24 changes: 23 additions & 1 deletion hydroflow/tests/surface_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pyo3::prelude::*;
pub fn test_python_basic() {
let mut hf = hydroflow_syntax! {
source_iter(0..10)
-> map(|x| (x,))
-> py_udf(r#"
def fib(n):
if n < 2:
Expand All @@ -20,6 +21,7 @@ def fib(n):
}))
-> assert([0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
};
// FIXME
assert_graphvis_snapshots!(hf);

hf.run_available();
Expand All @@ -28,7 +30,7 @@ def fib(n):
#[multiplatform_test(test)]
pub fn test_python_too_many_args() {
let mut hf = hydroflow_syntax! {
source_iter([5])
source_iter([(5,)])
-> py_udf(r#"
def add(a, b):
return a + b
Expand All @@ -37,6 +39,26 @@ def add(a, b):
-> map(|py_err| py_err.to_string())
-> assert(["TypeError: add() missing 1 required positional argument: 'b'"]);
};
// FIXME
assert_graphvis_snapshots!(hf);

hf.run_available();
}

#[multiplatform_test(test)]
pub fn test_python_two_args() {
let mut hf = hydroflow_syntax! {
source_iter([(5,1)])
-> py_udf(r#"
def add(a, b):
return a + b
"#, "add")
-> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
usize::extract(x.unwrap().as_ref(py)).unwrap()
}))
-> assert([6]);
};
// FIXME
assert_graphvis_snapshots!(hf);

hf.run_available();
Expand Down
21 changes: 12 additions & 9 deletions hydroflow_lang/src/graph/ops/py_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use super::{
///
/// **Requires the "python" feature to be enabled.**
///
/// An operator which allows you to run a python udf. Input arguments must be a stream of items
/// which implement [`IntoPy`](https://docs.rs/pyo3/latest/pyo3/conversion/trait.IntoPy.html).
/// An operator which allows you to run a python udf. Input arguments must be a stream of tuples
/// whose items implement [`IntoPy`](https://docs.rs/pyo3/latest/pyo3/conversion/trait.IntoPy.html).
/// See the [relevant pyo3 docs here](https://pyo3.rs/latest/conversions/tables#mapping-of-rust-types-to-python-types).
///
/// Output items are of type `PyResult<Py<PyAny>>`. Rust native types can be extracted using
Expand All @@ -21,6 +21,7 @@ use super::{
///
/// ```hydroflow
/// source_iter(0..10)
/// -> map(|x| (x,))
/// -> py_udf(r#"
/// def fib(n):
/// if n < 2:
Expand All @@ -35,14 +36,15 @@ use super::{
/// ```
///
/// ```hydroflow
/// source_iter([5])
/// -> py_udf(r#"
/// source_iter([(5,1)])
/// -> py_udf(r#"
/// def add(a, b):
/// return a + b
/// "#, "add")
/// -> map(PyResult::<Py<PyAny>>::unwrap_err)
/// -> map(|py_err| py_err.to_string())
/// -> assert(["TypeError: add() missing 1 required positional argument: 'b'"]);
/// "#, "add")
/// -> map(|x: PyResult<Py<PyAny>>| Python::with_gil(|py| {
/// usize::extract(x.unwrap().as_ref(py)).unwrap()
/// }))
/// -> assert([6]);
/// ```
pub const PY_UDF: OperatorConstraints = OperatorConstraints {
name: "py_udf",
Expand Down Expand Up @@ -112,7 +114,8 @@ pub const PY_UDF: OperatorConstraints = OperatorConstraints {
{
// TODO(mingwei): maybe this can be outside the closure?
let py_func = #context.state_ref(#py_func_ident);
::pyo3::Python::with_gil(|py| py_func.call1(py, (x,)))
//::pyo3::Python::with_gil(|py| py_func.call1(py, (x,)))
::pyo3::Python::with_gil(|py| py_func.call1(py, x))
}
#[cfg(not(feature = "python"))]
panic!()
Expand Down

0 comments on commit 9600ed0

Please sign in to comment.