-
Notifications
You must be signed in to change notification settings - Fork 0
/
Class.hs
76 lines (69 loc) · 3.4 KB
/
Class.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-}
module Numeric.ADEV.Class (
D(..), C(..), ADEV(..)) where
import Numeric.Log
-- | Type of density-carrying distributions.
data D m r a = D (m a) (a -> Log r)
-- | Type of CDF-carrying distributions over the reals,
-- for implicit reparameterization.
data C m r = C (m Double) (Double -> Log Double) (r -> Log r)
-- ---------------------------------------------------------------------------
-- | A typeclass for ADEV programs, parameterized by:
--
-- * @r@ - the type used to represent real numbers
-- * @m@ - the monad used to encode randomness (so that @m r@ is the type of
-- unbiasedly estimated real numbers)
-- * @p@ - the type used for monadic probabilistic programming (so
-- that @p m a@ is a probabilistic program returning @a@)
class (RealFrac r, Monad (p m), Monad m) => ADEV p m r | p -> r, r -> p where
-- | Sample a random uniform value between 0 and 1.
sample :: p m r
-- | Add a real value into a running cost accumulator.
-- When a @p m r@ is passed to @expect@, the result is
-- an estimator of the expected cost *plus* the expected
-- return value.
add_cost :: r -> p m ()
-- | Flip a coin with a specified probability of heads.
-- Uses enumeration (costly but low-variance) to estimate
-- gradients.
flip_enum :: r -> p m Bool
-- | Flip a coin with a specified probability of heads.
-- Uses the REINFORCE estimator (cheaper but higher-variance)
-- for gradients.
flip_reinforce :: r -> p m Bool
-- | Generate from a normal distribution. Uses the REPARAM gradient estimator.
normal_reparam :: r -> r -> p m r
-- | Generate from a normal distribution. Uses the REINFORCE gradient estimator.
normal_reinforce :: r -> r -> p m r
-- | Estimate the expectation of a probabilistic computation.
expect :: p m r -> m r
-- | Combinator DSL for estimators
plus_ :: m r -> m r -> m r
times_ :: m r -> m r -> m r
exp_ :: m r -> m r
minibatch_ :: Int -> Int -> (Int -> m r) -> m r
exact_ :: r -> m r
-- | Baselines for controlling variance
baseline :: p m r -> r -> m r
-- | Automatic construction of new REINFORCE estimators
reinforce :: D m r a -> p m a
-- | Storchastic leave_one_out estimator
leave_one_out :: Int -> D m r a -> p m a
-- | Differentiable particle filter, accepting:
-- * @p@: a density function for the target measure.
-- * @q0@: an initial proposal for the particle filter.
-- * @q@: a transition proposal for the particle filter.
-- * @f@: an unbiased estimator of an integrand to estimate
-- * @n@: the number of SMC steps to run
-- * @k@: the number of particles to use
-- Returns an SMC estimator of the integral
smc :: ([a] -> Log r) -> D m r a -> (a -> D m r a) -> ([a] -> m r) -> Int -> Int -> m r
-- | Importance sampling gradient estimator
importance :: D m r a -> D m r a -> p m a
-- | Implicit reparameterization for real-valued distributions
-- differentiable with CDFs (e.g., mixtures of Gaussians)
implicit_reparam :: C m r -> p m r
-- | Sample from a Poisson distribution, using a measure-valued derivative.
poisson_weak :: Log r -> p m Int
-- | Gradients through rejection sampling for density-carrying distributions.
reparam_reject :: D m r a -> (a -> b) -> (D m r b) -> (D m r b) -> Log r -> p m b