diff --git a/src/beanmachine/ppl/compiler/gen_bm_python.py b/src/beanmachine/ppl/compiler/gen_bm_python.py index 0cb192101f..11c6a71ee4 100644 --- a/src/beanmachine/ppl/compiler/gen_bm_python.py +++ b/src/beanmachine/ppl/compiler/gen_bm_python.py @@ -5,13 +5,23 @@ from collections import defaultdict -from typing import Dict, List +from enum import Enum +from typing import Dict, List, Tuple from beanmachine.ppl.compiler import bmg_nodes as bn from beanmachine.ppl.compiler.bm_graph_builder import BMGraphBuilder from beanmachine.ppl.compiler.fix_problems import fix_problems from beanmachine.ppl.compiler.internal_error import InternalError +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler +from beanmachine.ppl.inference.single_site_nmc import SingleSiteNewtonianMonteCarlo +from beanmachine.ppl.model.rv_identifier import RVIdentifier + + +class InferenceType(Enum): + SingleSiteNewtonianMonteCarlo = SingleSiteNewtonianMonteCarlo + GlobalNoUTurnSampler = GlobalNoUTurnSampler + _node_type_to_distribution = { bn.BernoulliNode: "torch.distributions.Bernoulli", @@ -37,6 +47,8 @@ class ToBMPython: no_dist_samples: Dict[bn.BMGNode, int] queries: List[str] observations: List[str] + node_to_rv_id: Dict[str, str] + node_to_query_map: Dict[str, RVIdentifier] def __init__(self, bmg: BMGraphBuilder) -> None: self.code = "" @@ -51,6 +63,8 @@ def __init__(self, bmg: BMGraphBuilder) -> None: self.no_dist_samples = defaultdict(lambda: 0) self.queries = [] self.observations = [] + self.node_to_rv_id = {} + self.node_to_query_map = {} def _get_node_id_mapping(self, node: bn.BMGNode) -> str: if node in self.node_to_var_id: @@ -132,12 +146,17 @@ def _add_sample(self, node: bn.SampleNode) -> None: total_samples = self._no_dist_samples(node.operand) if total_samples > 1: param = f"{self.no_dist_samples[node.operand]}" + self._code.append(f"v{var_id} = rv{rv_id}({param},)") + self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param},)" else: param = "" - self._code.append(f"v{var_id} = rv{rv_id}({param})") + self._code.append(f"v{var_id} = rv{rv_id}({param})") + self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param})" def _add_query(self, node: bn.Query) -> None: - self.queries.append(f"{self._get_node_id_mapping(node.operator)}") + query_id = self._get_node_id_mapping(node.operator) + self.node_to_query_map[self.node_to_rv_id[query_id]] = node.rv_identifier + self.queries.append(f"{query_id}") def _add_observation(self, node: bn.Observation) -> None: val = node.value @@ -163,18 +182,34 @@ def _generate_python(self, node: bn.BMGNode) -> None: elif isinstance(node, bn.Observation): self._add_observation(node) - def _generate_bm_python(self) -> str: + def _generate_bm_python( + self, inference_type, infer_config + ) -> Tuple[str, Dict[str, RVIdentifier]]: bmg, error_report = fix_problems(self.bmg) self.bmg = bmg error_report.raise_errors() + self._code.append( + f"from {inference_type.value.__module__} import {inference_type.value.__name__}" + ) for node in self.bmg.all_ancestor_nodes(): self._generate_python(node) - self._code.append(f"queries = [{(','.join(self.queries))}]") - self._code.append(f"observations = {{{','.join(self.observations)}}}") + self._code.append(f"opt_queries = [{(','.join(self.queries))}]") + self._code.append(f"opt_observations = {{{','.join(self.observations)}}}") + self._code.append( + f"""samples = {inference_type.value.__name__}().infer( + opt_queries, + opt_observations, + num_samples={infer_config['num_samples']}, + num_chains={infer_config['num_chains']}, + num_adaptive_samples={infer_config['num_adaptive_samples']} + )""" + ) self.code = "\n".join(self._code) - return self.code + return self.code, self.node_to_query_map -def to_bm_python(bmg: BMGraphBuilder) -> str: +def to_bm_python( + bmg: BMGraphBuilder, inference_type: InferenceType, infer_config: Dict +) -> Tuple[str, Dict[str, RVIdentifier]]: bmp = ToBMPython(bmg) - return bmp._generate_bm_python() + return bmp._generate_bm_python(inference_type, infer_config) diff --git a/src/beanmachine/ppl/inference/bm_inference.py b/src/beanmachine/ppl/inference/bm_inference.py new file mode 100644 index 0000000000..198b0c9074 --- /dev/null +++ b/src/beanmachine/ppl/inference/bm_inference.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""An inference engine which uses Bean Machine to make +inferences on optimized Bean Machine models.""" + +from typing import Dict, List, Set + +import graphviz +import torch + +from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations +from beanmachine.ppl.compiler.gen_bm_python import InferenceType, to_bm_python +from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph +from beanmachine.ppl.compiler.gen_dot import to_dot +from beanmachine.ppl.compiler.gen_mini import to_mini +from beanmachine.ppl.compiler.runtime import BMGRuntime +from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples +from beanmachine.ppl.inference.utils import _verify_queries_and_observations +from beanmachine.ppl.model.rv_identifier import RVIdentifier + + +class BMInference: + """ + Interface to Bean Machine Inference on optimized models. + + Please note that this is a highly experimental implementation under active + development, and that the subset of Bean Machine model is limited. Limitations + include that the runtime graph should be static (meaning, it does not change + during inference), and that the types of primitive distributions supported + is currently limited. + """ + + _fix_observe_true: bool = False + _infer_config = {} + + def __init__(self): + pass + + def _accumulate_graph( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + ) -> BMGRuntime: + _verify_queries_and_observations(queries, observations, True) + rt = BMGRuntime() + bmg = rt.accumulate_graph(queries, observations) + # TODO: Figure out a better way to pass this flag around + bmg._fix_observe_true = self._fix_observe_true + return rt + + def _build_mcsamples( + self, + queries, + opt_rv_to_query_map, + samples, + ) -> MonteCarloSamples: + assert len(samples) == len(queries) + + results: Dict[RVIdentifier, torch.Tensor] = {} + for rv in samples.keys(): + query = opt_rv_to_query_map[rv.__str__()] + results[query] = samples[rv] + mcsamples = MonteCarloSamples(results) + return mcsamples + + def _infer( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + num_samples: int, + num_chains: int = 1, + num_adaptive_samples: int = 0, + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> MonteCarloSamples: + + rt = self._accumulate_graph(queries, observations) + bmg = rt._bmg + + self._infer_config["num_samples"] = num_samples + self._infer_config["num_chains"] = num_chains + self._infer_config["num_adaptive_samples"] = num_adaptive_samples + + generated_graph = to_bmg_graph(bmg, skip_optimizations) + optimized_python, opt_rv_to_query_map = to_bm_python( + generated_graph.bmg, inference_type, self._infer_config + ) + + try: + exec(optimized_python, globals()) # noqa + except RuntimeError as e: + raise RuntimeError("Error during BM inference\n") from e + + opt_samples = self._build_mcsamples( + queries, + opt_rv_to_query_map, + # pyre-ignore + samples, # noqa + ) + return opt_samples + + def infer( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + num_samples: int, + num_chains: int = 4, + num_adaptive_samples: int = 0, + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> MonteCarloSamples: + """ + Perform inference by (runtime) compilation of Python source code associated + with its parameters, constructing an optimized BM graph, and then calling the + BM implementation of a particular inference method on this graph. + + Args: + queries: queried random variables + observations: observations dict + num_samples: number of samples in each chain + num_chains: number of chains generated + num_adaptive_samples: number of burn in samples to discard + inference_type: inference method + skip_optimizations: list of optimization to disable in this call + + Returns: + MonteCarloSamples: The requested samples + """ + # TODO: Add verbose level + # TODO: Add logging + samples = self._infer( + queries, + observations, + num_samples, + num_chains, + num_adaptive_samples, + inference_type, + skip_optimizations, + ) + return samples + + def to_dot( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + after_transform: bool = True, + label_edges: bool = False, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> str: + """Produce a string containing a program in the GraphViz DOT language + representing the graph deduced from the model.""" + node_types = False + node_sizes = False + edge_requirements = False + bmg = self._accumulate_graph(queries, observations)._bmg + return to_dot( + bmg, + node_types, + node_sizes, + edge_requirements, + after_transform, + label_edges, + skip_optimizations, + ) + + def _to_mini( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + indent=None, + ) -> str: + """Internal test method for Neal's MiniBMG prototype.""" + bmg = self._accumulate_graph(queries, observations)._bmg + return to_mini(bmg, indent) + + def to_graphviz( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + after_transform: bool = True, + label_edges: bool = False, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> graphviz.Source: + """Small wrapper to generate an actual graphviz object""" + s = self.to_dot( + queries, observations, after_transform, label_edges, skip_optimizations + ) + return graphviz.Source(s) + + def to_python( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, + ) -> str: + """Produce a string containing a BM Python program from the graph.""" + bmg = self._accumulate_graph(queries, observations)._bmg + self._infer_config["num_samples"] = 0 + self._infer_config["num_chains"] = 0 + self._infer_config["num_adaptive_samples"] = 0 + opt_bm, _ = to_bm_python(bmg, inference_type, self._infer_config) + return opt_bm diff --git a/src/beanmachine/ppl/inference/bmg_inference.py b/src/beanmachine/ppl/inference/bmg_inference.py index dae1ebbc96..4ddab32876 100644 --- a/src/beanmachine/ppl/inference/bmg_inference.py +++ b/src/beanmachine/ppl/inference/bmg_inference.py @@ -16,7 +16,6 @@ from beanmachine.ppl.compiler.bm_graph_builder import rv_to_query from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations -from beanmachine.ppl.compiler.gen_bm_python import to_bm_python from beanmachine.ppl.compiler.gen_bmg_cpp import to_bmg_cpp from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph from beanmachine.ppl.compiler.gen_bmg_python import to_bmg_python @@ -167,7 +166,7 @@ def _build_mcsamples( # but it requires the input to be of a different type in the # cases of num_chains==1 and !=1 respectively. Furthermore, # we had to tweak it to support the right operator for merging - # saumple values when num_chains!=1. + # sample values when num_chains!=1. if num_chains == 1: mcsamples = MonteCarloSamples( results[0], num_adaptive_samples, stack_not_cat=True @@ -365,15 +364,6 @@ def to_python( bmg = self._accumulate_graph(queries, observations)._bmg return to_bmg_python(bmg).code - def to_bm_python( - self, - queries: List[RVIdentifier], - observations: Dict[RVIdentifier, torch.Tensor], - ) -> str: - """Produce a string containing a BM Python program from the graph.""" - bmg = self._accumulate_graph(queries, observations)._bmg - return to_bm_python(bmg) - def to_graph( self, queries: List[RVIdentifier], diff --git a/tests/ppl/compiler/gen_bm_python_test.py b/tests/ppl/compiler/gen_bm_python_test.py index 6b329d307b..68f6903c87 100644 --- a/tests/ppl/compiler/gen_bm_python_test.py +++ b/tests/ppl/compiler/gen_bm_python_test.py @@ -6,7 +6,8 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.inference.bmg_inference import BMGInference +from beanmachine.ppl.inference.bm_inference import BMInference +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler from torch import tensor from torch.distributions import Bernoulli, Beta, Normal @@ -41,10 +42,11 @@ def test_gen_bm_python_simple(self) -> None: flip(2): tensor(1.0), flip(3): tensor(0.0), } - observed = BMGInference().to_bm_python(queries, observations) + observed = BMInference().to_python(queries, observations) expected = """ import beanmachine.ppl as bm import torch +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler v0 = 2.0 @bm.random_variable def rv0(): @@ -53,25 +55,43 @@ def rv0(): @bm.random_variable def rv1(i): \treturn torch.distributions.Bernoulli(v1.wrapper(*v1.arguments)) -v2 = rv1(1) -v3 = rv1(2) -v4 = rv1(3) -v5 = rv1(4) -queries = [v1] -observations = {v2 : torch.tensor(0.0),v3 : torch.tensor(0.0),v4 : torch.tensor(1.0),v5 : torch.tensor(0.0)} +v2 = rv1(1,) +v3 = rv1(2,) +v4 = rv1(3,) +v5 = rv1(4,) +opt_queries = [v1] +opt_observations = {v2 : torch.tensor(0.0),v3 : torch.tensor(0.0),v4 : torch.tensor(1.0),v5 : torch.tensor(0.0)} +samples = GlobalNoUTurnSampler().infer( + opt_queries, + opt_observations, + num_samples=0, + num_chains=0, + num_adaptive_samples=0 + ) """ self.assertEqual(expected.strip(), observed.strip()) + observed_samples = BMInference().infer( + queries, observations, num_samples=1000, num_chains=1 + ) + expected_samples = GlobalNoUTurnSampler().infer( + queries, observations, num_samples=1000, num_chains=1 + ) + observed_mean = observed_samples[beta()].mean() + expected_mean = expected_samples[beta()].mean() + self.assertAlmostEqual(expected_mean, observed_mean, delta=0.05) + def test_gen_bm_python_rv_operations(self) -> None: self.maxDiff = None queries = [beta(), normal(0), normal(1)] observations = { flip_2(0): tensor(0.0), } - observed = BMGInference().to_bm_python(queries, observations) + observed = BMInference().to_python(queries, observations) expected = """ import beanmachine.ppl as bm import torch +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler v0 = 2.0 @bm.random_variable def rv0(): @@ -84,7 +104,7 @@ def f3(): @bm.random_variable def rv1(i): \treturn torch.distributions.Bernoulli(f3()) -v4 = rv1(1) +v4 = rv1(1,) @bm.functional def f5(): \treturn (v4.wrapper(*v4.arguments)) @@ -93,7 +113,7 @@ def f5(): def rv2(): \treturn torch.distributions.Normal(f5(), v6) v7 = rv2() -v8 = rv1(2) +v8 = rv1(2,) @bm.functional def f9(): \treturn (v8.wrapper(*v8.arguments)) @@ -101,7 +121,14 @@ def f9(): def rv3(): \treturn torch.distributions.Normal(f9(), v6) v10 = rv3() -queries = [v1,v7,v10] -observations = {v4 : torch.tensor(0.0)} +opt_queries = [v1,v7,v10] +opt_observations = {v4 : torch.tensor(0.0)} +samples = GlobalNoUTurnSampler().infer( + opt_queries, + opt_observations, + num_samples=0, + num_chains=0, + num_adaptive_samples=0 + ) """ self.assertEqual(expected.strip(), observed.strip())