From 86990bc5a23c924266020172c21abe1768c75fb4 Mon Sep 17 00:00:00 2001 From: Aishwarya Sivaraman Date: Tue, 1 Nov 2022 11:10:47 -0700 Subject: [PATCH] BMInference API to compile and run BM models from Beanstalk (#1801) Summary: Pull Request resolved: https://github.com/facebookresearch/beanmachine/pull/1801 We currently have `BMGInference`, that takes a BM python program produces a valid BMG graph and runs inference on the BMG C++ backend. In the same vein, this diff implements a `BMInference` API that takes a BM python program, optimizes it (currently uses the same optimization path as BMG) and runs BM inference methods to compute the samples. This allows us to separate the type of optimizations each backend needs and allows us to test end-to-end with an inference method. Differential Revision: D40853055 fbshipit-source-id: 8b5ce58385c5fe17e0a688e428261f3276e9b03b --- src/beanmachine/ppl/compiler/gen_bm_python.py | 53 ++++- src/beanmachine/ppl/inference/bm_inference.py | 209 ++++++++++++++++++ .../ppl/inference/bmg_inference.py | 12 +- tests/ppl/compiler/gen_bm_python_test.py | 67 ++++-- 4 files changed, 308 insertions(+), 33 deletions(-) create mode 100644 src/beanmachine/ppl/inference/bm_inference.py diff --git a/src/beanmachine/ppl/compiler/gen_bm_python.py b/src/beanmachine/ppl/compiler/gen_bm_python.py index 0cb192101f..11c6a71ee4 100644 --- a/src/beanmachine/ppl/compiler/gen_bm_python.py +++ b/src/beanmachine/ppl/compiler/gen_bm_python.py @@ -5,13 +5,23 @@ from collections import defaultdict -from typing import Dict, List +from enum import Enum +from typing import Dict, List, Tuple from beanmachine.ppl.compiler import bmg_nodes as bn from beanmachine.ppl.compiler.bm_graph_builder import BMGraphBuilder from beanmachine.ppl.compiler.fix_problems import fix_problems from beanmachine.ppl.compiler.internal_error import InternalError +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler +from beanmachine.ppl.inference.single_site_nmc import SingleSiteNewtonianMonteCarlo +from beanmachine.ppl.model.rv_identifier import RVIdentifier + + +class InferenceType(Enum): + SingleSiteNewtonianMonteCarlo = SingleSiteNewtonianMonteCarlo + GlobalNoUTurnSampler = GlobalNoUTurnSampler + _node_type_to_distribution = { bn.BernoulliNode: "torch.distributions.Bernoulli", @@ -37,6 +47,8 @@ class ToBMPython: no_dist_samples: Dict[bn.BMGNode, int] queries: List[str] observations: List[str] + node_to_rv_id: Dict[str, str] + node_to_query_map: Dict[str, RVIdentifier] def __init__(self, bmg: BMGraphBuilder) -> None: self.code = "" @@ -51,6 +63,8 @@ def __init__(self, bmg: BMGraphBuilder) -> None: self.no_dist_samples = defaultdict(lambda: 0) self.queries = [] self.observations = [] + self.node_to_rv_id = {} + self.node_to_query_map = {} def _get_node_id_mapping(self, node: bn.BMGNode) -> str: if node in self.node_to_var_id: @@ -132,12 +146,17 @@ def _add_sample(self, node: bn.SampleNode) -> None: total_samples = self._no_dist_samples(node.operand) if total_samples > 1: param = f"{self.no_dist_samples[node.operand]}" + self._code.append(f"v{var_id} = rv{rv_id}({param},)") + self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param},)" else: param = "" - self._code.append(f"v{var_id} = rv{rv_id}({param})") + self._code.append(f"v{var_id} = rv{rv_id}({param})") + self.node_to_rv_id[f"v{var_id}"] = f"rv{rv_id}({param})" def _add_query(self, node: bn.Query) -> None: - self.queries.append(f"{self._get_node_id_mapping(node.operator)}") + query_id = self._get_node_id_mapping(node.operator) + self.node_to_query_map[self.node_to_rv_id[query_id]] = node.rv_identifier + self.queries.append(f"{query_id}") def _add_observation(self, node: bn.Observation) -> None: val = node.value @@ -163,18 +182,34 @@ def _generate_python(self, node: bn.BMGNode) -> None: elif isinstance(node, bn.Observation): self._add_observation(node) - def _generate_bm_python(self) -> str: + def _generate_bm_python( + self, inference_type, infer_config + ) -> Tuple[str, Dict[str, RVIdentifier]]: bmg, error_report = fix_problems(self.bmg) self.bmg = bmg error_report.raise_errors() + self._code.append( + f"from {inference_type.value.__module__} import {inference_type.value.__name__}" + ) for node in self.bmg.all_ancestor_nodes(): self._generate_python(node) - self._code.append(f"queries = [{(','.join(self.queries))}]") - self._code.append(f"observations = {{{','.join(self.observations)}}}") + self._code.append(f"opt_queries = [{(','.join(self.queries))}]") + self._code.append(f"opt_observations = {{{','.join(self.observations)}}}") + self._code.append( + f"""samples = {inference_type.value.__name__}().infer( + opt_queries, + opt_observations, + num_samples={infer_config['num_samples']}, + num_chains={infer_config['num_chains']}, + num_adaptive_samples={infer_config['num_adaptive_samples']} + )""" + ) self.code = "\n".join(self._code) - return self.code + return self.code, self.node_to_query_map -def to_bm_python(bmg: BMGraphBuilder) -> str: +def to_bm_python( + bmg: BMGraphBuilder, inference_type: InferenceType, infer_config: Dict +) -> Tuple[str, Dict[str, RVIdentifier]]: bmp = ToBMPython(bmg) - return bmp._generate_bm_python() + return bmp._generate_bm_python(inference_type, infer_config) diff --git a/src/beanmachine/ppl/inference/bm_inference.py b/src/beanmachine/ppl/inference/bm_inference.py new file mode 100644 index 0000000000..7eee7fe704 --- /dev/null +++ b/src/beanmachine/ppl/inference/bm_inference.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""An inference engine which uses Bean Machine to make +inferences on optimized Bean Machine models.""" + +from typing import Dict, List, Set + +import graphviz +import torch + +from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations +from beanmachine.ppl.compiler.gen_bm_python import InferenceType, to_bm_python +from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph +from beanmachine.ppl.compiler.gen_dot import to_dot +from beanmachine.ppl.compiler.gen_mini import to_mini +from beanmachine.ppl.compiler.runtime import BMGRuntime +from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples +from beanmachine.ppl.inference.utils import _verify_queries_and_observations +from beanmachine.ppl.model.rv_identifier import RVIdentifier + + +class BMInference: + """ + Interface to Bean Machine Inference on optimized models. + + Please note that this is a highly experimental implementation under active + development, and that the subset of Bean Machine model is limited. Limitations + include that the runtime graph should be static (meaning, it does not change + during inference), and that the types of primitive distributions supported + is currently limited. + """ + + _fix_observe_true: bool = False + _infer_config = {} + + def __init__(self): + pass + + def _accumulate_graph( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + ) -> BMGRuntime: + _verify_queries_and_observations(queries, observations, True) + rt = BMGRuntime() + bmg = rt.accumulate_graph(queries, observations) + # TODO: Figure out a better way to pass this flag around + bmg._fix_observe_true = self._fix_observe_true + return rt + + def _build_mcsamples( + self, + queries, + opt_rv_to_query_map, + samples, + ) -> MonteCarloSamples: + assert len(samples) == len(queries) + + results: Dict[RVIdentifier, torch.Tensor] = {} + for rv in samples.keys(): + query = opt_rv_to_query_map[rv.__str__()] + results[query] = samples[rv] + mcsamples = MonteCarloSamples(results) + return mcsamples + + def _infer( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + num_samples: int, + num_chains: int = 1, + num_adaptive_samples: int = 0, + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> MonteCarloSamples: + + rt = self._accumulate_graph(queries, observations) + bmg = rt._bmg + + self._infer_config["num_samples"] = num_samples + self._infer_config["num_chains"] = num_chains + self._infer_config["num_adaptive_samples"] = num_adaptive_samples + + generated_graph = to_bmg_graph(bmg, skip_optimizations) + optimized_python, opt_rv_to_query_map = to_bm_python( + generated_graph.bmg, inference_type, self._infer_config + ) + + try: + exec(optimized_python, globals()) # noqa + except RuntimeError as e: + raise RuntimeError("Error during BM inference\n") from e + + opt_samples = self._build_mcsamples( + queries, + opt_rv_to_query_map, + # pyre-ignore + samples, # noqa + ) + return opt_samples + + def infer( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + num_samples: int, + num_chains: int = 4, + num_adaptive_samples: int = 0, + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> MonteCarloSamples: + """ + Perform inference by (runtime) compilation of Python source code associated + with its parameters, constructing an optimized BM graph, and then calling the + BM implementation of a particular inference method on this graph. + + Args: + queries: queried random variables + observations: observations dict + num_samples: number of samples in each chain + num_chains: number of chains generated + num_adaptive_samples: number of burn in samples to discard + inference_type: inference method + skip_optimizations: list of optimization to disable in this call + + Returns: + MonteCarloSamples: The requested samples + """ + # TODO: Add verbose level + # TODO: Add logging + samples = self._infer( + queries, + observations, + num_samples, + num_chains, + num_adaptive_samples, + inference_type, + skip_optimizations, + ) + return samples + + def to_dot( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + after_transform: bool = True, + label_edges: bool = False, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> str: + """Produce a string containing a program in the GraphViz DOT language + representing the graph deduced from the model.""" + node_types = False + node_sizes = False + edge_requirements = False + bmg = self._accumulate_graph(queries, observations)._bmg + return to_dot( + bmg, + node_types, + node_sizes, + edge_requirements, + after_transform, + label_edges, + skip_optimizations, + ) + + def _to_mini( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + indent=None, + ) -> str: + """Internal test method for Neal's MiniBMG prototype.""" + bmg = self._accumulate_graph(queries, observations)._bmg + return to_mini(bmg, indent) + + def to_graphviz( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + after_transform: bool = True, + label_edges: bool = False, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> graphviz.Source: + """Small wrapper to generate an actual graphviz object""" + s = self.to_dot( + queries, observations, after_transform, label_edges, skip_optimizations + ) + return graphviz.Source(s) + + def to_python( + self, + queries: List[RVIdentifier], + observations: Dict[RVIdentifier, torch.Tensor], + num_samples: int, + num_chains: int = 4, + num_adaptive_samples: int = 0, + inference_type: InferenceType = InferenceType.GlobalNoUTurnSampler, + skip_optimizations: Set[str] = default_skip_optimizations, + ) -> str: + """Produce a string containing a BM Python program from the graph.""" + bmg = self._accumulate_graph(queries, observations)._bmg + self._infer_config["num_samples"] = num_samples + self._infer_config["num_chains"] = num_chains + self._infer_config["num_adaptive_samples"] = num_adaptive_samples + opt_bm, _ = to_bm_python(bmg, inference_type, self._infer_config) + return opt_bm diff --git a/src/beanmachine/ppl/inference/bmg_inference.py b/src/beanmachine/ppl/inference/bmg_inference.py index dae1ebbc96..4ddab32876 100644 --- a/src/beanmachine/ppl/inference/bmg_inference.py +++ b/src/beanmachine/ppl/inference/bmg_inference.py @@ -16,7 +16,6 @@ from beanmachine.ppl.compiler.bm_graph_builder import rv_to_query from beanmachine.ppl.compiler.fix_problems import default_skip_optimizations -from beanmachine.ppl.compiler.gen_bm_python import to_bm_python from beanmachine.ppl.compiler.gen_bmg_cpp import to_bmg_cpp from beanmachine.ppl.compiler.gen_bmg_graph import to_bmg_graph from beanmachine.ppl.compiler.gen_bmg_python import to_bmg_python @@ -167,7 +166,7 @@ def _build_mcsamples( # but it requires the input to be of a different type in the # cases of num_chains==1 and !=1 respectively. Furthermore, # we had to tweak it to support the right operator for merging - # saumple values when num_chains!=1. + # sample values when num_chains!=1. if num_chains == 1: mcsamples = MonteCarloSamples( results[0], num_adaptive_samples, stack_not_cat=True @@ -365,15 +364,6 @@ def to_python( bmg = self._accumulate_graph(queries, observations)._bmg return to_bmg_python(bmg).code - def to_bm_python( - self, - queries: List[RVIdentifier], - observations: Dict[RVIdentifier, torch.Tensor], - ) -> str: - """Produce a string containing a BM Python program from the graph.""" - bmg = self._accumulate_graph(queries, observations)._bmg - return to_bm_python(bmg) - def to_graph( self, queries: List[RVIdentifier], diff --git a/tests/ppl/compiler/gen_bm_python_test.py b/tests/ppl/compiler/gen_bm_python_test.py index 6b329d307b..8778800a4d 100644 --- a/tests/ppl/compiler/gen_bm_python_test.py +++ b/tests/ppl/compiler/gen_bm_python_test.py @@ -6,7 +6,8 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.inference.bmg_inference import BMGInference +from beanmachine.ppl.inference.bm_inference import BMInference +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler from torch import tensor from torch.distributions import Bernoulli, Beta, Normal @@ -41,10 +42,17 @@ def test_gen_bm_python_simple(self) -> None: flip(2): tensor(1.0), flip(3): tensor(0.0), } - observed = BMGInference().to_bm_python(queries, observations) + observed = BMInference().to_python( + queries, + observations, + num_samples=1000, + num_chains=1, + num_adaptive_samples=500, + ) expected = """ import beanmachine.ppl as bm import torch +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler v0 = 2.0 @bm.random_variable def rv0(): @@ -53,25 +61,51 @@ def rv0(): @bm.random_variable def rv1(i): \treturn torch.distributions.Bernoulli(v1.wrapper(*v1.arguments)) -v2 = rv1(1) -v3 = rv1(2) -v4 = rv1(3) -v5 = rv1(4) -queries = [v1] -observations = {v2 : torch.tensor(0.0),v3 : torch.tensor(0.0),v4 : torch.tensor(1.0),v5 : torch.tensor(0.0)} +v2 = rv1(1,) +v3 = rv1(2,) +v4 = rv1(3,) +v5 = rv1(4,) +opt_queries = [v1] +opt_observations = {v2 : torch.tensor(0.0),v3 : torch.tensor(0.0),v4 : torch.tensor(1.0),v5 : torch.tensor(0.0)} +samples = GlobalNoUTurnSampler().infer( + opt_queries, + opt_observations, + num_samples=1000, + num_chains=1, + num_adaptive_samples=500 + ) """ self.assertEqual(expected.strip(), observed.strip()) + observed_samples = BMInference().infer( + queries, + observations, + num_samples=1000, + num_chains=1, + num_adaptive_samples=500, + ) + expected_samples = GlobalNoUTurnSampler().infer( + queries, + observations, + num_samples=1000, + num_chains=1, + num_adaptive_samples=500, + ) + observed_mean = observed_samples[beta()].mean() + expected_mean = expected_samples[beta()].mean() + self.assertAlmostEqual(expected_mean, observed_mean, delta=0.05) + def test_gen_bm_python_rv_operations(self) -> None: self.maxDiff = None queries = [beta(), normal(0), normal(1)] observations = { flip_2(0): tensor(0.0), } - observed = BMGInference().to_bm_python(queries, observations) + observed = BMInference().to_python(queries, observations, num_samples=1000) expected = """ import beanmachine.ppl as bm import torch +from beanmachine.ppl.inference.nuts_inference import GlobalNoUTurnSampler v0 = 2.0 @bm.random_variable def rv0(): @@ -84,7 +118,7 @@ def f3(): @bm.random_variable def rv1(i): \treturn torch.distributions.Bernoulli(f3()) -v4 = rv1(1) +v4 = rv1(1,) @bm.functional def f5(): \treturn (v4.wrapper(*v4.arguments)) @@ -93,7 +127,7 @@ def f5(): def rv2(): \treturn torch.distributions.Normal(f5(), v6) v7 = rv2() -v8 = rv1(2) +v8 = rv1(2,) @bm.functional def f9(): \treturn (v8.wrapper(*v8.arguments)) @@ -101,7 +135,14 @@ def f9(): def rv3(): \treturn torch.distributions.Normal(f9(), v6) v10 = rv3() -queries = [v1,v7,v10] -observations = {v4 : torch.tensor(0.0)} +opt_queries = [v1,v7,v10] +opt_observations = {v4 : torch.tensor(0.0)} +samples = GlobalNoUTurnSampler().infer( + opt_queries, + opt_observations, + num_samples=1000, + num_chains=4, + num_adaptive_samples=0 + ) """ self.assertEqual(expected.strip(), observed.strip())