Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Fix the use of tensor as arguments of random variables in unit tests …
Browse files Browse the repository at this point in the history
…and tutorials

Differential Revision: D40856513

fbshipit-source-id: 9fca5a59a62d3515d2f3005c81eb9c84d2ea0f36
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Oct 31, 2022
1 parent 1f04780 commit 152a4f9
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 286 deletions.
8 changes: 0 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,8 @@ filterwarnings = [
"ignore::DeprecationWarning:nbval",
# PyTorch 1.10 warns against creating a tensor from a list of numpy arrays
"default:Creating a tensor from a list of numpy.ndarrays is extremely slow.*:UserWarning",
# xarray uses a module that's deprecated since setuptools 60.0.0. This has been
# fixed in xarray/pull/6096, so we can remove this filter with the next xarray
# release
"default:distutils Version classes are deprecated.*:DeprecationWarning",
# statsmodels imports a module that's deprecated since pandas 1.14.0
"default:pandas.Int64Index is deprecated *:FutureWarning",
# functorch 0.1.0 imports deprecated _stateless module
"default:The `torch.nn.utils._stateless` code is deprecated*:DeprecationWarning",
# BM warns against using torch tensors as arguments of random variables
"default:PyTorch tensors are hashed by memory address instead of value.*:UserWarning",
# Arviz warns against the use of deprecated methods, due to the recent release of matplotlib v3.6.0
"default:The register_cmap function will be deprecated in a future version.*:PendingDeprecationWarning",
# gpytorch < 1.9.0 uses torch.triangular_solve
Expand Down
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/model/rv_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __post_init__(self):
warnings.warn(
"PyTorch tensors are hashed by memory address instead of value. "
"Therefore, it is not recommended to use tensors as indices of random variables.",
stacklevel=3,
# display the warning on where the RVIdentifier is created
stacklevel=5,
)

def __str__(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/compiler/gaussian_mixture_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def category(item):

@bm.random_variable
def mixed(item):
return Normal(mean(category(item)), 2)
return Normal(mean(category(item).item()), 2)


class GaussianMixtureModelTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/compiler/gmm_1d_2comp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def component(self, i):

@bm.random_variable
def y(self, i):
c = self.component(i)
c = self.component(i).item()
return dist.Normal(self.mu(c), self.sigma(c))


Expand Down
5 changes: 5 additions & 0 deletions tests/ppl/compiler/jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import astor
import beanmachine.ppl as bm
import pytest
from beanmachine.ppl.compiler.bm_to_bmg import (
_bm_function_to_bmg_ast,
_bm_function_to_bmg_function,
Expand Down Expand Up @@ -528,6 +529,10 @@ def test_bad_control_flow_4(self) -> None:
"Functional calls must not have named arguments.",
)

# Ignore the warnings against using tensor as arguments of random variables
@pytest.mark.filterwarnings(
"ignore:PyTorch tensors are hashed by memory address*:UserWarning"
)
def test_rv_identity(self) -> None:
self.maxDiff = None

Expand Down
7 changes: 7 additions & 0 deletions tests/ppl/compiler/stochastic_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@
import unittest

import beanmachine.ppl as bm

import pytest
from beanmachine.ppl.compiler.gen_dot import to_dot
from beanmachine.ppl.compiler.runtime import BMGRuntime
from beanmachine.ppl.inference import BMGInference
from torch import tensor
from torch.distributions import Bernoulli, Beta, Dirichlet, Normal

# Ignore all warnings in this module against using tensor as arguments of random
# variables
pytestmark = pytest.mark.filterwarnings(
"ignore:PyTorch tensors are hashed by memory address*:UserWarning"
)

# Random variable that takes an argument
@bm.random_variable
Expand Down
4 changes: 2 additions & 2 deletions tests/ppl/compiler/support_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def cat_or_bern(n):

@bm.functional
def switch_inf():
return normal_or_bern(flip1(0))
return normal_or_bern(flip1(0).item())


@bm.functional
def switch_4():
return cat_or_bern(flip1(0))
return cat_or_bern(flip1(0).item())


class NodeSupportTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def component(self, i):

@bm.random_variable
def y(self, i):
c = self.component(i)
c = self.component(i).item()
return dist.Normal(self.mu(c), self.sigma(c))


Expand Down
26 changes: 20 additions & 6 deletions tests/ppl/compiler/tutorial_Robust_Linear_Regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# TODO: Check imports for conistency

import beanmachine.ppl as bm
import pytest
import torch # from torch import manual_seed, tensor
import torch.distributions as dist # from torch.distributions import Bernoulli, Normal, Uniform
from beanmachine.ppl.distributions import Flat
from beanmachine.ppl.inference.bmg_inference import BMGInference
from sklearn import model_selection
from torch import tensor
Expand Down Expand Up @@ -65,12 +67,19 @@ def df_nu():


@bm.random_variable
def y_robust(X):
def X():
return Flat()


@bm.random_variable
def y_robust():
"""
Heavy-Tailed Noise model for regression utilizing StudentT
Student's T : https://en.wikipedia.org/wiki/Student%27s_t-distribution
"""
return dist.StudentT(df=df_nu(), loc=beta() * X + alpha(), scale=sigma_regressor())
return dist.StudentT(
df=df_nu(), loc=beta() * X() + alpha(), scale=sigma_regressor()
)


# Creating sample data
Expand All @@ -88,8 +97,9 @@ def y_robust(X):

dist_clean = dist.MultivariateNormal(loc=torch.zeros(2), covariance_matrix=cov)
points = tensor([dist_clean.sample().tolist() for i in range(N)]).view(N, 2)
X = X_clean = points[:, 0]
Y = Y_clean = points[:, 1]

X_clean = points[:, 0]
Y_clean = points[:, 1]

true_beta_1 = 2.0
true_beta_0 = 5.0
Expand All @@ -102,7 +112,7 @@ def y_robust(X):
X_corr = points_noisy[:, 0]
Y_corr = points_noisy[:, 1]

X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y)
X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X_corr, Y_corr)

# Inference parameters

Expand All @@ -111,7 +121,7 @@ def y_robust(X):
)
num_chains = 4

observations = {y_robust(X_train): Y_train}
observations = {y_robust(): Y_train, X(): X_train}

queries = [beta(), alpha(), sigma_regressor(), df_nu()]

Expand All @@ -137,6 +147,10 @@ def test_tutorial_Robust_Linear_Regression(self) -> None:

self.assertTrue(True, msg="We just want to check this point is reached")

# TODO: re-enable once we can compile Flat distribution
@pytest.mark.xfail(
raises=TypeError, reason="Flat distribution not supported by BMG yet"
)
def test_tutorial_Robust_Linear_Regression_to_dot_cpp_python(
self,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/ppl/inference/predictive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def test_predictive_dynamic(self):
def test_predictive_data(self):
x = torch.randn(4)
y = torch.randn(4) + 2.0
obs = {self.likelihood_reg(x): y}
obs = {self.likelihood_reg(x.item()): y}
post_samples = bm.SingleSiteAncestralMetropolisHastings().infer(
[self.prior()], obs, num_samples=10, num_chains=2
)
assert post_samples[self.prior()].shape == (2, 10)
test_x = torch.randn(4, 1, 1)
test_query = self.likelihood_reg(test_x)
test_query = self.likelihood_reg(test_x.item())
predictives = bm.simulate([test_query], post_samples, vectorized=True)
assert predictives[test_query].shape == (4, 2, 10)

Expand Down
Loading

0 comments on commit 152a4f9

Please sign in to comment.