Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

BMGInference does not handle torch.stack called with other RVs #1565

Open
feynmanliang opened this issue Jul 23, 2022 · 0 comments
Open

BMGInference does not handle torch.stack called with other RVs #1565

feynmanliang opened this issue Jul 23, 2022 · 0 comments

Comments

@feynmanliang
Copy link
Contributor

feynmanliang commented Jul 23, 2022

Issue Description

When other RVs are concatenated together using torch.stack, BMGInference fails to
trace execution because it assumes that all arguments to stack are of type Tensor.

The example runs fine if stack is replaced by torch.tensor, but torch.tensor is not differentiable wrt its arguments which precludes methods such as VI and HMC.

Steps to Reproduce

import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference

foo = bm.random_variable(lambda: dist.Normal(torch.stack([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
    queries=[foo()],
    observations={},
    num_samples=1,
)

raises

expected Tensor as element 0 in argument 0, but got SampleNode
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-358-e65cd9e99a94> in <module>
      4 foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.stack([bar(i) for i in range(2)]), torch.eye(2)))
      5 bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
----> 6 BMGInference().infer(
      7     queries=[foo()],
      8     observations={},
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in infer(self, queries, observations, num_samples, num_chains, inference_type, skip_optimizations)
    262         # TODO: Add verbose level
    263         # TODO: Add logging
--> 264         samples, _ = self._infer(
    265             queries,
    266             observations,
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _infer(self, queries, observations, num_samples, num_chains, inference_type, produce_report, skip_optimizations)
    182             self._pd = prof.ProfilerData()
    183 
--> 184         rt = self._accumulate_graph(queries, observations)
    185         bmg = rt._bmg
    186         report = pr.PerformanceReport()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _accumulate_graph(self, queries, observations)
     71         rt = BMGRuntime()
     72         rt._pd = self._pd
---> 73         bmg = rt.accumulate_graph(queries, observations)
     74         # TODO: Figure out a better way to pass this flag around
     75         bmg._fix_observe_true = self._fix_observe_true
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in accumulate_graph(self, queries, observations)
    719             self._bmg.add_observation(node, val)
    720         for qrv in queries:
--> 721             node = self._rv_to_node(qrv)
    722             q = self._bmg.add_query(node)
    723             self._rv_to_query[qrv] = q
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in _rv_to_node(self, rv)
    583                 # RVID, and if we're in the second situation, we will not.
    584 
--> 585                 value = self._context.call(rewritten_function, rv.arguments)
    586                 if isinstance(value, RVIdentifier):
    587                     # We have a rewritten function with a decorator already applied.
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/execution_context.py in call(self, func, args, kwargs)
     92         self._stack.push(FunctionCall(func, args, kwargs))
     93         try:
---> 94             return func(*args, **kwargs)
     95         finally:
     96             self._stack.pop()
<BMGJIT> in a1()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in handle_function(self, function, arguments, kwargs)
    510             function, arguments, kwargs
    511         ):
--> 512             result = self._special_function_caller.do_special_call_maybe_stochastic(
    513                 function, arguments, kwargs
    514             )
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/special_function_caller.py in do_special_call_maybe_stochastic(self, func, args, kwargs)
    629             new_args = (_get_ordinary_value(arg) for arg in args)
    630             new_kwargs = {key: _get_ordinary_value(arg) for key, arg in kwargs.items()}
--> 631             return func(*new_args, **new_kwargs)
    632 
    633         if _is_in_place_operator(func):
TypeError: expected Tensor as element 0 in argument 0, but got SampleNode

Expected Behavior

Successful execution with identical results to s/stack/tensor i.e.

import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference

foo = bm.random_variable(lambda: dist.Normal(torch.tensor([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
    queries=[foo()],
    observations={},
    num_samples=1,
)
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant