Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

BMInference API to compile and run BM models from Beanstalk #1801

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions src/beanmachine/ppl/compiler/gen_bm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@


from collections import defaultdict
from typing import Dict, List
from enum import Enum
from typing import Dict, List, Tuple

from beanmachine.ppl.compiler import bmg_nodes as bn

from beanmachine.ppl.compiler.bm_graph_builder import BMGraphBuilder
from beanmachine.ppl.compiler.fix_problems import fix_problems
from beanmachine.ppl.compiler.internal_error import InternalError
from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler
from beanmachine.ppl.inference.single_site_nmc import SingleSiteNewtonianMonteCarlo
from beanmachine.ppl.model.rv_identifier import RVIdentifier


class InferenceType(Enum):
SingleSiteNewtonianMonteCarlo = SingleSiteNewtonianMonteCarlo
GlobalNoUTurnSampler = GlobalNoUTurnSampler


_node_type_to_distribution = {
bn.BernoulliNode: "torch.distributions.Bernoulli",
Expand All @@ -37,6 +47,8 @@ class ToBMPython:
no_dist_samples: Dict[bn.BMGNode, int]
queries: List[str]
observations: List[str]
node_to_rv_id: Dict[str, str]
node_to_query_map: Dict[str, RVIdentifier]

def __init__(self, bmg: BMGraphBuilder) -> None:
self.code = ""
Expand All @@ -51,6 +63,8 @@ def __init__(self, bmg: BMGraphBuilder) -> None:
self.no_dist_samples = defaultdict(lambda: 0)
self.queries = []
self.observations = []
self.node_to_rv_id = {}
self.node_to_query_map = {}

def _get_node_id_mapping(self, node: bn.BMGNode) -> str:
if node in self.node_to_var_id:
Expand Down Expand Up @@ -132,12 +146,17 @@ def _add_sample(self, node: bn.SampleNode) -> None:
total_samples = self._no_dist_samples(node.operand)
if total_samples > 1:
param = f"{self.no_dist_samples[node.operand]}"
self._code.append(f"v{var_id} = rv{rv_id}({param},)")
self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param},)"
else:
param = ""
self._code.append(f"v{var_id} = rv{rv_id}({param})")
self._code.append(f"v{var_id} = rv{rv_id}({param})")
self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param})"

def _add_query(self, node: bn.Query) -> None:
self.queries.append(f"{self._get_node_id_mapping(node.operator)}")
query_id = self._get_node_id_mapping(node.operator)
self.node_to_query_map[self.node_to_rv_id[query_id]] = node.rv_identifier
self.queries.append(f"{query_id}")

def _add_observation(self, node: bn.Observation) -> None:
val = node.value
Expand All @@ -163,18 +182,34 @@ def _generate_python(self, node: bn.BMGNode) -> None:
elif isinstance(node, bn.Observation):
self._add_observation(node)

def _generate_bm_python(self) -> str:
def _generate_bm_python(
self, inference_type, infer_config
) -> Tuple[str, Dict[str, RVIdentifier]]:
bmg, error_report = fix_problems(self.bmg)
self.bmg = bmg
error_report.raise_errors()
self._code.append(
f"from {inference_type.value.__module__} import {inference_type.value.__name__}"
)
for node in self.bmg.all_ancestor_nodes():
self._generate_python(node)
self._code.append(f"queries = [{(','.join(self.queries))}]")
self._code.append(f"observations = {{{','.join(self.observations)}}}")
self._code.append(f"opt_queries = [{(','.join(self.queries))}]")
self._code.append(f"opt_observations = {{{','.join(self.observations)}}}")
self._code.append(
f"""samples = {inference_type.value.__name__}().infer(
opt_queries,
opt_observations,
num_samples={infer_config['num_samples']},
num_chains={infer_config['num_chains']},
num_adaptive_samples={infer_config['num_adaptive_samples']}
)"""
)
self.code = "\n".join(self._code)
return self.code
return self.code, self.node_to_query_map


def to_bm_python(bmg: BMGraphBuilder) -> str:
def to_bm_python(
bmg: BMGraphBuilder, inference_type: InferenceType, infer_config: Dict
) -> Tuple[str, Dict[str, RVIdentifier]]:
bmp = ToBMPython(bmg)
return bmp._generate_bm_python()
return bmp._generate_bm_python(inference_type, infer_config)
209 changes: 209 additions & 0 deletions src/beanmachine/ppl/inference/bm_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""An inference engine which uses Bean Machine to make
inferences on optimized Bean Machine models."""

from typing import Dict, List, Set

import graphviz
import torch

from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations
from beanmachine.ppl.compiler.gen_bm_python import InferenceType, to_bm_python
from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph
from beanmachine.ppl.compiler.gen_dot import to_dot
from beanmachine.ppl.compiler.gen_mini import to_mini
from beanmachine.ppl.compiler.runtime import BMGRuntime
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples
from beanmachine.ppl.inference.utils import _verify_queries_and_observations
from beanmachine.ppl.model.rv_identifier import RVIdentifier


class BMInference:
"""
Interface to Bean Machine Inference on optimized models.

Please note that this is a highly experimental implementation under active
development, and that the subset of Bean Machine model is limited. Limitations
include that the runtime graph should be static (meaning, it does not change
during inference), and that the types of primitive distributions supported
is currently limited.
"""

_fix_observe_true: bool = False
_infer_config = {}

def __init__(self):
pass

def _accumulate_graph(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
) -> BMGRuntime:
_verify_queries_and_observations(queries, observations, True)
rt = BMGRuntime()
bmg = rt.accumulate_graph(queries, observations)
# TODO: Figure out a better way to pass this flag around
bmg._fix_observe_true = self._fix_observe_true
return rt

def _build_mcsamples(
self,
queries,
opt_rv_to_query_map,
samples,
) -> MonteCarloSamples:
assert len(samples) == len(queries)

results: Dict[RVIdentifier, torch.Tensor] = {}
for rv in samples.keys():
query = opt_rv_to_query_map[rv.__str__()]
results[query] = samples[rv]
mcsamples = MonteCarloSamples(results)
return mcsamples

def _infer(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
num_samples: int,
num_chains: int = 1,
num_adaptive_samples: int = 0,
inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler,
skip_optimizations: Set[str] = default_skip_optimizations,
) -> MonteCarloSamples:

rt = self._accumulate_graph(queries, observations)
bmg = rt._bmg

self._infer_config["num_samples"] = num_samples
self._infer_config["num_chains"] = num_chains
self._infer_config["num_adaptive_samples"] = num_adaptive_samples

generated_graph = to_bmg_graph(bmg, skip_optimizations)
optimized_python, opt_rv_to_query_map = to_bm_python(
generated_graph.bmg, inference_type, self._infer_config
)

try:
exec(optimized_python, globals()) # noqa
except RuntimeError as e:
raise RuntimeError("Error during BM inference\n") from e

opt_samples = self._build_mcsamples(
queries,
opt_rv_to_query_map,
# pyre-ignore
samples, # noqa
)
return opt_samples

def infer(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
num_samples: int,
num_chains: int = 4,
num_adaptive_samples: int = 0,
inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler,
skip_optimizations: Set[str] = default_skip_optimizations,
) -> MonteCarloSamples:
"""
Perform inference by (runtime) compilation of Python source code associated
with its parameters, constructing an optimized BM graph, and then calling the
BM implementation of a particular inference method on this graph.

Args:
queries: queried random variables
observations: observations dict
num_samples: number of samples in each chain
num_chains: number of chains generated
num_adaptive_samples: number of burn in samples to discard
inference_type: inference method
skip_optimizations: list of optimization to disable in this call

Returns:
MonteCarloSamples: The requested samples
"""
# TODO: Add verbose level
# TODO: Add logging
samples = self._infer(
queries,
observations,
num_samples,
num_chains,
num_adaptive_samples,
inference_type,
skip_optimizations,
)
return samples

def to_dot(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
after_transform: bool = True,
label_edges: bool = False,
skip_optimizations: Set[str] = default_skip_optimizations,
) -> str:
"""Produce a string containing a program in the GraphViz DOT language
representing the graph deduced from the model."""
node_types = False
node_sizes = False
edge_requirements = False
bmg = self._accumulate_graph(queries, observations)._bmg
return to_dot(
bmg,
node_types,
node_sizes,
edge_requirements,
after_transform,
label_edges,
skip_optimizations,
)

def _to_mini(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
indent=None,
) -> str:
"""Internal test method for Neal's MiniBMG prototype."""
bmg = self._accumulate_graph(queries, observations)._bmg
return to_mini(bmg, indent)

def to_graphviz(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
after_transform: bool = True,
label_edges: bool = False,
skip_optimizations: Set[str] = default_skip_optimizations,
) -> graphviz.Source:
"""Small wrapper to generate an actual graphviz object"""
s = self.to_dot(
queries, observations, after_transform, label_edges, skip_optimizations
)
return graphviz.Source(s)

def to_python(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
num_samples: int,
num_chains: int = 4,
num_adaptive_samples: int = 0,
inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler,
skip_optimizations: Set[str] = default_skip_optimizations,
) -> str:
"""Produce a string containing a BM Python program from the graph."""
bmg = self._accumulate_graph(queries, observations)._bmg
self._infer_config["num_samples"] = num_samples
self._infer_config["num_chains"] = num_chains
self._infer_config["num_adaptive_samples"] = num_adaptive_samples
opt_bm, _ = to_bm_python(bmg, inference_type, self._infer_config)
return opt_bm
12 changes: 1 addition & 11 deletions src/beanmachine/ppl/inference/bmg_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from beanmachine.ppl.compiler.bm_graph_builder import rv_to_query
from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations
from beanmachine.ppl.compiler.gen_bm_python import to_bm_python
from beanmachine.ppl.compiler.gen_bmg_cpp import to_bmg_cpp
from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph
from beanmachine.ppl.compiler.gen_bmg_python import to_bmg_python
Expand Down Expand Up @@ -167,7 +166,7 @@ def _build_mcsamples(
# but it requires the input to be of a different type in the
# cases of num_chains==1 and !=1 respectively. Furthermore,
# we had to tweak it to support the right operator for merging
# saumple values when num_chains!=1.
# sample values when num_chains!=1.
if num_chains == 1:
mcsamples = MonteCarloSamples(
results[0], num_adaptive_samples, stack_not_cat=True
Expand Down Expand Up @@ -365,15 +364,6 @@ def to_python(
bmg = self._accumulate_graph(queries, observations)._bmg
return to_bmg_python(bmg).code

def to_bm_python(
self,
queries: List[RVIdentifier],
observations: Dict[RVIdentifier, torch.Tensor],
) -> str:
"""Produce a string containing a BM Python program from the graph."""
bmg = self._accumulate_graph(queries, observations)._bmg
return to_bm_python(bmg)

def to_graph(
self,
queries: List[RVIdentifier],
Expand Down
Loading