From f01f822c3bb800c3fb417be66f64df6f527afbb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Mon, 18 Sep 2023 23:08:59 +0200 Subject: [PATCH] feat: Add support for generics in interface --- sylvia-derive/src/check_generics.rs | 1 + sylvia-derive/src/input.rs | 40 +++++---- sylvia-derive/src/message.rs | 128 ++++++++++++++++++++-------- sylvia-derive/src/multitest.rs | 12 +-- sylvia-derive/src/utils.rs | 32 ++++--- sylvia/examples/generics.rs | 17 ---- sylvia/src/into_response.rs | 2 +- sylvia/src/lib.rs | 1 + sylvia/src/types.rs | 3 + sylvia/tests/generics.rs | 44 ++++++++++ 10 files changed, 187 insertions(+), 93 deletions(-) delete mode 100644 sylvia/examples/generics.rs create mode 100644 sylvia/tests/generics.rs diff --git a/sylvia-derive/src/check_generics.rs b/sylvia-derive/src/check_generics.rs index fbd2188a..edba0ce4 100644 --- a/sylvia-derive/src/check_generics.rs +++ b/sylvia-derive/src/check_generics.rs @@ -1,6 +1,7 @@ use syn::visit::Visit; use syn::GenericParam; +#[derive(Debug)] pub struct CheckGenerics<'g> { generics: &'g [&'g GenericParam], used: Vec<&'g GenericParam>, diff --git a/sylvia-derive/src/input.rs b/sylvia-derive/src/input.rs index 6fe658c5..83e05237 100644 --- a/sylvia-derive/src/input.rs +++ b/sylvia-derive/src/input.rs @@ -62,7 +62,14 @@ impl<'a> TraitInput<'a> { let messages = self.emit_messages(); let multitest_helpers = self.emit_helpers(); let remote = Remote::new(&Interfaces::default()).emit(); - let querier = MsgVariants::new(self.item.as_variants(), &self.generics).emit_querier(); + + let querier = MsgVariants::new( + self.item.as_variants(), + MsgType::Query, + &self.generics, + &self.item.generics.where_clause, + ) + .emit_querier(); #[cfg(not(tarpaulin_include))] { @@ -159,22 +166,26 @@ impl<'a> ImplInput<'a> { quote! {} }; - let interfaces = Interfaces::new(self.item); - let variants = MsgVariants::new(self.item.as_variants(), &self.generics); + let unbonded_generics = &vec![]; + let variants = MsgVariants::new( + self.item.as_variants(), + MsgType::Query, + unbonded_generics, + &None, + ); match is_trait { - true => self.process_interface(&interfaces, variants, multitest_helpers), - false => self.process_contract(&interfaces, variants, multitest_helpers), + true => self.process_interface(variants, multitest_helpers), + false => self.process_contract(variants, multitest_helpers), } } fn process_interface( &self, - interfaces: &Interfaces, variants: MsgVariants<'a>, multitest_helpers: TokenStream, ) -> TokenStream { - let querier_bound_for_impl = self.emit_querier_for_bound_impl(interfaces, variants); + let querier_bound_for_impl = self.emit_querier_for_bound_impl(variants); #[cfg(not(tarpaulin_include))] quote! { @@ -186,14 +197,14 @@ impl<'a> ImplInput<'a> { fn process_contract( &self, - interfaces: &Interfaces, variants: MsgVariants<'a>, multitest_helpers: TokenStream, ) -> TokenStream { let messages = self.emit_messages(); - let remote = Remote::new(interfaces).emit(); + let remote = Remote::new(&self.interfaces).emit(); + let querier = variants.emit_querier(); - let querier_from_impl = interfaces.emit_querier_from_impl(); + let querier_from_impl = self.interfaces.emit_querier_from_impl(); #[cfg(not(tarpaulin_include))] { @@ -268,12 +279,9 @@ impl<'a> ImplInput<'a> { .emit() } - fn emit_querier_for_bound_impl( - &self, - interfaces: &Interfaces, - variants: MsgVariants<'a>, - ) -> TokenStream { - let trait_module = interfaces + fn emit_querier_for_bound_impl(&self, variants: MsgVariants<'a>) -> TokenStream { + let trait_module = self + .interfaces .interfaces() .first() .map(|interface| &interface.module); diff --git a/sylvia-derive/src/message.rs b/sylvia-derive/src/message.rs index 3bb14c52..1a410453 100644 --- a/sylvia-derive/src/message.rs +++ b/sylvia-derive/src/message.rs @@ -303,7 +303,7 @@ impl<'a> EnumMessage<'a> { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, cosmwasm_schema::QueryResponses)] #[serde(rename_all="snake_case")] - pub enum #unique_enum_name #generics #where_clause { + pub enum #unique_enum_name #generics { #(#variants,)* } pub type #name #generics = #unique_enum_name #generics; @@ -314,7 +314,7 @@ impl<'a> EnumMessage<'a> { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] #[serde(rename_all="snake_case")] - pub enum #unique_enum_name #generics #where_clause { + pub enum #unique_enum_name #generics { #(#variants,)* } pub type #name #generics = #unique_enum_name #generics; @@ -506,7 +506,9 @@ impl<'a> MsgVariant<'a> { let return_type = if let MsgAttr::Query { resp_type } = msg_attr { match resp_type { - Some(resp_type) => quote! {#resp_type}, + Some(resp_type) => { + quote! {#resp_type} + } None => { let return_type = extract_return_type(&sig.output); quote! {#return_type} @@ -667,12 +669,20 @@ impl<'a> MsgVariant<'a> { } } -pub struct MsgVariants<'a>(Vec>); +pub struct MsgVariants<'a> { + variants: Vec>, + unbonded_generics: Vec<&'a GenericParam>, + where_clause: Option, +} impl<'a> MsgVariants<'a> { - pub fn new(source: VariantDescs<'a>, generics: &[&'a GenericParam]) -> Self { + pub fn new( + source: VariantDescs<'a>, + msg_type: MsgType, + generics: &'a Vec<&'a GenericParam>, + unfiltered_where_clause: &'a Option, + ) -> Self { let mut generics_checker = CheckGenerics::new(generics); - let variants: Vec<_> = source .filter_map(|variant_desc| { let msg_attr = variant_desc.attr_msg()?; @@ -684,6 +694,10 @@ impl<'a> MsgVariants<'a> { } }; + if attr.msg_type() != msg_type { + return None; + } + Some(MsgVariant::new( variant_desc.into_sig(), &mut generics_checker, @@ -691,12 +705,38 @@ impl<'a> MsgVariants<'a> { )) }) .collect(); - Self(variants) + + let (unbonded_generics, _) = generics_checker.used_unused(); + let wheres = filter_wheres( + unfiltered_where_clause, + generics.as_slice(), + &unbonded_generics, + ); + let where_clause = if !wheres.is_empty() { + Some(parse_quote! { where #(#wheres),* }) + } else { + None + }; + + Self { + variants, + unbonded_generics, + where_clause, + } + } + + pub fn variants(&self) -> &Vec> { + &self.variants } pub fn emit_querier(&self) -> TokenStream { let sylvia = crate_module(); - let variants = &self.0; + let Self { + variants, + unbonded_generics, + where_clause, + .. + } = self; let methods_impl = variants .iter() @@ -708,6 +748,12 @@ impl<'a> MsgVariants<'a> { .filter(|variant| variant.msg_type == MsgType::Query) .map(MsgVariant::emit_querier_declaration); + let querier = if !unbonded_generics.is_empty() { + quote! { Querier < #(#unbonded_generics,)* > } + } else { + quote! { Querier } + }; + #[cfg(not(tarpaulin_include))] { quote! { @@ -730,12 +776,11 @@ impl<'a> MsgVariants<'a> { } } - impl <'a, C: #sylvia ::cw_std::CustomQuery> Querier for BoundQuerier<'a, C> { + impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for BoundQuerier<'a, C> #where_clause { #(#methods_impl)* } - - pub trait Querier { + pub trait #querier { #(#methods_declaration)* } } @@ -748,24 +793,33 @@ impl<'a> MsgVariants<'a> { contract_module: Option<&Path>, ) -> TokenStream { let sylvia = crate_module(); - let variants = &self.0; + let Self { + variants, + unbonded_generics, + where_clause, + .. + } = self; let methods_impl = variants .iter() .filter(|variant| variant.msg_type == MsgType::Query) .map(|variant| variant.emit_querier_impl(trait_module)); - let querier = trait_module + let mut querier = trait_module .map(|module| quote! { #module ::Querier }) .unwrap_or_else(|| quote! { Querier }); let bound_querier = contract_module .map(|module| quote! { #module ::BoundQuerier}) .unwrap_or_else(|| quote! { BoundQuerier }); + if !unbonded_generics.is_empty() { + querier = quote! { #querier < #(#unbonded_generics,)* > }; + } + #[cfg(not(tarpaulin_include))] { quote! { - impl <'a, C: #sylvia ::cw_std::CustomQuery> #querier for #bound_querier<'a, C> { + impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for #bound_querier<'a, C> #where_clause { #(#methods_impl)* } } @@ -886,7 +940,7 @@ impl<'a> GlueMessage<'a> { interfaces, } = self; let contract = StripGenerics.fold_type((*contract).clone()); - let contract_name = Ident::new(&format!("Contract{}", name), name.span()); + let enum_name = Ident::new(&format!("Contract{}", name), name.span()); let variants = interfaces.emit_glue_message_variants(msg_ty, name); @@ -916,15 +970,15 @@ impl<'a> GlueMessage<'a> { match (msg_ty, customs.has_msg) { (MsgType::Exec, true) => quote! { - #contract_name :: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?) + #enum_name:: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?) }, _ => quote! { - #contract_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx )) + #enum_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx )) }, } }); - let dispatch_arm = quote! {#contract_name :: #contract (msg) =>msg.dispatch(contract, ctx)}; + let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)}; let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(name); @@ -951,7 +1005,7 @@ impl<'a> GlueMessage<'a> { { quote! { #[cfg(not(target_arch = "wasm32"))] - impl cosmwasm_schema::QueryResponses for #contract_name { + impl #sylvia ::cw_schema::QueryResponses for #enum_name { fn response_schemas_impl() -> std::collections::BTreeMap { let responses = [#(#response_schemas_calls),*]; responses.into_iter().flatten().collect() @@ -971,12 +1025,12 @@ impl<'a> GlueMessage<'a> { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(#sylvia ::serde::Serialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)] #[serde(rename_all="snake_case", untagged)] - pub enum #contract_name { + pub enum #enum_name { #(#variants,)* #msg_name } - impl #contract_name { + impl #enum_name { pub fn dispatch( self, contract: &#contract, @@ -996,7 +1050,7 @@ impl<'a> GlueMessage<'a> { #response_schemas - impl<'de> serde::Deserialize<'de> for #contract_name { + impl<'de> serde::Deserialize<'de> for #enum_name { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { @@ -1043,7 +1097,8 @@ pub struct EntryPoints<'a> { error: Type, custom: Custom<'a>, override_entry_points: OverrideEntryPoints, - variants: MsgVariants<'a>, + has_migrate: bool, + reply: Option, } impl<'a> EntryPoints<'a> { @@ -1067,9 +1122,15 @@ impl<'a> EntryPoints<'a> { ) .unwrap_or_else(|| parse_quote! { #sylvia ::cw_std::StdError }); - let generics: Vec<_> = source.generics.params.iter().collect(); + let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &vec![], &None) + .variants() + .is_empty(); - let variants = MsgVariants::new(source.as_variants(), &generics); + let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &vec![], &None) + .variants() + .iter() + .map(|variant| variant.function_name.clone()) + .next(); let custom = Custom::new(&source.attrs); Self { @@ -1077,7 +1138,8 @@ impl<'a> EntryPoints<'a> { error, custom, override_entry_points, - variants, + has_migrate, + reply, } } @@ -1087,17 +1149,13 @@ impl<'a> EntryPoints<'a> { error, custom, override_entry_points, - variants, + has_migrate, + reply, } = self; let sylvia = crate_module(); let custom_msg = custom.msg_or_default(); let custom_query = custom.query_or_default(); - let reply = variants - .0 - .iter() - .find(|variant| variant.msg_type == MsgType::Reply) - .map(|variant| variant.function_name.clone()); #[cfg(not(tarpaulin_include))] { @@ -1119,12 +1177,8 @@ impl<'a> EntryPoints<'a> { let migrate_not_overridden = override_entry_points .get_entry_point(MsgType::Migrate) .is_none(); - let migrate_msg_defined = variants - .0 - .iter() - .any(|variant| variant.msg_type == MsgType::Migrate); - let migrate = if migrate_not_overridden && migrate_msg_defined { + let migrate = if migrate_not_overridden && *has_migrate { OverrideEntryPoint::emit_default_entry_point( &custom_msg, &custom_query, diff --git a/sylvia-derive/src/multitest.rs b/sylvia-derive/src/multitest.rs index 8b871d36..3183e1ef 100644 --- a/sylvia-derive/src/multitest.rs +++ b/sylvia-derive/src/multitest.rs @@ -42,7 +42,9 @@ pub struct MultitestHelpers<'a> { fn interface_name(source: &ItemImpl) -> &Ident { let trait_name = &source.trait_; - let Some(trait_name) = trait_name else {unreachable!()}; + let Some(trait_name) = trait_name else { + unreachable!() + }; let (_, Path { segments, .. }, _) = &trait_name; assert!(!segments.is_empty()); @@ -50,9 +52,9 @@ fn interface_name(source: &ItemImpl) -> &Ident { } fn extract_contract_name(contract: &Type) -> &Ident { - let Type::Path(type_path) = contract else { - unreachable!() - }; + let Type::Path(type_path) = contract else { + unreachable!() + }; let segments = &type_path.path.segments; assert!(!segments.is_empty()); let segment = &segments[0]; @@ -540,7 +542,7 @@ impl<'a> MultitestHelpers<'a> { let mut generics_checker = CheckGenerics::new(generics); let parsed = parse_struct_message(source, MsgType::Instantiate); - let Some((method,_)) = parsed else { + let Some((method, _)) = parsed else { return quote! {}; }; diff --git a/sylvia-derive/src/utils.rs b/sylvia-derive/src/utils.rs index 108e6db4..609ebea0 100644 --- a/sylvia-derive/src/utils.rs +++ b/sylvia-derive/src/utils.rs @@ -2,7 +2,7 @@ use proc_macro_error::emit_error; use syn::spanned::Spanned; use syn::visit::Visit; use syn::{ - FnArg, GenericArgument, GenericParam, PathArguments, PathSegment, ReturnType, Signature, Type, + FnArg, GenericArgument, GenericParam, Path, PathArguments, ReturnType, Signature, Type, WhereClause, WherePredicate, }; @@ -52,14 +52,14 @@ pub fn process_fields<'s>( .collect() } -pub fn extract_return_type(ret_type: &ReturnType) -> &PathSegment { - let ReturnType::Type(_, ty) = ret_type else { - unreachable!() - }; +pub fn extract_return_type(ret_type: &ReturnType) -> &Path { + let ReturnType::Type(_, ty) = ret_type else { + unreachable!() + }; - let Type::Path(type_path) = ty.as_ref() else { - unreachable!() - }; + let Type::Path(type_path) = ty.as_ref() else { + unreachable!() + }; let segments = &type_path.path.segments; assert!(!segments.is_empty()); let segment = &segments[0]; @@ -73,16 +73,14 @@ pub fn extract_return_type(ret_type: &ReturnType) -> &PathSegment { Please use #[msg(return_type=)]" ); } - let PathArguments::AngleBracketed(args) = &segments[0].arguments else{ - unreachable!() - }; + let PathArguments::AngleBracketed(args) = &segments[0].arguments else { + unreachable!() + }; let args = &args.args; assert!(!args.is_empty()); - let GenericArgument::Type(Type::Path(type_path)) = &args[0] else{ - unreachable!() - }; - let segments = &type_path.path.segments; - assert!(!segments.is_empty()); + let GenericArgument::Type(Type::Path(type_path)) = &args[0] else { + unreachable!() + }; - &segments[0] + &type_path.path } diff --git a/sylvia/examples/generics.rs b/sylvia/examples/generics.rs deleted file mode 100644 index 75acbed5..00000000 --- a/sylvia/examples/generics.rs +++ /dev/null @@ -1,17 +0,0 @@ -use cosmwasm_std::{CosmosMsg, Response}; - -use sylvia::types::ExecCtx; -use sylvia_derive::interface; - -#[interface(module=msg)] -pub trait Cw1 -where - Msg: std::fmt::Debug + PartialEq + Clone + schemars::JsonSchema, -{ - type Error; - - #[msg(exec)] - fn execute(&self, ctx: ExecCtx, msgs: Vec>) -> Result; -} - -fn main() {} diff --git a/sylvia/src/into_response.rs b/sylvia/src/into_response.rs index b3a3e73d..06f974f3 100644 --- a/sylvia/src/into_response.rs +++ b/sylvia/src/into_response.rs @@ -52,7 +52,7 @@ impl IntoResponse for Response { .map(|msg| msg.into_msg()) .collect::>()?; let mut resp = Response::new() - .add_submessages(messages.into_iter()) + .add_submessages(messages) .add_events(self.events) .add_attributes(self.attributes); resp.data = self.data; diff --git a/sylvia/src/lib.rs b/sylvia/src/lib.rs index f3502c47..66afeb6f 100644 --- a/sylvia/src/lib.rs +++ b/sylvia/src/lib.rs @@ -10,6 +10,7 @@ pub mod utils; #[cfg(feature = "mt")] pub use anyhow; +pub use cosmwasm_schema as cw_schema; pub use cosmwasm_std as cw_std; #[cfg(feature = "mt")] pub use cw_multi_test; diff --git a/sylvia/src/types.rs b/sylvia/src/types.rs index 725d57d3..71495870 100644 --- a/sylvia/src/types.rs +++ b/sylvia/src/types.rs @@ -1,4 +1,5 @@ use cosmwasm_std::{CustomQuery, Deps, DepsMut, Empty, Env, MessageInfo}; +use serde::de::DeserializeOwned; pub struct ReplyCtx<'a, C: CustomQuery = Empty> { pub deps: DepsMut<'a, C>, @@ -93,3 +94,5 @@ impl<'a, C: CustomQuery> From<(Deps<'a, C>, Env)> for QueryCtx<'a, C> { Self { deps, env } } } + +pub trait CustomMsg: cosmwasm_std::CustomMsg + DeserializeOwned {} diff --git a/sylvia/tests/generics.rs b/sylvia/tests/generics.rs new file mode 100644 index 00000000..de3eaccd --- /dev/null +++ b/sylvia/tests/generics.rs @@ -0,0 +1,44 @@ +use cosmwasm_schema::cw_serde; + +pub mod cw1 { + use cosmwasm_std::{CosmosMsg, CustomMsg, Response, StdError}; + + use serde::Deserialize; + use sylvia::types::{ExecCtx, QueryCtx}; + use sylvia_derive::interface; + + #[interface(module=msg)] + pub trait Cw1 + where + for<'msg_de> Msg: CustomMsg + Deserialize<'msg_de>, + Param: sylvia::types::CustomMsg, + { + type Error: From; + + #[msg(exec)] + fn execute(&self, ctx: ExecCtx, msgs: Vec>) + -> Result; + + #[msg(query)] + fn query(&self, ctx: QueryCtx, param: Param) -> Result; + } +} + +#[cw_serde] +pub struct ExternalMsg; +impl cosmwasm_std::CustomMsg for ExternalMsg {} +impl sylvia::types::CustomMsg for ExternalMsg {} + +#[cfg(test)] +mod tests { + use cosmwasm_std::{CosmosMsg, Empty}; + + use crate::ExternalMsg; + + #[test] + fn construct_messages() { + let _ = crate::cw1::QueryMsg::query(ExternalMsg {}); + let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(ExternalMsg {})]); + let _ = crate::cw1::ExecMsg::execute(vec![CosmosMsg::Custom(Empty {})]); + } +}