Skip to content

Commit

Permalink
Allow pymodule functions to take a single Bound<'_, PyModule> arg (#3905
Browse files Browse the repository at this point in the history
)
  • Loading branch information
maffoo authored Feb 27, 2024
1 parent 6f03a54 commit a15e4b1
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 1 deletion.
1 change: 1 addition & 0 deletions newsfragments/3905.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The `#[pymodule]` macro now supports module functions that take a single argument as a `&Bound<'_, PyModule>`.
10 changes: 9 additions & 1 deletion pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
let doc = get_doc(&function.attrs, None);

let initialization = module_initialization(options, ident);

// Module function called with optional Python<'_> marker as first arg, followed by the module.
let mut module_args = Vec::new();
if function.sig.inputs.len() == 2 {
module_args.push(quote!(module.py()));
}
module_args.push(quote!(::std::convert::Into::into(BoundRef(module))));

Ok(quote! {
#function
#vis mod #ident {
Expand All @@ -218,7 +226,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
use #krate::impl_::pymethods::BoundRef;

fn __pyo3_pymodule(module: &#krate::Bound<'_, #krate::types::PyModule>) -> #krate::PyResult<()> {
#ident(module.py(), ::std::convert::Into::into(BoundRef(module)))
#ident(#(#module_args),*)
}

impl #ident::MakeDef {
Expand Down
8 changes: 8 additions & 0 deletions src/tests/hygiene/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@ fn invoke_wrap_pyfunction() {
crate::py_run!(py, func, r#"func(5)"#);
});
}

#[test]
fn invoke_wrap_pyfunction_bound() {
crate::Python::with_gil(|py| {
let func = crate::wrap_pyfunction_bound!(do_something, py).unwrap();
crate::py_run!(py, func, r#"func(5)"#);
});
}
15 changes: 15 additions & 0 deletions src/tests/hygiene/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,18 @@ fn my_module(_py: crate::Python<'_>, m: &crate::types::PyModule) -> crate::PyRes

::std::result::Result::Ok(())
}

#[crate::pymodule]
#[pyo3(crate = "crate")]
fn my_module_bound(m: &crate::Bound<'_, crate::types::PyModule>) -> crate::PyResult<()> {
<crate::Bound<'_, crate::types::PyModule> as crate::types::PyModuleMethods>::add_function(
m,
crate::wrap_pyfunction_bound!(do_something, m)?,
)?;
<crate::Bound<'_, crate::types::PyModule> as crate::types::PyModuleMethods>::add_wrapped(
m,
crate::wrap_pymodule!(foo),
)?;

::std::result::Result::Ok(())
}
12 changes: 12 additions & 0 deletions tests/test_no_imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ fn basic_module(_py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyRes
Ok(())
}

#[pyo3::pymodule]
fn basic_module_bound(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyResult<()> {
#[pyfn(m)]
fn answer() -> usize {
42
}

m.add_function(pyo3::wrap_pyfunction_bound!(basic_function, m)?)?;

Ok(())
}

#[pyo3::pyclass]
struct BasicClass {
#[pyo3(get)]
Expand Down

0 comments on commit a15e4b1

Please sign in to comment.