Skip to content

Commit

Permalink
🔕(🤖) Better Error Span
Browse files Browse the repository at this point in the history
  • Loading branch information
KAIYOHUGO committed May 24, 2022
1 parent ea6555b commit 3507a1d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 113 deletions.
39 changes: 7 additions & 32 deletions mady_macro_core/src/folder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use syn::spanned::Spanned;
use syn::{parse_quote, Error};
use syn::{parse_quote_spanned, Error};

use crate::error::ParseError;
use crate::gen::Chain;
Expand Down Expand Up @@ -48,10 +48,10 @@ impl Chain for AfterFolder {
destruct.push(var.to_ident())
}

let span = t.span();
let left = t.left;
let right = t.right;

let ast = parse_quote! {
let ast = parse_quote_spanned! {span=>
{
let mady_tmp;
(mady_tmp, (#(#destruct),*)) = #left.#grad_fn(#right);
Expand All @@ -61,31 +61,6 @@ impl Chain for AfterFolder {
Ok(ast)
}

// fn chain_stmt_expr(
// &mut self,
// c: &mut Self::Input,
// t: syn::Expr,
// ) -> Result<syn::Stmt, Self::Err> {
// let stmt = if c.is_top_level() {
// let outs = c
// .graph
// .out_nodes()
// .into_iter()
// .map(|x| c.graph.node_weight(x).id().to_ident());
// let backward = gen_backward(c)?;
// parse_quote! {
// {
// let mady_return = #t;
// #(#backward)*
// (mady_return, (#(#outs),*))
// }
// }
// } else {
// t
// };
// Ok(syn::Stmt::Expr(stmt))
// }

fn chain_block(&mut self, c: &mut Self::Input, t: syn::Block) -> Result<syn::Block, Self::Err> {
let stmt = if c.is_sig_level() {
let outs = c
Expand All @@ -94,7 +69,7 @@ impl Chain for AfterFolder {
.into_iter()
.map(|x| c.graph.node_weight(x).id().to_ident());
let backward = gen_backward(c)?;
parse_quote! {
parse_quote_spanned! {t.span()=>
{
let _mady_return = {#t};
#(#backward)*
Expand All @@ -115,7 +90,7 @@ impl Chain for AfterFolder {
let outs = c.tys().iter().take(c.tys().len() - 1);
match t {
syn::ReturnType::Default => todo!(),
syn::ReturnType::Type(_, t) => Ok(parse_quote! {
syn::ReturnType::Type(_, t) => Ok(parse_quote_spanned! {t.span()=>
-> (#t, (#(#outs),*))
}),
}
Expand All @@ -138,7 +113,7 @@ impl Chain for AfterFolder {
let mut t = t;
t.method = grad_method(t.method);

let ast = parse_quote! {
let ast = parse_quote_spanned! {t.span()=>
{
let mady_tmp;
(mady_tmp, (#(#destruct),*)) = #t;
Expand Down Expand Up @@ -173,7 +148,7 @@ impl Chain for AfterFolder {
return Err(ParseError::UnsupportedSyntax.new(t.func.span()));
}

let ast = parse_quote! {
let ast = parse_quote_spanned! {t.span()=>
{
let mady_tmp;
(mady_tmp, (#(#destruct),*)) = #t;
Expand Down
75 changes: 10 additions & 65 deletions mady_macro_core/src/generator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use syn::{parse_quote, Error};
use syn::{parse_quote_spanned, Error};

use crate::{
error::ParseError,
Expand All @@ -14,13 +14,13 @@ pub fn gen_declare(c: &Recorder) -> Result<Vec<syn::Stmt>, Error> {
match node {
VarType::Tmp(v) | VarType::Grad(v) => {
let ident = v.to_ident();
stmts.push(parse_quote! {
stmts.push(parse_quote_spanned! {v.span()=>
let mut #ident;
});
for e in c.graph.to_edges(n) {
let edge = c.graph.edge_weight(e);
let ident = edge.to_ident();
stmts.push(parse_quote! {
stmts.push(parse_quote_spanned! {edge.span()=>
let #ident;
});
}
Expand All @@ -45,23 +45,24 @@ pub fn gen_backward(c: &Recorder) -> Result<Vec<syn::Stmt>, Error> {
let node_ident = v.to_ident();
if roots.contains(&n) {
let node_ty = tys.next().ok_or(ParseError::NotFindType.new(v.span()))?;
stmts.push(parse_quote! {
stmts.push(parse_quote_spanned! {v.span()=>
#node_ident = <#node_ty>::one();
});
}
for e in c.graph.to_edges(n) {
let edge = c.graph.edge_weight(e).to_ident();
let edge = c.graph.edge_weight(e);
let edge_ident = edge.to_ident();
let t = c.graph.to_node(e);
let to_node = c.graph.node_weight(t).id().to_ident();

if grads.contains(&t) {
stmts.push(parse_quote! {
#to_node += #node_ident.mady_chain(#edge).clone();
stmts.push(parse_quote_spanned! {edge.span()=>
#to_node += #node_ident.mady_chain(#edge_ident).clone();
});
} else {
grads.insert(t);
stmts.push(parse_quote! {
#to_node = #node_ident.mady_chain(#edge).clone();
stmts.push(parse_quote_spanned! {edge.span()=>
#to_node = #node_ident.mady_chain(#edge_ident).clone();
});
}
}
Expand All @@ -73,59 +74,3 @@ pub fn gen_backward(c: &Recorder) -> Result<Vec<syn::Stmt>, Error> {

Ok(stmts)
}

// pub fn _gen_types(c: &Recorder) -> Result<Vec<syn::Stmt>, Error> {
// let mut stmts = vec![];
// let roots = HashSet::<_>::from_iter(c.graph.roots());
// let mut grads = HashSet::new();
// for n in c.graph.topological_iter() {
// grads.insert(n);
// let node = c.graph.node_weight(n);
// match node {
// VarType::Tmp(v) | VarType::Grad(v) => {
// let ty = v.to_type_ident();
// let node_grad_ty = v.to_grad_type_ident();
// let annotate = v
// .ty()
// .clone()
// .ok_or(ParseError::CantInferType.new(node.span()))?;
// stmts.push(parse_quote! {
// type #ty = #annotate;
// });

// if roots.contains(&n) {
// stmts.push(parse_quote! {
// type #node_grad_ty = <#ty as One>::O0;
// });
// }

// for e in c.graph.to_edges(n) {
// let edge = c.graph.edge_weight(e);
// let ty = edge.to_type_ident();
// let annotate = edge
// .ty()
// .clone()
// .ok_or(ParseError::CantInferType.new(edge.span()))?;

// stmts.push(parse_quote! {
// type #ty = #annotate;
// });

// let to_node = c.graph.to_node(e);
// if !grads.contains(&to_node) {
// grads.insert(to_node);
// let grad_ty = c.graph.node_weight(to_node).id().to_grad_type_ident();

// stmts.push(parse_quote! {
// type #grad_ty = <#node_grad_ty as MadyChain<#ty>>::O0;
// });
// }
// }
// }
// VarType::Null => {}
// _ => unimplemented!(),
// }
// }

// Ok(stmts)
// }
11 changes: 1 addition & 10 deletions mady_macro_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,4 @@ pub fn new() -> Parser {
Parser::new()
.register(Linker::new())
.register(Folder::new())
}

#[cfg(test)]
mod tests {
use quote::quote;
use syn::parse_quote;

use super::*;

}
}
6 changes: 0 additions & 6 deletions mady_macro_core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ impl ParseGraph {
}
}

pub fn to_upper_camel_case(name: String) -> String {
name.split('_')
.map(|x| format!("{}{}", x[..1].to_uppercase(), &x[1..]))
.collect::<String>()
}

pub fn grad_method<T>(method_name: T) -> Ident
where
T: ToString,
Expand Down

0 comments on commit 3507a1d

Please sign in to comment.