Skip to content

Commit

Permalink
fix jacobian tests for new inference network design
Browse files Browse the repository at this point in the history
also adjust invertible layer test for consistency
  • Loading branch information
LarsKue committed May 31, 2024
1 parent 8f9d433 commit 37b0cf7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_jacobian_numerically(invertible_layer, random_input):
import torch

forward_output, forward_log_det = invertible_layer(random_input)
numerical_forward_jacobian, _ = torch.autograd.functional.jacobian(invertible_layer, random_input, vectorize=True)
numerical_forward_jacobian, *_ = torch.autograd.functional.jacobian(invertible_layer, random_input, vectorize=True)

# TODO: torch is somehow permuted wrt keras
numerical_forward_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_forward_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_input)[0])]
Expand All @@ -78,7 +78,7 @@ def test_jacobian_numerically(invertible_layer, random_input):

inverse_output, inverse_log_det = invertible_layer(random_input, inverse=True)

numerical_inverse_jacobian, _ = torch.autograd.functional.jacobian(functools.partial(invertible_layer, inverse=True), random_input, vectorize=True)
numerical_inverse_jacobian, *_ = torch.autograd.functional.jacobian(functools.partial(invertible_layer, inverse=True), random_input, vectorize=True)

# TODO: torch is somehow permuted wrt keras
numerical_inverse_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_inverse_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_input)[0])]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ def test_jacobian_numerically(inference_network, random_samples):
import torch

forward_output, forward_log_det = inference_network(random_samples, jacobian=True)
numerical_forward_jacobian, _ = torch.autograd.functional.jacobian(inference_network, random_samples, vectorize=True)
numerical_forward_jacobian, *_ = torch.autograd.functional.jacobian(inference_network, random_samples, vectorize=True)

# TODO: torch is somehow permuted wrt keras
numerical_forward_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_forward_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_samples)[0])]
numerical_forward_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_forward_jacobian[:, i, :]))) for i in range(keras.ops.shape(random_samples)[0])]
numerical_forward_log_det = keras.ops.stack(numerical_forward_log_det, axis=0)

assert allclose(forward_log_det, numerical_forward_log_det, rtol=1e-4, atol=1e-5)

inverse_output, inverse_log_det = inference_network(random_samples, jacobian=True, inverse=True)

numerical_inverse_jacobian, _ = torch.autograd.functional.jacobian(functools.partial(inference_network, inverse=True), random_samples, vectorize=True)
numerical_inverse_jacobian, *_ = torch.autograd.functional.jacobian(functools.partial(inference_network, inverse=True), random_samples, vectorize=True)

# TODO: torch is somehow permuted wrt keras
numerical_inverse_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_inverse_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_samples)[0])]
numerical_inverse_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_inverse_jacobian[:, i, :]))) for i in range(keras.ops.shape(random_samples)[0])]
numerical_inverse_log_det = keras.ops.stack(numerical_inverse_log_det, axis=0)

assert allclose(inverse_log_det, numerical_inverse_log_det, rtol=1e-4, atol=1e-5)
Expand Down

0 comments on commit 37b0cf7

Please sign in to comment.