Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
SamFerracin committed Dec 5, 2023
1 parent bf4c490 commit 081404b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 19 deletions.
14 changes: 6 additions & 8 deletions mrmustard/math/backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,21 +428,19 @@ def hermite_renormalized(

precision_bits = settings.PRECISION_BITS_HERMITE_POLY

_A, _B, _C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C)

if precision_bits == 128: # numba
G = vanilla(tuple(shape), _A, _B, _C.max())
G = vanilla(tuple(shape), A, B, C)
else: # julia (with precision_bits = 512)
# The following import must come after running "jl = Julia(compiled_modules=False)" in settings.py
from julia import Main as Main_julia # pylint: disable=import-outside-toplevel

_A, _B, _C = (
_A.astype(np.complex128),
_B.astype(np.complex128),
_C.astype(np.complex128),
A, B, C = (
A.astype(np.complex128),
B.astype(np.complex128),
C.astype(np.complex128),
)
G = Main_julia.Vanilla.vanilla(
_A, _B, _C.item(), np.array(shape, dtype=np.int64), precision_bits
A, B, C.item(), np.array(shape, dtype=np.int64), precision_bits
)

return G
Expand Down
16 changes: 8 additions & 8 deletions mrmustard/math/backend_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,26 +400,26 @@ def hermite_renormalized(

precision_bits = settings.PRECISION_BITS_HERMITE_POLY

_A, _B, _C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C)
A, B, C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C)

if precision_bits == 128: # numba
G = strategies.vanilla(tuple(shape), _A, _B, _C)
G = strategies.vanilla(tuple(shape), A, B, C)
else: # julia
# The following import must come after running "jl = Julia(compiled_modules=False)" in settings.py
from julia import Main as Main_julia # pylint: disable=import-outside-toplevel

_A, _B, _C = (
_A.astype(np.complex128),
_B.astype(np.complex128),
_C.astype(np.complex128),
A, B, C = (
A.astype(np.complex128),
B.astype(np.complex128),
C.astype(np.complex128),
)

G = Main_julia.Vanilla.vanilla(
_A, _B, _C.item(), np.array(shape, dtype=np.int64), precision_bits
A, B, C.item(), np.array(shape, dtype=np.int64), precision_bits
)

def grad(dLdGconj):
dLdA, dLdB, dLdC = strategies.vanilla_vjp(G, _C, np.conj(dLdGconj))
dLdA, dLdB, dLdC = strategies.vanilla_vjp(G, C, np.conj(dLdGconj))
return self.conj(dLdA), self.conj(dLdB), self.conj(dLdC)

return G, grad
Expand Down
4 changes: 2 additions & 2 deletions mrmustard/math/lattice/strategies/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from mrmustard.math.lattice import paths, steps
from mrmustard.utils.typing import ComplexMatrix, ComplexTensor, ComplexVector
from .flat_indices import first_available_pivot, lower_neighbours, shape_to_strides
from .flat_indices import first_available_pivot, lower_neighbors, shape_to_strides

__all__ = ["vanilla", "vanilla_batch", "vanilla_jacobian", "vanilla_vjp"]

Expand Down Expand Up @@ -65,7 +65,7 @@ def vanilla(shape: tuple[int, ...], A, b, c) -> ComplexTensor: # pragma: no cov
value_at_index = b[i] * ret[pivot]

# add the contribution of pivot's lower's neighbours
ns = lower_neighbours(pivot, strides, i)
ns = lower_neighbors(pivot, strides, i)
(j0, n0) = next(ns)
value_at_index += A[i, j0] * np.sqrt(index_u[j0] - 1) * ret[n0]
for j, n in ns:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_math/test_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ def test_diagonalbatchNumba_vs_diagonalNumba():

assert np.allclose(G_ref, G_batched[:, :, :, 0])
assert np.allclose(G_ref, G_batched[:, :, :, 1])
assert np.allclose(G_ref, G_batched[:, :, :, 2])
assert np.allclose(G_ref, G_batched[:, :, :, 2])

0 comments on commit 081404b

Please sign in to comment.