diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 1a410453..67bbbe3f 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -6,7 +6,9 @@ use crate::parser::{ Custom, MsgAttr, MsgType, OverrideEntryPoint, OverrideEntryPoints, }; use crate::strip_generics::StripGenerics; -use crate::utils::{extract_return_type, filter_wheres, process_fields}; +use crate::utils::{ + as_where_clause, brace_generics, extract_return_type, filter_wheres, process_fields, +}; use crate::variant_descs::{AsVariantDescs, VariantDescs}; use convert_case::{Case, Casing}; use proc_macro2::{Span, TokenStream}; @@ -100,14 +102,6 @@ impl<'a> StructMessage<'a> { custom, } = self; - let where_clause = if !wheres.is_empty() { - quote! { - where #(#wheres,)* - } - } else { - quote! {} - }; - let ctx_type = msg_attr .msg_type() .emit_ctx_type(&custom.query_or_default()); @@ -119,21 +113,9 @@ impl<'a> StructMessage<'a> { }); let fields = fields.iter().map(MsgField::emit); - let generics = if generics.is_empty() { - quote! {} - } else { - quote! { - <#(#generics,)*> - } - }; - - let unused_generics = if unused_generics.is_empty() { - quote! {} - } else { - quote! { - <#(#unused_generics,)*> - } - }; + let where_clause = as_where_clause(wheres); + let generics = brace_generics(generics); + let unused_generics = brace_generics(unused_generics); #[cfg(not(tarpaulin_include))] { @@ -282,18 +264,33 @@ impl<'a> EnumMessage<'a> { let ctx_type = msg_ty.emit_ctx_type(query_type); let dispatch_type = msg_ty.emit_result_type(resp_type, &parse_quote!(C::Error)); - let all_generics = if all_generics.is_empty() { + let all_generics = brace_generics(all_generics); + let phantom = if generics.is_empty() { quote! {} + } else if MsgType::Query == *msg_ty { + quote! { + #[returns((#(#generics,)*))] + _Phantom(std::marker::PhantomData<( #(#generics,)* )>), + } } else { - quote! { <#(#all_generics,)*> } + quote! { + _Phantom(std::marker::PhantomData<( #(#generics,)* )>), + } }; - let generics = if generics.is_empty() { - quote! {} + let match_arms = if !generics.is_empty() { + quote! { + #(#match_arms,)* + _Phantom(_) => unreachable!(), + } } else { - quote! { <#(#generics,)*> } + quote! { + #(#match_arms,)* + } }; + let generics = brace_generics(generics); + let unique_enum_name = Ident::new(&format!("{}{}", trait_name, name), name.span()); #[cfg(not(tarpaulin_include))] @@ -305,6 +302,7 @@ impl<'a> EnumMessage<'a> { #[serde(rename_all="snake_case")] pub enum #unique_enum_name #generics { #(#variants,)* + #phantom } pub type #name #generics = #unique_enum_name #generics; } @@ -316,6 +314,7 @@ impl<'a> EnumMessage<'a> { #[serde(rename_all="snake_case")] pub enum #unique_enum_name #generics { #(#variants,)* + #phantom } pub type #name #generics = #unique_enum_name #generics; } @@ -334,7 +333,7 @@ impl<'a> EnumMessage<'a> { use #unique_enum_name::*; match self { - #(#match_arms,)* + #match_arms } } pub const fn messages() -> [&'static str; #msgs_cnt] { @@ -507,10 +506,12 @@ impl<'a> MsgVariant<'a> { let return_type = if let MsgAttr::Query { resp_type } = msg_attr { match resp_type { Some(resp_type) => { + generics_checker.visit_path(&parse_quote! { #resp_type }); quote! {#resp_type} } None => { let return_type = extract_return_type(&sig.output); + generics_checker.visit_path(return_type); quote! {#return_type} } } @@ -621,7 +622,11 @@ impl<'a> MsgVariant<'a> { } } - pub fn emit_querier_impl(&self, trait_module: Option<&Path>) -> TokenStream { + pub fn emit_querier_impl( + &self, + trait_module: Option<&Path>, + unbonded_generics: &Vec<&GenericParam>, + ) -> TokenStream { let sylvia = crate_module(); let Self { name, @@ -637,6 +642,12 @@ impl<'a> MsgVariant<'a> { .map(|module| quote! { #module ::QueryMsg }) .unwrap_or_else(|| quote! { QueryMsg }); + let msg = if !unbonded_generics.is_empty() { + quote! { #msg ::< #(#unbonded_generics,)* > } + } else { + quote! { #msg } + }; + #[cfg(not(tarpaulin_include))] { quote! { @@ -741,18 +752,15 @@ impl<'a> MsgVariants<'a> { let methods_impl = variants .iter() .filter(|variant| variant.msg_type == MsgType::Query) - .map(|variant| variant.emit_querier_impl(None)); + .map(|variant| variant.emit_querier_impl(None, unbonded_generics)); let methods_declaration = variants .iter() .filter(|variant| variant.msg_type == MsgType::Query) .map(MsgVariant::emit_querier_declaration); - let querier = if !unbonded_generics.is_empty() { - quote! { Querier < #(#unbonded_generics,)* > } - } else { - quote! { Querier } - }; + let braced_generics = brace_generics(unbonded_generics); + let querier = quote! { Querier #braced_generics }; #[cfg(not(tarpaulin_include))] { @@ -803,7 +811,7 @@ impl<'a> MsgVariants<'a> { let methods_impl = variants .iter() .filter(|variant| variant.msg_type == MsgType::Query) - .map(|variant| variant.emit_querier_impl(trait_module)); + .map(|variant| variant.emit_querier_impl(trait_module, unbonded_generics)); let mut querier = trait_module .map(|module| quote! { #module ::Querier }) diff --git a/sylvia-derive/src/utils.rs b/sylvia-derive/src/utils.rs index 609ebea0..90ced524 100644 --- a/sylvia-derive/src/utils.rs +++ b/sylvia-derive/src/utils.rs @@ -1,9 +1,11 @@ +use proc_macro2::TokenStream; use proc_macro_error::emit_error; +use quote::quote; use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ - FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature, Type, - WhereClause, WherePredicate, + parse_quote, FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature, + Type, WhereClause, WherePredicate, }; use crate::check_generics::CheckGenerics; @@ -84,3 +86,17 @@ pub fn extract_return_type(ret_type: &ReturnType) -> &Path { &type_path.path } + +pub fn as_where_clause(where_predicates: &[&WherePredicate]) -> Option { + match where_predicates.is_empty() { + true => None, + false => Some(parse_quote! { where #(#where_predicates),* }), + } +} + +pub fn brace_generics(unbonded_generics: &[&GenericParam]) -> TokenStream { + match unbonded_generics.is_empty() { + true => quote! {}, + false => quote! { < #(#unbonded_generics,)* > }, + } +} diff --git a/sylvia/tests/generics.rs b/sylvia/tests/generics.rs index de3eaccd..31283d3e 100644 --- a/sylvia/tests/generics.rs +++ b/sylvia/tests/generics.rs @@ -1,17 +1,18 @@ use cosmwasm_schema::cw_serde; pub mod cw1 { - use cosmwasm_std::{CosmosMsg, CustomMsg, Response, StdError}; + use cosmwasm_std::{CosmosMsg, CustomMsg, CustomQuery, Response, StdError}; - use serde::Deserialize; + use serde::{de::DeserializeOwned, Deserialize}; use sylvia::types::{ExecCtx, QueryCtx}; use sylvia_derive::interface; #[interface(module=msg)] - pub trait Cw1 + pub trait Cw1 where for<'msg_de> Msg: CustomMsg + Deserialize<'msg_de>, Param: sylvia::types::CustomMsg, + for<'msg_de> QueryRet: CustomQuery + DeserializeOwned, { type Error: From; @@ -20,7 +21,7 @@ pub mod cw1 { -> Result; #[msg(query)] - fn query(&self, ctx: QueryCtx, param: Param) -> Result; + fn some_query(&self, ctx: QueryCtx, param: Param) -> Result; } } @@ -37,7 +38,7 @@ mod tests { #[test] fn construct_messages() { - let _ = crate::cw1::QueryMsg::query(ExternalMsg {}); + let _ = crate::cw1::QueryMsg::<_, Empty>::some_query(ExternalMsg {}); let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(ExternalMsg {})]); let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(Empty {})]); }