Skip to content

Commit

Permalink
Inline struct, enum, fn and mod inside of declarative modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Tpt committed Feb 26, 2024
1 parent 404161c commit bc2df2b
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 30 deletions.
140 changes: 115 additions & 25 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,10 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
for item in &mut *items {
match item {
Item::Use(item_use) => {
let mut is_pyo3 = false;
item_use.attrs.retain(|attr| {
let found = attr.path().is_ident("pymodule_export");
is_pyo3 |= found;
!found
});
if is_pyo3 {
let cfg_attrs = item_use
.attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg"))
.cloned()
.collect::<Vec<_>>();
let is_pymodule_export =
find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
if is_pymodule_export {
let cfg_attrs = get_cfg_attributes(&item_use.attrs);
extract_use_items(
&item_use.tree,
&cfg_attrs,
Expand All @@ -136,23 +127,97 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
}
Item::Fn(item_fn) => {
let mut is_module_init = false;
item_fn.attrs.retain(|attr| {
let found = attr.path().is_ident("pymodule_init");
is_module_init |= found;
!found
});
if is_module_init {
let is_pymodule_init =
find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
let is_pymodule_export =
find_and_remove_attribute(&mut item_fn.attrs, "pymodule_export");
let ident = &item_fn.sig.ident;
if is_pymodule_init {
if is_pymodule_export {
bail_spanned!(item_fn.span() => "#[pymodule_export] cannot be used on the #[pymodule_init] function");
}
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one pymodule_init may be specified");
let ident = &item_fn.sig.ident;
pymodule_init = Some(quote! { #ident(module)?; });
} else {
bail_spanned!(item.span() => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]")
} else if is_pymodule_export {
module_items.push(ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
}
}
Item::Struct(item_struct) => {
let is_pymodule_export =
find_and_remove_attribute(&mut item_struct.attrs, "pymodule_export");
if is_pymodule_export {
module_items.push(item_struct.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
}
}
Item::Enum(item_enum) => {
let is_pymodule_export =
find_and_remove_attribute(&mut item_enum.attrs, "pymodule_export");
if is_pymodule_export {
module_items.push(item_enum.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
}
}
Item::Mod(item_mod) => {
let is_pymodule_export =
find_and_remove_attribute(&mut item_mod.attrs, "pymodule_export");
if is_pymodule_export {
module_items.push(item_mod.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
}
}
Item::ForeignMod(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on foreign mods but only on inline mods");
}
}
Item::Trait(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on traits");
}
}
Item::Const(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on const");
}
}
Item::Static(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on statics");
}
}
Item::Macro(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on macros");
}
}
item => {
bail_spanned!(item.span() => "only 'use' statements and and pymodule_init functions are allowed in #[pymodule]")
Item::ExternCrate(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on extern crates");
}
}
Item::Impl(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on impls");
}
}
Item::TraitAlias(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on trait aliases");
}
}
Item::Type(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on type aliases");
}
}
Item::Union(item) => {
if has_attribute(&item.attrs, "pymodule_export") {
bail_spanned!(item.span() => "#[pymodule_export] cannot be used on unions");
}
}
_ => (),
}
}

Expand Down Expand Up @@ -337,6 +402,31 @@ fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs
Ok(pyfn_args)
}

fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
attrs
.iter()
.filter(|attr| attr.path().is_ident("cfg"))
.cloned()
.collect()
}

fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
let mut found = false;
attrs.retain(|attr| {
if attr.path().is_ident(ident) {
found = true;
false
} else {
true
}
});
found
}

fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(ident))
}

enum PyModulePyO3Option {
Crate(CrateAttribute),
Name(NameAttribute),
Expand Down
40 changes: 36 additions & 4 deletions tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ struct ValueClass {
#[pymethods]
impl ValueClass {
#[new]
fn new(value: usize) -> ValueClass {
ValueClass { value }
fn new(value: usize) -> Self {
Self { value }
}
}

Expand Down Expand Up @@ -48,6 +48,37 @@ mod declarative_module {
#[pymodule_export]
use super::{declarative_module2, double, MyError, ValueClass as Value};

#[pymodule_export]
#[pymodule]
mod inner {
use super::*;

#[pymodule_export]
#[pyfunction]
fn triple(x: usize) -> usize {
x * 3
}

#[pymodule_export]
#[pyclass]
struct Struct;

#[pymethods]
impl Struct {
#[new]
fn new() -> Self {
Self
}
}

#[pymodule_export]
#[pyclass]
enum Enum {
A,
B,
}
}

#[pymodule_init]
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("double2", m.getattr("double")?)
Expand All @@ -65,7 +96,6 @@ mod declarative_submodule {
use super::{double, double_value};
}

/// A module written using declarative syntax.
#[pymodule]
#[pyo3(name = "declarative_module_renamed")]
mod declarative_module2 {
Expand All @@ -84,7 +114,7 @@ fn test_declarative_module() {
);

py_assert!(py, m, "m.double(2) == 4");
py_assert!(py, m, "m.double2(3) == 6");
py_assert!(py, m, "m.inner.triple(3) == 9");
py_assert!(py, m, "m.declarative_submodule.double(4) == 8");
py_assert!(
py,
Expand All @@ -97,5 +127,7 @@ fn test_declarative_module() {
py_assert!(py, m, "not hasattr(m, 'LocatedClass')");
#[cfg(not(Py_LIMITED_API))]
py_assert!(py, m, "hasattr(m, 'LocatedClass')");
py_assert!(py, m, "isinstance(m.inner.Struct(), m.inner.Struct)");
py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)");
})
}
2 changes: 1 addition & 1 deletion tests/ui/invalid_pymodule_trait.stderr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: only 'use' statements and and pymodule_init functions are allowed in #[pymodule]
error: #[pymodule_export] cannot be used on traits
--> tests/ui/invalid_pymodule_trait.rs:5:5
|
5 | #[pymodule_export]
Expand Down

0 comments on commit bc2df2b

Please sign in to comment.