Skip to content

Commit

Permalink
update: test_compute_power
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Nov 3, 2023
1 parent 60f1862 commit 2855b79
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,42 +144,42 @@ def test_compute_power():
assert torch.tensor([1.0]) == x

# case 3 : len(x.shape) != 1 and x.shape[0] != 1, n&n-1 != 0
x = compute_power_schur_newton(torch.ones((2, 2)), p=5)
x = compute_power_schur_newton(torch.ones((2, 2)), p=3)
np.testing.assert_array_almost_equal(
np.asarray([[7.35, -6.48], [-6.48, 7.35]]),
np.asarray([[39.7070, -38.9133], [-38.9133, 39.7070]]),
x.numpy(),
decimal=2,
decimal=4,
)

# case 4 p=1
# case 4 : p=1
x = compute_power_schur_newton(torch.ones((2, 2)), p=1)
assert np.sum(x.numpy() - np.asarray([[252206.4062, -252205.8750], [-252205.8750, 252206.4062]])) < 200

# case 5 p=8
# case 5 : p=8
x = compute_power_schur_newton(torch.ones((2, 2)), p=8)
np.testing.assert_array_almost_equal(
np.asarray([[3.0399, -2.1229], [-2.1229, 3.0399]]),
x.numpy(),
decimal=2,
)

# case 6 p=16
# case 6 : p=16
x = compute_power_schur_newton(torch.ones((2, 2)), p=16)
np.testing.assert_array_almost_equal(
np.asarray([[1.6142, -0.6567], [-0.6567, 1.6142]]),
x.numpy(),
decimal=2,
)

# case 7 max_error_ratio=0
# case 7 : max_error_ratio=0
x = compute_power_schur_newton(torch.ones((2, 2)), p=16, max_error_ratio=0.0)
np.testing.assert_array_almost_equal(
np.asarray([[1.0946, 0.0000], [0.0000, 1.0946]]),
x.numpy(),
decimal=2,
)

# case 8 p=2
# case 8 : p=2
x = compute_power_schur_newton(torch.ones((2, 2)), p=2)
assert np.sum(x.numpy() - np.asarray([[359.1108, -358.4036], [-358.4036, 359.1108]])) < 50

Expand Down

0 comments on commit 2855b79

Please sign in to comment.