diff --git a/dunge/tests/shader.rs b/dunge/tests/shader.rs index 5a8c18a..b7edb97 100644 --- a/dunge/tests/shader.rs +++ b/dunge/tests/shader.rs @@ -3,7 +3,7 @@ type Error = Box; #[test] -fn render() -> Result<(), Error> { +fn shader_calc() -> Result<(), Error> { use dunge::{ glam::Vec4, sl::{self, Out}, @@ -12,7 +12,7 @@ fn render() -> Result<(), Error> { let compute = || { let m = -sl::mat2(sl::vec2(1., 0.), sl::vec2(0., 1.)); let [m0, m1, m3] = sl::thunk(m); - let v = m0.x() + (-m1.y()); + let v = m0.x() + m1.y(); let z = sl::splat_vec3(1.).z(); Out { @@ -23,6 +23,30 @@ fn render() -> Result<(), Error> { let cx = helpers::block_on(dunge::context())?; let shader = cx.make_shader(compute); - assert_eq!(shader.debug_wgsl(), include_str!("shader.wgsl")); + assert_eq!(shader.debug_wgsl(), include_str!("shader_calc.wgsl")); + Ok(()) +} + +#[test] +fn shader_if() -> Result<(), Error> { + use dunge::{ + glam::Vec4, + sl::{self, Out}, + }; + + let compute = || { + let a = Vec4::splat(3.); + let b = sl::splat_vec4(2.) * 2.; + let x = sl::if_then_else(true, a, b); + + Out { + place: x, + color: sl::splat_vec4(1.), + } + }; + + let cx = helpers::block_on(dunge::context())?; + let shader = cx.make_shader(compute); + assert_eq!(shader.debug_wgsl(), include_str!("shader_if.wgsl")); Ok(()) } diff --git a/dunge/tests/shader.wgsl b/dunge/tests/shader_calc.wgsl similarity index 75% rename from dunge/tests/shader.wgsl rename to dunge/tests/shader_calc.wgsl index 59541fe..3dfa57b 100644 --- a/dunge/tests/shader.wgsl +++ b/dunge/tests/shader_calc.wgsl @@ -5,7 +5,7 @@ struct VertexOutput { @vertex fn vs() -> VertexOutput { let _e7: mat2x2 = -(mat2x2(vec2(1f, 0f), vec2(0f, 1f))); - return VertexOutput(((vec4(_e7[0], (_e7[0] + -(_e7[1]))) * f32(1i)) * vec3(1f, 1f, 1f).z)); + return VertexOutput(((vec4(_e7[0], (_e7[0] + _e7[1])) * f32(1i)) * vec3(1f, 1f, 1f).z)); } @fragment diff --git a/dunge/tests/shader_if.wgsl b/dunge/tests/shader_if.wgsl new file mode 100644 index 0000000..e48f0c3 --- /dev/null +++ b/dunge/tests/shader_if.wgsl @@ -0,0 +1,21 @@ +struct VertexOutput { + @builtin(position) member: vec4, +} + +@vertex +fn vs() -> VertexOutput { + var local: vec4; + + if true { + local = vec4(3f, 3f, 3f, 3f); + } else { + local = (vec4(2f, 2f, 2f, 2f) * 2f); + } + let _e11: vec4 = local; + return VertexOutput(_e11); +} + +@fragment +fn fs(param: VertexOutput) -> @location(0) vec4 { + return vec4(1f, 1f, 1f, 1f); +} diff --git a/dunge_shader/src/eval.rs b/dunge_shader/src/eval.rs index 565d8b0..a07b040 100644 --- a/dunge_shader/src/eval.rs +++ b/dunge_shader/src/eval.rs @@ -7,8 +7,8 @@ use { types::{self, MemberType, ScalarType, ValueType, VectorType}, }, naga::{ - AddressSpace, Arena, BinaryOperator, Binding, Block, BuiltIn, EntryPoint, Expression, - Function, FunctionArgument, FunctionResult, GlobalVariable, Handle, Literal, Range, + AddressSpace, Arena, BinaryOperator, Binding, BuiltIn, EntryPoint, Expression, Function, + FunctionArgument, FunctionResult, GlobalVariable, Handle, Literal, LocalVariable, Range, ResourceBinding, SampleLevel, ShaderStage, Span, Statement, StructMember, Type, TypeInner, UnaryOperator, UniqueArena, }, @@ -338,11 +338,11 @@ impl Clone for Thunk { } } -impl Eval for Ret, O> +impl Eval for Ret, A::Out> where A: Eval, { - type Out = O; + type Out = A::Out; fn eval(self, en: &mut E) -> Expr { let Thunk { s, .. } = self.get(); @@ -363,6 +363,50 @@ enum State { Expr(Expr), } +pub fn if_then_else(c: C, a: A, b: B) -> Ret, A::Out> +where + C: Eval, + A: Eval, + A::Out: types::Value, + B: Eval, +{ + Ret::new(IfThenElse { + c, + a, + b, + e: PhantomData, + }) +} + +pub struct IfThenElse { + c: C, + a: A, + b: B, + e: PhantomData, +} + +impl Eval for Ret, A::Out> +where + C: Eval, + A: Eval, + A::Out: types::Value, + B: Eval, + E: GetEntry, +{ + type Out = A::Out; + + fn eval(self, en: &mut E) -> Expr { + let IfThenElse { c, a, b, .. } = self.get(); + let c = c.eval(en); + let a = a.eval(en); + let b = b.eval(en); + let en = en.get_entry(); + let valty = ::VALUE_TYPE; + let ty = en.new_type(valty.ty()); + en.if_then_else(c, a, b, ty) + } +} + #[derive(Default)] pub(crate) struct Evaluated([Option; 4]); @@ -651,9 +695,11 @@ impl Sampled { pub struct Entry { compl: Compiler, + locls: Arena, exprs: Arena, stats: Statements, cached_glob: HashMap, Expr>, + cached_locl: HashMap, Expr>, cached_args: HashMap, } @@ -661,9 +707,11 @@ impl Entry { fn new(compl: Compiler) -> Self { Self { compl, + locls: Arena::default(), exprs: Arena::default(), stats: Statements::default(), cached_glob: HashMap::default(), + cached_locl: HashMap::default(), cached_args: HashMap::default(), } } @@ -672,6 +720,16 @@ impl Entry { self.compl.types.insert(ty, Span::UNDEFINED) } + fn add_local(&mut self, ty: Handle) -> Handle { + let local = LocalVariable { + name: None, + ty, + init: None, + }; + + self.locls.append(local, Span::UNDEFINED) + } + fn literal(&mut self, literal: Literal) -> Expr { let ex = Expression::Literal(literal); Expr(self.exprs.append(ex, Span::UNDEFINED)) @@ -691,6 +749,13 @@ impl Entry { }) } + fn local(&mut self, v: Handle) -> Expr { + *self.cached_locl.entry(v).or_insert_with(|| { + let ex = Expression::LocalVariable(v); + Expr(self.exprs.append(ex, Span::UNDEFINED)) + }) + } + fn load(&mut self, ptr: Expr) -> Expr { let ex = Expression::Load { pointer: ptr.0 }; let handle = self.exprs.append(ex, Span::UNDEFINED); @@ -777,6 +842,30 @@ impl Entry { Expr(handle) } + // TODO: Lazy evaluation + fn if_then_else(&mut self, cond: Expr, a: Expr, b: Expr, ty: Handle) -> Expr { + let v = self.add_local(ty); + let pointer = self.local(v); + let a = Statements(vec![Statement::Store { + pointer: pointer.0, + value: a.0, + }]); + + let b = Statements(vec![Statement::Store { + pointer: pointer.0, + value: b.0, + }]); + + let st = Statement::If { + condition: cond.0, + accept: a.0.into(), + reject: b.0.into(), + }; + + self.stats.push(st, &self.exprs); + self.load(pointer) + } + fn ret(&mut self, value: Expr) { let st = Statement::Return { value: Some(value.0), @@ -807,8 +896,9 @@ impl Entry { function: Function { arguments: args.map(Argument::into_function).collect(), result: Some(res), + local_variables: self.locls, expressions: self.exprs, - body: Block::from_vec(self.stats.0), + body: self.stats.0.into(), ..Default::default() }, };