You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
When other RVs are concatenated together using torch.stack, BMGInference fails to
trace execution because it assumes that all arguments to stack are of type Tensor.
The example runs fine if stack is replaced by torch.tensor, but torch.tensor is not differentiable wrt its arguments which precludes methods such as VI and HMC.
Steps to Reproduce
import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.stack([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
raises
expected Tensor as element 0 in argument 0, but got SampleNode
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-358-e65cd9e99a94> in <module>
4 foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.stack([bar(i) for i in range(2)]), torch.eye(2)))
5 bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
----> 6 BMGInference().infer(
7 queries=[foo()],
8 observations={},
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in infer(self, queries, observations, num_samples, num_chains, inference_type, skip_optimizations)
262 # TODO: Add verbose level
263 # TODO: Add logging
--> 264 samples, _ = self._infer(
265 queries,
266 observations,
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _infer(self, queries, observations, num_samples, num_chains, inference_type, produce_report, skip_optimizations)
182 self._pd = prof.ProfilerData()
183
--> 184 rt = self._accumulate_graph(queries, observations)
185 bmg = rt._bmg
186 report = pr.PerformanceReport()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _accumulate_graph(self, queries, observations)
71 rt = BMGRuntime()
72 rt._pd = self._pd
---> 73 bmg = rt.accumulate_graph(queries, observations)
74 # TODO: Figure out a better way to pass this flag around
75 bmg._fix_observe_true = self._fix_observe_true
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in accumulate_graph(self, queries, observations)
719 self._bmg.add_observation(node, val)
720 for qrv in queries:
--> 721 node = self._rv_to_node(qrv)
722 q = self._bmg.add_query(node)
723 self._rv_to_query[qrv] = q
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in _rv_to_node(self, rv)
583 # RVID, and if we're in the second situation, we will not.
584
--> 585 value = self._context.call(rewritten_function, rv.arguments)
586 if isinstance(value, RVIdentifier):
587 # We have a rewritten function with a decorator already applied.
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/execution_context.py in call(self, func, args, kwargs)
92 self._stack.push(FunctionCall(func, args, kwargs))
93 try:
---> 94 return func(*args, **kwargs)
95 finally:
96 self._stack.pop()
<BMGJIT> in a1()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in handle_function(self, function, arguments, kwargs)
510 function, arguments, kwargs
511 ):
--> 512 result = self._special_function_caller.do_special_call_maybe_stochastic(
513 function, arguments, kwargs
514 )
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/special_function_caller.py in do_special_call_maybe_stochastic(self, func, args, kwargs)
629 new_args = (_get_ordinary_value(arg) for arg in args)
630 new_kwargs = {key: _get_ordinary_value(arg) for key, arg in kwargs.items()}
--> 631 return func(*new_args, **new_kwargs)
632
633 if _is_in_place_operator(func):
TypeError: expected Tensor as element 0 in argument 0, but got SampleNode
Expected Behavior
Successful execution with identical results to s/stack/tensor i.e.
import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.tensor([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
The text was updated successfully, but these errors were encountered:
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Issue Description
When other RVs are concatenated together using
torch.stack
,BMGInference
fails totrace execution because it assumes that all arguments to
stack
are of typeTensor
.The example runs fine if
stack
is replaced bytorch.tensor
, buttorch.tensor
is not differentiable wrt its arguments which precludes methods such as VI and HMC.Steps to Reproduce
raises
Expected Behavior
Successful execution with identical results to
s/stack/tensor
i.e.The text was updated successfully, but these errors were encountered: