Skip to content

Commit

Permalink
Merge pull request #5 from shapespeare/observe
Browse files Browse the repository at this point in the history
✨ Implement Observe
  • Loading branch information
yhs0602 authored Jun 1, 2024
2 parents e7a1afe + 5a72283 commit 365eceb
Showing 1 changed file with 47 additions and 15 deletions.
62 changes: 47 additions & 15 deletions lib/compile.ml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
open Core
open Program

type number = Int of int | Float of float

module Env = struct
type t = (Id.t, fn, Id.comparator_witness) Map.t

Expand All @@ -13,13 +11,19 @@ end

module Pred = struct
type t = Empty | And of Det_exp.t * t | And_not of Det_exp.t * t

let rec fv : t -> Set.M(Id).t = function
| Empty -> Set.empty (module Id)
| And (de, p) | And_not (de, p) -> Set.union (Det_exp.fv de) (fv p)
end

module Dist = struct
type t
type one = One

type exp =
| If of Det_exp.t * exp * exp
| If_de of Det_exp.t * exp * exp
| If_pred of Pred.t * exp * one
| Dist_obj of { dist : t; var : Id.t; args : Det_exp.t list }

exception Score_invalid_arguments
Expand All @@ -31,7 +35,7 @@ module Dist = struct
| If (e_pred, e_con, e_alt) ->
let s_con = score e_con var in
let s_alt = score e_alt var in
If (e_pred, s_con, s_alt)
If_de (e_pred, s_con, s_alt)
| Prim_call (c, es) -> Dist_obj { dist = prim_to_dist c; var; args = es }
| _ -> raise Score_invalid_arguments
end
Expand All @@ -40,7 +44,7 @@ module Graph = struct
type vertex = Id.t
type arc = vertex * vertex
type det_map = (Id.t, Dist.exp, Id.comparator_witness) Map.t
type obs_map = (Id.t, number, Id.comparator_witness) Map.t
type obs_map = (Id.t, Det_exp.t, Id.comparator_witness) Map.t

type t = {
vertices : vertex list;
Expand Down Expand Up @@ -81,7 +85,13 @@ let gen_sym =
let cnt = ref 0 in
fun () ->
incr cnt;
Printf.sprintf "X_%d" !cnt
Printf.sprintf "#%d" !cnt

let gen_vertex =
let cnt = ref 0 in
fun () ->
incr cnt;
Printf.sprintf "X%d" !cnt

let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t =
let sub' exp = sub exp x det_exp in
Expand Down Expand Up @@ -120,6 +130,8 @@ let rec sub (exp : Exp.t) (x : Id.t) (det_exp : Det_exp.t) : Exp.t =
| Sample e -> Sample (sub' e)
| Observe (e1, e2) -> Observe (sub' e1, sub' e2)

exception Not_closed_observation

let rec compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) :
Graph.t * Det_exp.t =
ignore env;
Expand All @@ -130,17 +142,37 @@ let rec compile (env : Env.t) (pred : Pred.t) (exp : Exp.t) :
| Var x -> (Graph.empty, Det_exp.Var x)
| Sample e ->
let g, de = compile env pred e in
let v = gen_sym () in
let v = gen_vertex () in
let de_fvs = Det_exp.fv de in
let f = Dist.score de v in
( g
@+ {
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.empty (module Id);
},
Det_exp.Var v )
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list de_fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.empty (module Id);
}
in
(g @+ g', Det_exp.Var v)
| Observe (e1, e2) ->
let g1, de1 = compile env pred e1 in
let g2, de2 = compile env pred e2 in
let v = gen_vertex () in
let f1 = Dist.score de1 v in
let f = Dist.(If_pred (pred, f1, One)) in
let fvs = Set.union (Det_exp.fv de1) (Pred.fv pred) in
if not @@ Set.is_empty (Det_exp.fv de2) then raise Not_closed_observation;
let g' =
Graph.
{
vertices = [ v ];
arcs = List.map (Set.to_list fvs) ~f:(fun z -> (z, v));
det_map = Map.singleton (module Id) v f;
obs_map = Map.singleton (module Id) v de2;
}
in
(g1 @+ g2 @+ g', de2)
| Assign (x, e, body) ->
let g1, det_exp1 = compile env pred e in
let sub_body = sub body x det_exp1 in
Expand Down

0 comments on commit 365eceb

Please sign in to comment.