Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linearizing the vanilla algorithm #312

Merged
merged 21 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ which uses the old Numba code. When setting to a higher value, the new Julia cod
[(#303)](https://github.com/XanaduAI/MrMustard/pull/303)
[(#304)](https://github.com/XanaduAI/MrMustard/pull/304)

* Improves the algorithm implemented in `vanilla` to achieve a speedup. Specifically, the improved
algorithm works on a flattened array (which is reshaped before returning) as opposed to a
multi-dimensional array.
SamFerracin marked this conversation as resolved.
Show resolved Hide resolved

* Adds functions `hermite_renormalized_batch` and `hermite_renormalized_diagonal_batch` to speed up calculating
Hermite polynomials over a batch of B vectors.
[(#308)](https://github.com/XanaduAI/MrMustard/pull/308)
Expand Down
14 changes: 6 additions & 8 deletions mrmustard/math/backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,21 +428,19 @@ def hermite_renormalized(

precision_bits = settings.PRECISION_BITS_HERMITE_POLY

_A, _B, _C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C)

if precision_bits == 128: # numba
G = vanilla(tuple(shape), _A, _B, _C)
G = vanilla(tuple(shape), A, B, C)
else: # julia (with precision_bits = 512)
# The following import must come after running "jl = Julia(compiled_modules=False)" in settings.py
from julia import Main as Main_julia # pylint: disable=import-outside-toplevel

_A, _B, _C = (
_A.astype(np.complex128),
_B.astype(np.complex128),
_C.astype(np.complex128),
A, B, C = (
A.astype(np.complex128),
B.astype(np.complex128),
C.astype(np.complex128),
)
G = Main_julia.Vanilla.vanilla(
_A, _B, _C.item(), np.array(shape, dtype=np.int64), precision_bits
A, B, C.item(), np.array(shape, dtype=np.int64), precision_bits
)

return G
Expand Down
16 changes: 8 additions & 8 deletions mrmustard/math/backend_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,26 +400,26 @@ def hermite_renormalized(

precision_bits = settings.PRECISION_BITS_HERMITE_POLY

_A, _B, _C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C)
A, B, C = self.asnumpy(A), self.asnumpy(B), self.asnumpy(C)

if precision_bits == 128: # numba
G = strategies.vanilla(tuple(shape), _A, _B, _C)
G = strategies.vanilla(tuple(shape), A, B, C)
else: # julia
# The following import must come after running "jl = Julia(compiled_modules=False)" in settings.py
from julia import Main as Main_julia # pylint: disable=import-outside-toplevel

_A, _B, _C = (
_A.astype(np.complex128),
_B.astype(np.complex128),
_C.astype(np.complex128),
A, B, C = (
A.astype(np.complex128),
B.astype(np.complex128),
C.astype(np.complex128),
)

G = Main_julia.Vanilla.vanilla(
_A, _B, _C.item(), np.array(shape, dtype=np.int64), precision_bits
A, B, C.item(), np.array(shape, dtype=np.int64), precision_bits
)

def grad(dLdGconj):
dLdA, dLdB, dLdC = strategies.vanilla_vjp(G, _C, np.conj(dLdGconj))
dLdA, dLdB, dLdC = strategies.vanilla_vjp(G, C, np.conj(dLdGconj))
return self.conj(dLdA), self.conj(dLdB), self.conj(dLdC)

return G, grad
Expand Down
76 changes: 76 additions & 0 deletions mrmustard/math/lattice/strategies/flat_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Contains the functions to operate with flattened indices.

Given a multi-dimensional ``np.ndarray``, we can index its elements using ``np.ndindex``.
Alternatevely, we can flatten the multi-dimensional array and index its elements with
``int``s (hereby referred to as ''flat indices'').
"""

from typing import Iterator, Sequence
from numba import njit

import numpy as np


@njit
def first_available_pivot(
index: int, strides: Sequence[int]
) -> tuple[int, tuple[int, ...]]: # pragma: no cover
r"""
Returns the first available pivot for the given flat index.
A pivot is a nearest neighbor of the index. Here we pick the first available pivot.

Arguments:
index: the flat index to get the first available pivot of.
strides: the strides that allow mapping the flat index to a tuple index.

Returns:
the flat index that was decremented and the pivot.
"""
for i, s in enumerate(strides):
y = index - s
if y >= 0:
return (i, y)
raise ValueError("Index is zero.")


@njit
def lower_neighbors(
index: int, strides: Sequence[int], start: int
) -> Iterator[tuple[int, tuple[int, ...]]]: # pragma: no cover
r"""
Yields the flat indices of the lower neighbours of the given flat index.
"""
for i in range(start, len(strides)):
yield i, index - strides[i]


@njit
def shape_to_strides(shape: Sequence[int]) -> Sequence[int]: # pragma: no cover
r"""
Calculates strides from shape.

Arguments:
shape: the shape of the ``np.ndindex``.

Returns:
the strides that allow mapping a flat index to the corresponding ``np.ndindex``.
"""
strides = np.ones_like(shape)
for i in range(1, len(shape)):
strides[i - 1] = np.prod(shape[i:])
return strides
54 changes: 38 additions & 16 deletions mrmustard/math/lattice/strategies/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

from mrmustard.math.lattice import paths, steps
from mrmustard.utils.typing import ComplexMatrix, ComplexTensor, ComplexVector

SQRT = np.sqrt(np.arange(100000))
from .flat_indices import first_available_pivot, lower_neighbors, shape_to_strides

__all__ = ["vanilla", "vanilla_batch", "vanilla_jacobian", "vanilla_vjp"]


@njit
def vanilla(shape: tuple[int, ...], A, b, c) -> ComplexTensor: # pragma: no cover
r"""Vanilla Fock-Bargmann strategy. Fills the tensor by iterating over all indices
in ndindex order.
r"""Vanilla Fock-Bargmann strategy.

Flattens the tensors, then fills it by iterating over all indices in the order
given by ``np.ndindex``. Finally, it reshapes the tensor before returning.

Args:
shape (tuple[int, ...]): shape of the output tensor
Expand All @@ -37,28 +38,49 @@ def vanilla(shape: tuple[int, ...], A, b, c) -> ComplexTensor: # pragma: no cov
Returns:
np.ndarray: Fock representation of the Gaussian tensor with shape ``shape``
"""
# calculate the strides
strides = shape_to_strides(np.array(shape))

# init output tensor
G = np.zeros(shape, dtype=np.complex128)
# init flat output tensor
ret = np.array([0 + 0j] * np.prod(np.array(shape)))

# initialize path iterator
path = np.ndindex(shape)
# initialize the indeces.
# ``index`` is the index of the flattened output tensor, while
# ``index_u_iter`` iterates through the unravelled counterparts of
# ``index``.
index = 0
index_u_iter = np.ndindex(shape)
next(index_u_iter)

# write vacuum amplitude
G[next(path)] = c
ret[0] = c

# iterate over the rest of the indices
for index in path:
G[index] = steps.vanilla_step(G, A, b, index)
return G
for index_u in index_u_iter:
# update index
index += 1

# calculate pivot's contribution
i, pivot = first_available_pivot(index, strides)
value_at_index = b[i] * ret[pivot]

# add the contribution of pivot's lower's neighbours
ns = lower_neighbors(pivot, strides, i)
(j0, n0) = next(ns)
value_at_index += A[i, j0] * np.sqrt(index_u[j0] - 1) * ret[n0]
for j, n in ns:
value_at_index += A[i, j] * np.sqrt(index_u[j]) * ret[n]
ret[index] = value_at_index / np.sqrt(index_u[i])

return ret.reshape(shape)


@njit
def vanilla_batch(shape: tuple[int, ...], A, b, c) -> ComplexTensor: # pragma: no cover
r"""Vanilla batched Fock-Bargmann strategy. Fills the tensor by iterating over all indices
in ndindex order.
Note that this function is different from vanilla with b is no longer a vector,
it becomes a bathced vector with the batch dimension on the last index.
r"""Vanilla Fock-Bargmann strategy for batched ``b``, with batched dimension on the
last index.

Fills the tensor by iterating over all indices in the order given by ``np.ndindex``.

Args:
shape (tuple[int, ...]): shape of the output tensor with the batch dimension on the last term
Expand Down
69 changes: 69 additions & 0 deletions tests/test_math/test_flat_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for flat indices"""

import numpy as np
import pytest

from mrmustard.math.lattice.strategies.flat_indices import (
first_available_pivot,
lower_neighbors,
shape_to_strides,
)


def test_shape_to_strides():
r"""
Tests the ``shape_to_strides`` method.
"""
shape1 = np.array([2])
strides1 = np.array([1])
assert np.allclose(shape_to_strides(shape1), strides1)

shape2 = np.array([1, 2])
strides2 = np.array([2, 1])
assert np.allclose(shape_to_strides(shape2), strides2)

shape3 = np.array([4, 5, 6])
strides3 = np.array([30, 6, 1])
assert np.allclose(shape_to_strides(shape3), strides3)


def test_first_available_pivot():
r"""
Tests the ``first_available_pivot`` method.
"""
strides1 = shape_to_strides(np.array([2, 2, 2]))

with pytest.raises(ValueError, match="zero"):
first_available_pivot(0, strides1)
assert first_available_pivot(1, strides1) == (2, 0)
assert first_available_pivot(2, strides1) == (1, 0)
assert first_available_pivot(3, strides1) == (1, 1)
assert first_available_pivot(4, strides1) == (0, 0)
assert first_available_pivot(5, strides1) == (0, 1)
assert first_available_pivot(6, strides1) == (0, 2)
assert first_available_pivot(7, strides1) == (0, 3)


def test_lower_neighbors():
r"""
Tests the ``lower_neighbors`` method.
"""
strides = shape_to_strides(np.array([2, 2, 2]))

assert list(lower_neighbors(1, strides, 0)) == [(0, -3), (1, -1), (2, 0)]
assert list(lower_neighbors(1, strides, 1)) == [(1, -1), (2, 0)]
assert list(lower_neighbors(1, strides, 2)) == [(2, 0)]
32 changes: 15 additions & 17 deletions tests/test_math/test_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import numpy as np

from mrmustard.lab import Gaussian
from mrmustard.lab import Gaussian, Dgate
from mrmustard import settings, math
from mrmustard.physics.bargmann import wigner_to_bargmann_rho
from mrmustard.math.lattice.strategies.binomial import binomial, binomial_dict
Expand Down Expand Up @@ -66,47 +66,45 @@ def test_binomial_vs_binomialDict():
assert np.isclose(D[idx], G[idx])


def test_vanillabatchNumba_vs_vanillaNumba():
@pytest.mark.parametrize("n_batches", [1, 3])
def test_vanillabatchNumba_vs_vanillaNumba(n_batches):
"""Test the batch version works versus the normal vanilla version."""
state = Gaussian(3)
state = Gaussian(3) >> Dgate([0.0, 0.1, 0.2])
A, B, C = wigner_to_bargmann_rho(
state.cov, state.means
) # Create random state (M mode Gaussian state with displacement)

batch = 3
cutoffs = (20, 20, 20, 20, batch)
cutoffs = (20, 20, 20, 20, n_batches)

# Vanilla MM
G_ref = math.hermite_renormalized(A, B, C, shape=cutoffs[:-1])

# replicate the B
B_batched = np.stack((B,) * batch, axis=1)
B_batched = np.stack((B,) * n_batches, axis=1)

G_batched = math.hermite_renormalized_batch(A, B_batched, C, shape=cutoffs)

assert np.allclose(G_ref, G_batched[:, :, :, :, 0])
assert np.allclose(G_ref, G_batched[:, :, :, :, 1])
assert np.allclose(G_ref, G_batched[:, :, :, :, 2])
for nb in range(n_batches):
assert np.allclose(G_ref, G_batched[:, :, :, :, nb])


def test_diagonalbatchNumba_vs_diagonalNumba():
@pytest.mark.parametrize("n_batches", [1, 3])
SamFerracin marked this conversation as resolved.
Show resolved Hide resolved
def test_diagonalbatchNumba_vs_diagonalNumba(n_batches):
"""Test the batch version works versus the normal diagonal version."""
state = Gaussian(3)
state = Gaussian(3) >> Dgate([0.0, 0.1, 0.2])
A, B, C = wigner_to_bargmann_rho(
state.cov, state.means
) # Create random state (M mode Gaussian state with displacement)

batch = 3
cutoffs = (18, 19, 20, batch)
cutoffs = (18, 19, 20, n_batches)

# Diagonal MM
G_ref = math.hermite_renormalized_diagonal(A, B, C, cutoffs=cutoffs[:-1])

# replicate the B
B_batched = np.stack((B,) * batch, axis=1)
B_batched = np.stack((B,) * n_batches, axis=1)

G_batched = math.hermite_renormalized_diagonal_batch(A, B_batched, C, cutoffs=cutoffs[:-1])

assert np.allclose(G_ref, G_batched[:, :, :, 0])
assert np.allclose(G_ref, G_batched[:, :, :, 1])
assert np.allclose(G_ref, G_batched[:, :, :, 2])
for nb in range(n_batches):
assert np.allclose(G_ref, G_batched[:, :, :, nb])
Loading