From fa25d9cfb12d99d042cfc2000c970d435364a466 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Sun, 24 Jul 2022 05:57:26 -0700 Subject: [PATCH 1/2] add logdensity_def --- src/primitives/logdensity.jl | 53 ++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/primitives/logdensity.jl b/src/primitives/logdensity.jl index 81fa4cd..b157153 100644 --- a/src/primitives/logdensity.jl +++ b/src/primitives/logdensity.jl @@ -43,3 +43,56 @@ end end end +############################################################################### + + +function MeasureBase.logdensity_def(c::ConditionalModel{A,B,M}, x=NamedTuple()) where {A,B,M} + @show c + _logdensity_def(M, Model(c), argvals(c), observations(c), x) +end + +export sourceLogdensityDef + +sourceLogdensityDef(m::AbstractModel) = sourceLogdensityDef()(Model(m)) + +# function Base.convert(nt::NamedTuple, args...) +# @show nt +# @show args +# for (n, t) in enumerate(stacktrace()) +# print("\t",n,". ") +# println(t) +# end + +# end + +function sourceLogdensityDef() + function(_m::Model) + proc(_m, st :: Assign) = :($(st.x) = $(st.rhs)) + proc(_m, st :: Return) = nothing + proc(_m, st :: LineNumber) = nothing + function proc(_m, st :: Sample) + x = st.x + rhs = st.rhs + @q begin + _ℓ += Soss.logdensity_def($rhs, $x) + $x = Soss.predict($rhs, $x) + end + end + + wrap(kernel) = @q begin + _ℓ = 0.0 + $kernel + return _ℓ + end + + buildSource(_m, proc, wrap) |> MacroTools.flatten + end +end + +@gg function _logdensity_def(M::Type{<:TypeLevel}, _m::Model, _args, _data, _pars) + body = type2model(_m) |> sourceLogdensityDef() |> loadvals(_args, _data, _pars) + @under_global from_type(_unwrap_type(M)) @q let M + $body + end +end + From 05c6f82f4f1adece89c4cfcd0416df22e3b4c849 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Sun, 24 Jul 2022 06:00:31 -0700 Subject: [PATCH 2/2] drop show --- src/primitives/logdensity.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primitives/logdensity.jl b/src/primitives/logdensity.jl index b157153..4d13367 100644 --- a/src/primitives/logdensity.jl +++ b/src/primitives/logdensity.jl @@ -47,7 +47,6 @@ end function MeasureBase.logdensity_def(c::ConditionalModel{A,B,M}, x=NamedTuple()) where {A,B,M} - @show c _logdensity_def(M, Model(c), argvals(c), observations(c), x) end