Skip to content

Commit

Permalink
approx equal (#1862)
Browse files Browse the repository at this point in the history
  • Loading branch information
watiss authored Jan 20, 2023
1 parent 1ac9c0c commit 7c01ce3
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/models/test_models_latent_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def prep_model(cls=SCVI, layer=None, use_size_factor=False):
return model, adata, adata_lib_size, adata_before_setup


def assert_approx_equal(a, b):
# Allclose because on GPU, the values are not exactly the same
# as latents are moved to cpu in latent mode
np.testing.assert_allclose(a, b, rtol=3e-1, atol=5e-1)


def run_test_scvi_latent_mode(
cls=SCVI,
n_samples: int = 1,
Expand Down Expand Up @@ -115,11 +121,7 @@ def run_test_scvi_latent_mode(
assert params_latent[k].shape == adata_orig.shape

for k in keys:
# Allclose because on GPU, the values are not exactly the same
# as latents are moved to cpu in latent mode
np.testing.assert_allclose(
params_latent[k], params_orig[k], rtol=3e-1, atol=5e-1
)
assert_approx_equal(params_latent[k], params_orig[k])


def test_scvi_latent_mode_one_sample():
Expand Down Expand Up @@ -410,4 +412,4 @@ def test_scvi_latent_mode_get_feature_correlation_matrix():
transform_batch=["batch_0", "batch_1"],
)

np.testing.assert_array_equal(fcm_latent, fcm_orig)
assert_approx_equal(fcm_latent, fcm_orig)

0 comments on commit 7c01ce3

Please sign in to comment.