Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ No more special pred type
Browse files Browse the repository at this point in the history
Zeta611 committed Jun 10, 2024

Verified

This commit was signed with the committer’s verified signature. The key has expired.
snoyberg Michael Snoyman
1 parent 0f20467 commit 28c240f
Showing 3 changed files with 80 additions and 67 deletions.
92 changes: 61 additions & 31 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
@@ -25,10 +25,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
match peval te with
| { exp = Value v; _ } -> { ty; exp = Value (uop.op v) }
| e -> { ty; exp = Uop (uop, e) })
| If_pred (pred, te_con, te_alt) -> (
match peval_pred pred with
| True -> peval { ty; exp = If_just te_con }
| False -> peval { ty; exp = If_just te_alt }
| If_pred (te_pred, te_con, te_alt) -> (
match peval te_pred with
| { exp = Value true; _ } -> peval { ty; exp = If_just te_con }
| { exp = Value false; _ } -> peval { ty; exp = If_just te_alt }
| p -> { ty; exp = If_pred (p, peval te_con, peval te_alt) })
| Call (f, args) -> (
match peval_args args with
@@ -48,9 +48,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
in
{ ty; exp = Call (f_dist, []) })
| If_pred_dist (p, de) -> (
match peval_pred p with
| True -> peval de
| False -> { ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) }
match peval p with
| { exp = Value true; _ } -> peval de
| { exp = Value false; _ } ->
{ ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) }
| p -> { ty; exp = If_pred_dist (p, peval de) })
| If_just de -> { ty; exp = If_just (peval de) }

@@ -63,23 +64,41 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
({ ty; exp = Value v } :: tl, Some ((dty_of_dat_ty ty, v) :: vargs))
| te, (tl, _) -> (te :: tl, None))

and peval_pred : pred -> pred = function
| Empty -> failwith "[Bug] Empty predicate"
| True -> True
| False -> False
| And (p, de) -> (
match peval de with
| { exp = Value true; _ } -> peval_pred p
| { exp = Value false; _ } -> False
| de -> And (p, de))
| And_not (p, de) -> (
match peval de with
| { exp = Value true; _ } -> False
| { exp = Value false; _ } -> peval_pred p
| de -> And_not (p, de))
let ( &&& ) :
type s1 s2 s.
((bool, _) dat_ty, det) texp ->
((bool, _) dat_ty, det) texp ->
bool some_dat_det_texp =
fun ({ ty = Dat_ty (Tyb, s1); _ } as p1) ({ ty = Dat_ty (Tyb, s2); _ } as p2) ->
let (Ex (ms, s)) = merge_stamps s1 s2 in
Ex
(peval
{
ty = Dat_ty (Tyb, s);
exp = Bop ({ name = "&&"; op = ( && ) }, p1, p2, ms);
})

let ( &&& ) p de = peval_pred (And (p, de))
let ( &&! ) p de = peval_pred (And_not (p, de))
let ( &&! ) :
type s1 s2 s.
((bool, _) dat_ty, det) texp ->
((bool, _) dat_ty, det) texp ->
bool some_dat_det_texp =
fun ({ ty = Dat_ty (Tyb, s1); _ } as p1) ({ ty = Dat_ty (Tyb, s2); _ } as p2) ->
let (Ex (ms, s)) = merge_stamps s1 s2 in
Ex
(peval
{
ty = Dat_ty (Tyb, s);
exp =
Bop
( { name = "&&"; op = ( && ) },
p1,
{
ty = Dat_ty (Tyb, s2);
exp = Uop ({ name = "not"; op = not }, p2);
},
ms );
})

let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp =
function
@@ -88,9 +107,12 @@ let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp =
| { exp = Call _; _ } as e -> e

let rec compile :
type a s. env:env -> ?pred:pred -> (a, ndet) texp -> Graph.t * (a, det) texp
=
fun ~env ?(pred = Empty) { ty; exp } ->
type a s.
env:env ->
pred:((bool, s) dat_ty, det) texp ->
(a, ndet) texp ->
Graph.t * (a, det) texp =
fun ~env ~pred { ty; exp } ->
match exp with
| Value _ as exp -> (Graph.empty, { ty; exp })
| Var x -> (
@@ -107,8 +129,8 @@ let rec compile :
(g, peval { ty; exp = Uop (op, te) })
| If (e_pred, e_con, e_alt, _, _) ->
let g1, de_pred = compile ~env ~pred e_pred in
let pred_con = pred &&& de_pred in
let pred_alt = pred &&! de_pred in
let (Ex pred_con) = pred &&& de_pred in
let (Ex pred_alt) = pred &&! de_pred in
let g2, de_con = compile ~env ~pred:pred_con e_con in
let g3, de_alt = compile ~env ~pred:pred_alt e_alt in
let g = Graph.(g1 @| g2 @| g3) in
@@ -143,7 +165,7 @@ let rec compile :
let v = gen_vertex () in
let f1 = score de1 in
let f = { ty = f1.ty; exp = If_pred_dist (pred, f1) } in
let fvs = Id.(fv de1.exp @| fv_pred pred) in
let fvs = Id.(fv de1.exp @| fv pred.exp) in
if not (Set.is_empty (fv de2.exp)) then
failwith "[Bug] Not closed observation";
let g' =
@@ -158,7 +180,11 @@ let rec compile :
Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () })

and compile_args :
type a. env -> pred -> (a, ndet) args -> Graph.t * (a, det) args =
type a s.
env ->
((bool, s) dat_ty, det) texp ->
(a, ndet) args ->
Graph.t * (a, det) args =
fun env pred args ->
match args with
| [] -> (Graph.empty, [])
@@ -177,7 +203,11 @@ let compile_program (prog : program) : Graph.t * Evaluator.query =
m "Inlined program %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]);

let (Ex e) = Typing.check exp in
let g, { ty; exp } = compile ~env:Id.Map.empty e in
let g, { ty; exp } =
compile ~env:Id.Map.empty
~pred:{ ty = Dat_ty (Tyb, Val); exp = Value true }
e
in
match ty with
| Dat_ty (_, Rv) -> (g, Ex { ty; exp })
| _ -> raise Query_not_found
14 changes: 4 additions & 10 deletions lib/evaluator.ml
Original file line number Diff line number Diff line change
@@ -22,24 +22,18 @@ let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a =
| None -> assert false)
| Bop ({ op; _ }, te1, te2, _) -> op (eval_dat ctx te1) (eval_dat ctx te2)
| Uop ({ op; _ }, te) -> op (eval_dat ctx te)
| If_pred (pred, te_con, te_alt) ->
if eval_pred ctx pred then eval_dat ctx te_con else eval_dat ctx te_alt
| If_pred (te_pred, te_con, te_alt) ->
if eval_dat ctx te_pred then eval_dat ctx te_con else eval_dat ctx te_alt
| If_just te -> eval_dat ctx te

and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a =
fun ctx { ty = Dist_ty dty as ty; exp } ->
match exp with
| Call (f, args) -> f.sampler (eval_args ctx args)
| If_pred_dist (pred, dist) ->
if eval_pred ctx pred then eval_dist ctx dist
if eval_dat ctx pred then eval_dist ctx dist
else eval_dist ctx { ty; exp = Call (Dist.one dty, []) }

and eval_pred (ctx : Ctx.t) : pred -> bool = function
| Empty | True -> true
| False -> false
| And (p, de) -> eval_dat ctx de && eval_pred ctx p
| And_not (p, de) -> (not (eval_dat ctx de)) && eval_pred ctx p

and eval_args : type a. Ctx.t -> (a, det) args -> a vargs =
fun ctx -> function
| [] -> []
@@ -51,7 +45,7 @@ let rec eval_pmdf :
fun ctx { ty = Dist_ty dty as ty; exp } ->
match exp with
| If_pred_dist (pred, te) ->
if eval_pred ctx pred then eval_pmdf ctx te
if eval_dat ctx pred then eval_pmdf ctx te
else eval_pmdf ctx { ty; exp = Call (Dist.one dty, []) }
| Call (f, args) ->
let pmdf (Ex (ty', v) : some_val) =
41 changes: 15 additions & 26 deletions lib/typed_tree.ml
Original file line number Diff line number Diff line change
@@ -40,18 +40,10 @@ type ('a, 'b) dist = {
log_pmdf : 'b vargs -> 'a -> real;
}

(* TODO: Why args should also be det? *)
type (_, _) args =
| [] : (unit, _) args
| ( :: ) : (('a, _) dat_ty, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args

and pred =
| Empty : pred
| True : pred
| False : pred
| And : pred * ((bool, _) dat_ty, det) texp -> pred
| And_not : pred * ((bool, _) dat_ty, det) texp -> pred

and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp }

and (_, _) exp =
@@ -73,9 +65,13 @@ and (_, _) exp =
* ('s_pred, 's_ca, 's) merge_stamp
-> (('a, 's) dat_ty, ndet) exp
| If_pred :
pred * (('a, _) dat_ty, det) texp * (('a, _) dat_ty, det) texp
((bool, _) dat_ty, det) texp
* (('a, _) dat_ty, det) texp
* (('a, _) dat_ty, det) texp
-> (('a, _) dat_ty, det) exp
| If_pred_dist : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp
| If_pred_dist :
((bool, _) dat_ty, det) texp * ('a dist_ty, det) texp
-> ('a dist_ty, det) exp
| If_just : (('a, _) dat_ty, det) texp -> (('a, _) dat_ty, det) exp
| Let : Id.t * ('a, ndet) texp * ('b, ndet) texp -> ('b, ndet) exp
| Call : ('a, 'b) dist * ('b, 'd) args -> ('a dist_ty, 'd) exp
@@ -92,6 +88,9 @@ type _ some_texp = Ex : (_, 'd) texp -> 'd some_texp
type _ some_dat_ndet_texp =
| Ex : (('a, _) dat_ty, ndet) texp -> 'a some_dat_ndet_texp

type _ some_dat_det_texp =
| Ex : (('a, _) dat_ty, det) texp -> 'a some_dat_det_texp

type _ some_val_texp = Ex : ((_, value) dat_ty, 'd) texp -> 'd some_val_texp
type _ some_rv_texp = Ex : ((_, rv) dat_ty, 'd) texp -> 'd some_rv_texp
type _ some_dat_texp = Ex : (_ dat_ty, 'd) texp -> 'd some_dat_texp
@@ -184,21 +183,17 @@ let rec fv : type a. (a, det) exp -> Id.Set.t = function
| Rvar x -> Id.Set.singleton x
| Bop (_, { exp = e1; _ }, { exp = e2; _ }, _) -> Id.(fv e1 @| fv e2)
| Uop (_, { exp; _ }) -> fv exp
| If_pred (pred, { exp = e_con; _ }, { exp = e_alt; _ }) ->
Id.(fv_pred pred @| fv e_con @| fv e_alt)
| If_pred_dist (pred, { exp = e_con; _ }) -> Id.(fv_pred pred @| fv e_con)
| If_pred ({ exp = e_pred; _ }, { exp = e_con; _ }, { exp = e_alt; _ }) ->
Id.(fv e_pred @| fv e_con @| fv e_alt)
| If_pred_dist ({ exp = e_pred; _ }, { exp = e_con; _ }) ->
Id.(fv e_pred @| fv e_con)
| If_just { exp; _ } -> fv exp
| Call (_, args) -> fv_args args

and fv_args : type a. (a, det) args -> Id.Set.t = function
| [] -> Id.Set.empty
| { exp; _ } :: es -> Id.(fv exp @| fv_args es)

and fv_pred : pred -> Id.Set.t = function
| Empty | True | False -> Id.Set.empty
| And (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p)
| And_not (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p)

module Erased = struct
type exp =
| Value : string -> exp
@@ -220,8 +215,8 @@ module Erased = struct
fun { ty; exp } ->
match exp with
| If (pred, con, alt, _, _) -> If (of_exp pred, of_exp con, of_exp alt)
| If_pred (pred, con, alt) -> If (of_pred pred, of_exp con, of_exp alt)
| If_pred_dist (pred, con) -> If (of_pred pred, of_exp con, Value "1")
| If_pred (pred, con, alt) -> If (of_exp pred, of_exp con, of_exp alt)
| If_pred_dist (pred, con) -> If (of_exp pred, of_exp con, Value "1")
| If_just exp -> If_just (of_exp exp)
| Value v -> (
match ty with
@@ -241,11 +236,5 @@ module Erased = struct
| [] -> []
| arg :: args -> of_exp arg :: of_args args

and of_pred : pred -> exp = function
| Empty | True -> Value "true"
| False -> Value "false"
| And (pred, exp) -> Bop ("&&", of_pred pred, of_exp exp)
| And_not (pred, exp) -> Bop ("&&", of_pred pred, Uop ("not", of_exp exp))

let of_rv (Ex rv : _ some_rv_texp) = rv |> of_exp
end

0 comments on commit 28c240f

Please sign in to comment.