diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index f48350533..4377404de 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -53,6 +53,11 @@ which uses the old Numba code. When setting to a higher value, the new Julia cod [(#303)](https://github.com/XanaduAI/MrMustard/pull/303) [(#304)](https://github.com/XanaduAI/MrMustard/pull/304) +* Improves the algorithm implemented in `vanilla` to achieve a speedup. Specifically, the improved + algorithm works on a flattened array (which is reshaped before returning) as opposed to a + multi-dimensional array. + [(#312)](https://github.com/XanaduAI/MrMustard/pull/312) + * Adds functions `hermite_renormalized_batch` and `hermite_renormalized_diagonal_batch` to speed up calculating Hermite polynomials over a batch of B vectors. [(#308)](https://github.com/XanaduAI/MrMustard/pull/308) diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index 94d7267f2..f1df21ee5 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -428,21 +428,19 @@ def hermite_renormalized( precision_bits = settings.PRECISION_BITS_HERMITE_POLY - _A, _B, _C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C) - if precision_bits == 128: # numba - G = vanilla(tuple(shape), _A, _B, _C) + G = vanilla(tuple(shape), A, B, C) else: # julia (with precision_bits = 512) # The following import must come after running "jl = Julia(compiled_modules=False)" in settings.py from julia import Main as Main_julia # pylint: disable=import-outside-toplevel - _A, _B, _C = ( - _A.astype(np.complex128), - _B.astype(np.complex128), - _C.astype(np.complex128), + A, B, C = ( + A.astype(np.complex128), + B.astype(np.complex128), + C.astype(np.complex128), ) G = Main_julia.Vanilla.vanilla( - _A, _B, _C.item(), np.array(shape, dtype=np.int64), precision_bits + A, B, C.item(), np.array(shape, dtype=np.int64), precision_bits ) return G diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index bf93e425a..9d3e0e85c 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -400,26 +400,26 @@ def hermite_renormalized( precision_bits = settings.PRECISION_BITS_HERMITE_POLY - _A, _B, _C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C) + A, B, C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C) if precision_bits == 128: # numba - G = strategies.vanilla(tuple(shape), _A, _B, _C) + G = strategies.vanilla(tuple(shape), A, B, C) else: # julia # The following import must come after running "jl = Julia(compiled_modules=False)" in settings.py from julia import Main as Main_julia # pylint: disable=import-outside-toplevel - _A, _B, _C = ( - _A.astype(np.complex128), - _B.astype(np.complex128), - _C.astype(np.complex128), + A, B, C = ( + A.astype(np.complex128), + B.astype(np.complex128), + C.astype(np.complex128), ) G = Main_julia.Vanilla.vanilla( - _A, _B, _C.item(), np.array(shape, dtype=np.int64), precision_bits + A, B, C.item(), np.array(shape, dtype=np.int64), precision_bits ) def grad(dLdGconj): - dLdA, dLdB, dLdC = strategies.vanilla_vjp(G, _C, np.conj(dLdGconj)) + dLdA, dLdB, dLdC = strategies.vanilla_vjp(G, C, np.conj(dLdGconj)) return self.conj(dLdA), self.conj(dLdB), self.conj(dLdC) return G, grad diff --git a/mrmustard/math/lattice/strategies/flat_indices.py b/mrmustard/math/lattice/strategies/flat_indices.py new file mode 100644 index 000000000..4a5d945e0 --- /dev/null +++ b/mrmustard/math/lattice/strategies/flat_indices.py @@ -0,0 +1,76 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Contains the functions to operate with flattened indices. + +Given a multi-dimensional ``np.ndarray``, we can index its elements using ``np.ndindex``. +Alternatevely, we can flatten the multi-dimensional array and index its elements with +``int``s (hereby referred to as ''flat indices''). +""" + +from typing import Iterator, Sequence +from numba import njit + +import numpy as np + + +@njit +def first_available_pivot( + index: int, strides: Sequence[int] +) -> tuple[int, tuple[int, ...]]: # pragma: no cover + r""" + Returns the first available pivot for the given flat index. + A pivot is a nearest neighbor of the index. Here we pick the first available pivot. + + Arguments: + index: the flat index to get the first available pivot of. + strides: the strides that allow mapping the flat index to a tuple index. + + Returns: + the flat index that was decremented and the pivot. + """ + for i, s in enumerate(strides): + y = index - s + if y >= 0: + return (i, y) + raise ValueError("Index is zero.") + + +@njit +def lower_neighbors( + index: int, strides: Sequence[int], start: int +) -> Iterator[tuple[int, tuple[int, ...]]]: # pragma: no cover + r""" + Yields the flat indices of the lower neighbours of the given flat index. + """ + for i in range(start, len(strides)): + yield i, index - strides[i] + + +@njit +def shape_to_strides(shape: Sequence[int]) -> Sequence[int]: # pragma: no cover + r""" + Calculates strides from shape. + + Arguments: + shape: the shape of the ``np.ndindex``. + + Returns: + the strides that allow mapping a flat index to the corresponding ``np.ndindex``. + """ + strides = np.ones_like(shape) + for i in range(1, len(shape)): + strides[i - 1] = np.prod(shape[i:]) + return strides diff --git a/mrmustard/math/lattice/strategies/vanilla.py b/mrmustard/math/lattice/strategies/vanilla.py index 7be29948d..8be1440d9 100644 --- a/mrmustard/math/lattice/strategies/vanilla.py +++ b/mrmustard/math/lattice/strategies/vanilla.py @@ -17,16 +17,17 @@ from mrmustard.math.lattice import paths, steps from mrmustard.utils.typing import ComplexMatrix, ComplexTensor, ComplexVector - -SQRT = np.sqrt(np.arange(100000)) +from .flat_indices import first_available_pivot, lower_neighbors, shape_to_strides __all__ = ["vanilla", "vanilla_batch", "vanilla_jacobian", "vanilla_vjp"] @njit def vanilla(shape: tuple[int, ...], A, b, c) -> ComplexTensor: # pragma: no cover - r"""Vanilla Fock-Bargmann strategy. Fills the tensor by iterating over all indices - in ndindex order. + r"""Vanilla Fock-Bargmann strategy. + + Flattens the tensors, then fills it by iterating over all indices in the order + given by ``np.ndindex``. Finally, it reshapes the tensor before returning. Args: shape (tuple[int, ...]): shape of the output tensor @@ -37,28 +38,49 @@ def vanilla(shape: tuple[int, ...], A, b, c) -> ComplexTensor: # pragma: no cov Returns: np.ndarray: Fock representation of the Gaussian tensor with shape ``shape`` """ + # calculate the strides + strides = shape_to_strides(np.array(shape)) - # init output tensor - G = np.zeros(shape, dtype=np.complex128) + # init flat output tensor + ret = np.array([0 + 0j] * np.prod(np.array(shape))) - # initialize path iterator - path = np.ndindex(shape) + # initialize the indeces. + # ``index`` is the index of the flattened output tensor, while + # ``index_u_iter`` iterates through the unravelled counterparts of + # ``index``. + index = 0 + index_u_iter = np.ndindex(shape) + next(index_u_iter) # write vacuum amplitude - G[next(path)] = c + ret[0] = c # iterate over the rest of the indices - for index in path: - G[index] = steps.vanilla_step(G, A, b, index) - return G + for index_u in index_u_iter: + # update index + index += 1 + + # calculate pivot's contribution + i, pivot = first_available_pivot(index, strides) + value_at_index = b[i] * ret[pivot] + + # add the contribution of pivot's lower's neighbours + ns = lower_neighbors(pivot, strides, i) + (j0, n0) = next(ns) + value_at_index += A[i, j0] * np.sqrt(index_u[j0] - 1) * ret[n0] + for j, n in ns: + value_at_index += A[i, j] * np.sqrt(index_u[j]) * ret[n] + ret[index] = value_at_index / np.sqrt(index_u[i]) + + return ret.reshape(shape) @njit def vanilla_batch(shape: tuple[int, ...], A, b, c) -> ComplexTensor: # pragma: no cover - r"""Vanilla batched Fock-Bargmann strategy. Fills the tensor by iterating over all indices - in ndindex order. - Note that this function is different from vanilla with b is no longer a vector, - it becomes a bathced vector with the batch dimension on the last index. + r"""Vanilla Fock-Bargmann strategy for batched ``b``, with batched dimension on the + last index. + + Fills the tensor by iterating over all indices in the order given by ``np.ndindex``. Args: shape (tuple[int, ...]): shape of the output tensor with the batch dimension on the last term diff --git a/tests/test_math/test_flat_indices.py b/tests/test_math/test_flat_indices.py new file mode 100644 index 000000000..1d38ed810 --- /dev/null +++ b/tests/test_math/test_flat_indices.py @@ -0,0 +1,69 @@ +# Copyright 2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for flat indices""" + +import numpy as np +import pytest + +from mrmustard.math.lattice.strategies.flat_indices import ( + first_available_pivot, + lower_neighbors, + shape_to_strides, +) + + +def test_shape_to_strides(): + r""" + Tests the ``shape_to_strides`` method. + """ + shape1 = np.array([2]) + strides1 = np.array([1]) + assert np.allclose(shape_to_strides(shape1), strides1) + + shape2 = np.array([1, 2]) + strides2 = np.array([2, 1]) + assert np.allclose(shape_to_strides(shape2), strides2) + + shape3 = np.array([4, 5, 6]) + strides3 = np.array([30, 6, 1]) + assert np.allclose(shape_to_strides(shape3), strides3) + + +def test_first_available_pivot(): + r""" + Tests the ``first_available_pivot`` method. + """ + strides1 = shape_to_strides(np.array([2, 2, 2])) + + with pytest.raises(ValueError, match="zero"): + first_available_pivot(0, strides1) + assert first_available_pivot(1, strides1) == (2, 0) + assert first_available_pivot(2, strides1) == (1, 0) + assert first_available_pivot(3, strides1) == (1, 1) + assert first_available_pivot(4, strides1) == (0, 0) + assert first_available_pivot(5, strides1) == (0, 1) + assert first_available_pivot(6, strides1) == (0, 2) + assert first_available_pivot(7, strides1) == (0, 3) + + +def test_lower_neighbors(): + r""" + Tests the ``lower_neighbors`` method. + """ + strides = shape_to_strides(np.array([2, 2, 2])) + + assert list(lower_neighbors(1, strides, 0)) == [(0, -3), (1, -1), (2, 0)] + assert list(lower_neighbors(1, strides, 1)) == [(1, -1), (2, 0)] + assert list(lower_neighbors(1, strides, 2)) == [(2, 0)] diff --git a/tests/test_math/test_lattice.py b/tests/test_math/test_lattice.py index dd02ab0c3..76506660c 100644 --- a/tests/test_math/test_lattice.py +++ b/tests/test_math/test_lattice.py @@ -18,7 +18,7 @@ import pytest import numpy as np -from mrmustard.lab import Gaussian +from mrmustard.lab import Gaussian, Dgate from mrmustard import settings, math from mrmustard.physics.bargmann import wigner_to_bargmann_rho from mrmustard.math.lattice.strategies.binomial import binomial, binomial_dict @@ -66,47 +66,45 @@ def test_binomial_vs_binomialDict(): assert np.isclose(D[idx], G[idx]) -def test_vanillabatchNumba_vs_vanillaNumba(): +@pytest.mark.parametrize("batch_size", [1, 3]) +def test_vanillabatchNumba_vs_vanillaNumba(batch_size): """Test the batch version works versus the normal vanilla version.""" - state = Gaussian(3) + state = Gaussian(3) >> Dgate([0.0, 0.1, 0.2]) A, B, C = wigner_to_bargmann_rho( state.cov, state.means ) # Create random state (M mode Gaussian state with displacement) - batch = 3 - cutoffs = (20, 20, 20, 20, batch) + cutoffs = (20, 20, 20, 20, batch_size) # Vanilla MM G_ref = math.hermite_renormalized(A, B, C, shape=cutoffs[:-1]) # replicate the B - B_batched = np.stack((B,) * batch, axis=1) + B_batched = np.stack((B,) * batch_size, axis=1) G_batched = math.hermite_renormalized_batch(A, B_batched, C, shape=cutoffs) - assert np.allclose(G_ref, G_batched[:, :, :, :, 0]) - assert np.allclose(G_ref, G_batched[:, :, :, :, 1]) - assert np.allclose(G_ref, G_batched[:, :, :, :, 2]) + for nb in range(batch_size): + assert np.allclose(G_ref, G_batched[:, :, :, :, nb]) -def test_diagonalbatchNumba_vs_diagonalNumba(): +@pytest.mark.parametrize("batch_size", [1, 3]) +def test_diagonalbatchNumba_vs_diagonalNumba(batch_size): """Test the batch version works versus the normal diagonal version.""" - state = Gaussian(3) + state = Gaussian(3) >> Dgate([0.0, 0.1, 0.2]) A, B, C = wigner_to_bargmann_rho( state.cov, state.means ) # Create random state (M mode Gaussian state with displacement) - batch = 3 - cutoffs = (18, 19, 20, batch) + cutoffs = (18, 19, 20, batch_size) # Diagonal MM G_ref = math.hermite_renormalized_diagonal(A, B, C, cutoffs=cutoffs[:-1]) # replicate the B - B_batched = np.stack((B,) * batch, axis=1) + B_batched = np.stack((B,) * batch_size, axis=1) G_batched = math.hermite_renormalized_diagonal_batch(A, B_batched, C, cutoffs=cutoffs[:-1]) - assert np.allclose(G_ref, G_batched[:, :, :, 0]) - assert np.allclose(G_ref, G_batched[:, :, :, 1]) - assert np.allclose(G_ref, G_batched[:, :, :, 2]) + for nb in range(batch_size): + assert np.allclose(G_ref, G_batched[:, :, :, nb])