Skip to content

Commit

Permalink
feat: Add support for generics in interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jawoznia committed Sep 27, 2023
1 parent e6131e4 commit b9ba7ca
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 93 deletions.
1 change: 1 addition & 0 deletions sylvia-derive/src/check_generics.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use syn::visit::Visit;
use syn::GenericParam;

#[derive(Debug)]
pub struct CheckGenerics<'g> {
generics: &'g [&'g GenericParam],
used: Vec<&'g GenericParam>,
Expand Down
40 changes: 24 additions & 16 deletions sylvia-derive/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ impl<'a> TraitInput<'a> {
let messages = self.emit_messages();
let multitest_helpers = self.emit_helpers();
let remote = Remote::new(&Interfaces::default()).emit();
let querier = MsgVariants::new(self.item.as_variants(), &self.generics).emit_querier();

let querier = MsgVariants::new(
self.item.as_variants(),
MsgType::Query,
&self.generics,
&self.item.generics.where_clause,
)
.emit_querier();

#[cfg(not(tarpaulin_include))]
{
Expand Down Expand Up @@ -159,22 +166,26 @@ impl<'a> ImplInput<'a> {
quote! {}
};

let interfaces = Interfaces::new(self.item);
let variants = MsgVariants::new(self.item.as_variants(), &self.generics);
let unbonded_generics = &vec![];
let variants = MsgVariants::new(
self.item.as_variants(),
MsgType::Query,
unbonded_generics,
&None,
);

match is_trait {
true => self.process_interface(&interfaces, variants, multitest_helpers),
false => self.process_contract(&interfaces, variants, multitest_helpers),
true => self.process_interface(variants, multitest_helpers),
false => self.process_contract(variants, multitest_helpers),
}
}

fn process_interface(
&self,
interfaces: &Interfaces,
variants: MsgVariants<'a>,
multitest_helpers: TokenStream,
) -> TokenStream {
let querier_bound_for_impl = self.emit_querier_for_bound_impl(interfaces, variants);
let querier_bound_for_impl = self.emit_querier_for_bound_impl(variants);

#[cfg(not(tarpaulin_include))]
quote! {
Expand All @@ -186,14 +197,14 @@ impl<'a> ImplInput<'a> {

fn process_contract(
&self,
interfaces: &Interfaces,
variants: MsgVariants<'a>,
multitest_helpers: TokenStream,
) -> TokenStream {
let messages = self.emit_messages();
let remote = Remote::new(interfaces).emit();
let remote = Remote::new(&self.interfaces).emit();

let querier = variants.emit_querier();
let querier_from_impl = interfaces.emit_querier_from_impl();
let querier_from_impl = self.interfaces.emit_querier_from_impl();

#[cfg(not(tarpaulin_include))]
{
Expand Down Expand Up @@ -268,12 +279,9 @@ impl<'a> ImplInput<'a> {
.emit()
}

fn emit_querier_for_bound_impl(
&self,
interfaces: &Interfaces,
variants: MsgVariants<'a>,
) -> TokenStream {
let trait_module = interfaces
fn emit_querier_for_bound_impl(&self, variants: MsgVariants<'a>) -> TokenStream {
let trait_module = self
.interfaces
.interfaces()
.first()
.map(|interface| &interface.module);
Expand Down
128 changes: 91 additions & 37 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl<'a> EnumMessage<'a> {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, cosmwasm_schema::QueryResponses)]
#[serde(rename_all="snake_case")]
pub enum #unique_enum_name #generics #where_clause {
pub enum #unique_enum_name #generics {
#(#variants,)*
}
pub type #name #generics = #unique_enum_name #generics;
Expand All @@ -314,7 +314,7 @@ impl<'a> EnumMessage<'a> {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)]
#[serde(rename_all="snake_case")]
pub enum #unique_enum_name #generics #where_clause {
pub enum #unique_enum_name #generics {
#(#variants,)*
}
pub type #name #generics = #unique_enum_name #generics;
Expand Down Expand Up @@ -506,7 +506,9 @@ impl<'a> MsgVariant<'a> {

let return_type = if let MsgAttr::Query { resp_type } = msg_attr {
match resp_type {
Some(resp_type) => quote! {#resp_type},
Some(resp_type) => {
quote! {#resp_type}
}
None => {
let return_type = extract_return_type(&sig.output);
quote! {#return_type}
Expand Down Expand Up @@ -667,12 +669,20 @@ impl<'a> MsgVariant<'a> {
}
}

pub struct MsgVariants<'a>(Vec<MsgVariant<'a>>);
pub struct MsgVariants<'a> {
variants: Vec<MsgVariant<'a>>,
unbonded_generics: Vec<&'a GenericParam>,
where_clause: Option<WhereClause>,
}

impl<'a> MsgVariants<'a> {
pub fn new(source: VariantDescs<'a>, generics: &[&'a GenericParam]) -> Self {
pub fn new(
source: VariantDescs<'a>,
msg_type: MsgType,
generics: &'a Vec<&'a GenericParam>,
unfiltered_where_clause: &'a Option<WhereClause>,
) -> Self {
let mut generics_checker = CheckGenerics::new(generics);

let variants: Vec<_> = source
.filter_map(|variant_desc| {
let msg_attr = variant_desc.attr_msg()?;
Expand All @@ -684,19 +694,49 @@ impl<'a> MsgVariants<'a> {
}
};

if attr.msg_type() != msg_type {
return None;
}

Some(MsgVariant::new(
variant_desc.into_sig(),
&mut generics_checker,
attr,
))
})
.collect();
Self(variants)

let (unbonded_generics, _) = generics_checker.used_unused();
let wheres = filter_wheres(
unfiltered_where_clause,
generics.as_slice(),
&unbonded_generics,
);
let where_clause = if !wheres.is_empty() {
Some(parse_quote! { where #(#wheres),* })
} else {
None
};

Self {
variants,
unbonded_generics,
where_clause,
}
}

pub fn variants(&self) -> &Vec<MsgVariant<'a>> {
&self.variants
}

pub fn emit_querier(&self) -> TokenStream {
let sylvia = crate_module();
let variants = &self.0;
let Self {
variants,
unbonded_generics,
where_clause,
..
} = self;

let methods_impl = variants
.iter()
Expand All @@ -708,6 +748,12 @@ impl<'a> MsgVariants<'a> {
.filter(|variant| variant.msg_type == MsgType::Query)
.map(MsgVariant::emit_querier_declaration);

let querier = if !unbonded_generics.is_empty() {
quote! { Querier < #(#unbonded_generics,)* > }
} else {
quote! { Querier }
};

#[cfg(not(tarpaulin_include))]
{
quote! {
Expand All @@ -730,12 +776,11 @@ impl<'a> MsgVariants<'a> {
}
}

impl <'a, C: #sylvia ::cw_std::CustomQuery> Querier for BoundQuerier<'a, C> {
impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for BoundQuerier<'a, C> #where_clause {
#(#methods_impl)*
}


pub trait Querier {
pub trait #querier {
#(#methods_declaration)*
}
}
Expand All @@ -748,24 +793,33 @@ impl<'a> MsgVariants<'a> {
contract_module: Option<&Path>,
) -> TokenStream {
let sylvia = crate_module();
let variants = &self.0;
let Self {
variants,
unbonded_generics,
where_clause,
..
} = self;

let methods_impl = variants
.iter()
.filter(|variant| variant.msg_type == MsgType::Query)
.map(|variant| variant.emit_querier_impl(trait_module));

let querier = trait_module
let mut querier = trait_module
.map(|module| quote! { #module ::Querier })
.unwrap_or_else(|| quote! { Querier });
let bound_querier = contract_module
.map(|module| quote! { #module ::BoundQuerier})
.unwrap_or_else(|| quote! { BoundQuerier });

if !unbonded_generics.is_empty() {
querier = quote! { #querier < #(#unbonded_generics,)* > };
}

#[cfg(not(tarpaulin_include))]
{
quote! {
impl <'a, C: #sylvia ::cw_std::CustomQuery> #querier for #bound_querier<'a, C> {
impl <'a, C: #sylvia ::cw_std::CustomQuery, #(#unbonded_generics,)*> #querier for #bound_querier<'a, C> #where_clause {
#(#methods_impl)*
}
}
Expand Down Expand Up @@ -886,7 +940,7 @@ impl<'a> GlueMessage<'a> {
interfaces,
} = self;
let contract = StripGenerics.fold_type((*contract).clone());
let contract_name = Ident::new(&format!("Contract{}", name), name.span());
let enum_name = Ident::new(&format!("Contract{}", name), name.span());

let variants = interfaces.emit_glue_message_variants(msg_ty, name);

Expand Down Expand Up @@ -916,15 +970,15 @@ impl<'a> GlueMessage<'a> {

match (msg_ty, customs.has_msg) {
(MsgType::Exec, true) => quote! {
#contract_name :: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?)
#enum_name:: #variant(msg) => #sylvia ::into_response::IntoResponse::into_response(msg.dispatch(contract, Into::into( #ctx ))?)
},
_ => quote! {
#contract_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx ))
#enum_name :: #variant(msg) => msg.dispatch(contract, Into::into( #ctx ))
},
}
});

let dispatch_arm = quote! {#contract_name :: #contract (msg) =>msg.dispatch(contract, ctx)};
let dispatch_arm = quote! {#enum_name :: #contract (msg) => msg.dispatch(contract, ctx)};

let interfaces_deserialization_attempts = interfaces.emit_deserialization_attempts(name);

Expand All @@ -951,7 +1005,7 @@ impl<'a> GlueMessage<'a> {
{
quote! {
#[cfg(not(target_arch = "wasm32"))]
impl cosmwasm_schema::QueryResponses for #contract_name {
impl #sylvia ::cw_schema::QueryResponses for #enum_name {
fn response_schemas_impl() -> std::collections::BTreeMap<String, #sylvia ::schemars::schema::RootSchema> {
let responses = [#(#response_schemas_calls),*];
responses.into_iter().flatten().collect()
Expand All @@ -971,12 +1025,12 @@ impl<'a> GlueMessage<'a> {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(#sylvia ::serde::Serialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)]
#[serde(rename_all="snake_case", untagged)]
pub enum #contract_name {
pub enum #enum_name {
#(#variants,)*
#msg_name
}

impl #contract_name {
impl #enum_name {
pub fn dispatch(
self,
contract: &#contract,
Expand All @@ -996,7 +1050,7 @@ impl<'a> GlueMessage<'a> {

#response_schemas

impl<'de> serde::Deserialize<'de> for #contract_name {
impl<'de> serde::Deserialize<'de> for #enum_name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: serde::Deserializer<'de>,
{
Expand Down Expand Up @@ -1043,7 +1097,8 @@ pub struct EntryPoints<'a> {
error: Type,
custom: Custom<'a>,
override_entry_points: OverrideEntryPoints,
variants: MsgVariants<'a>,
has_migrate: bool,
reply: Option<Ident>,
}

impl<'a> EntryPoints<'a> {
Expand All @@ -1067,17 +1122,24 @@ impl<'a> EntryPoints<'a> {
)
.unwrap_or_else(|| parse_quote! { #sylvia ::cw_std::StdError });

let generics: Vec<_> = source.generics.params.iter().collect();
let has_migrate = !MsgVariants::new(source.as_variants(), MsgType::Migrate, &vec![], &None)
.variants()
.is_empty();

let variants = MsgVariants::new(source.as_variants(), &generics);
let reply = MsgVariants::new(source.as_variants(), MsgType::Reply, &vec![], &None)
.variants()
.iter()
.map(|variant| variant.function_name.clone())
.next();
let custom = Custom::new(&source.attrs);

Self {
name,
error,
custom,
override_entry_points,
variants,
has_migrate,
reply,
}
}

Expand All @@ -1087,17 +1149,13 @@ impl<'a> EntryPoints<'a> {
error,
custom,
override_entry_points,
variants,
has_migrate,
reply,
} = self;
let sylvia = crate_module();

let custom_msg = custom.msg_or_default();
let custom_query = custom.query_or_default();
let reply = variants
.0
.iter()
.find(|variant| variant.msg_type == MsgType::Reply)
.map(|variant| variant.function_name.clone());

#[cfg(not(tarpaulin_include))]
{
Expand All @@ -1119,12 +1177,8 @@ impl<'a> EntryPoints<'a> {
let migrate_not_overridden = override_entry_points
.get_entry_point(MsgType::Migrate)
.is_none();
let migrate_msg_defined = variants
.0
.iter()
.any(|variant| variant.msg_type == MsgType::Migrate);

let migrate = if migrate_not_overridden && migrate_msg_defined {
let migrate = if migrate_not_overridden && *has_migrate {
OverrideEntryPoint::emit_default_entry_point(
&custom_msg,
&custom_query,
Expand Down
Loading

0 comments on commit b9ba7ca

Please sign in to comment.