Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stage1 mvp batch dimension #269

Closed
wants to merge 118 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
f329c36
Add utils for the Representations project
ryk-wolf Jul 10, 2023
8f325f9
Add missing Mock classes for tests
ryk-wolf Jul 10, 2023
55a753f
Changelog and blacked tests
ryk-wolf Jul 10, 2023
ac8af69
Add Data and tests from legacy branch, tested for non-regression of t…
ryk-wolf Jul 10, 2023
5d1525c
Add tests for QPolyData, stable
ryk-wolf Jul 10, 2023
32a7f8d
Add tests for SymplecticData
ryk-wolf Jul 11, 2023
f56e917
Stable basic tests for all data classes
ryk-wolf Jul 11, 2023
bba2bde
Proofread tests and code, fixed bugs
ryk-wolf Jul 12, 2023
513c3fe
modify pytest.ini to ignore the parent abstract test files which shou…
ryk-wolf Jul 12, 2023
41241d7
Files blacked and all test suite run
ryk-wolf Jul 12, 2023
6ff67ea
Add first version of the test on multiplication between 2 GaussianDat…
ryk-wolf Jul 12, 2023
e48e054
Skip test after debugging revealed it depends on batch dimension
ryk-wolf Jul 14, 2023
f141554
Fix multiply
ryk-wolf Jul 14, 2023
448593b
Blacked
ryk-wolf Jul 14, 2023
0733fd9
Blacked Data files
ryk-wolf Jul 14, 2023
299593c
Update changelog
ryk-wolf Jul 14, 2023
58df8cb
Minor updates
ryk-wolf Jul 14, 2023
34ea05b
Merge branch 'stage1_MVP_data' into stage1_MVP_batch_dimension
ryk-wolf Jul 19, 2023
487ee0c
Merge branch 'stage1_MVP_utils' into stage1_MVP_batch_dimension
ryk-wolf Jul 19, 2023
a717e50
Single test passing for QPolyData with batch
ryk-wolf Jul 26, 2023
72c5a4c
Updates to qpoly data for batch dimension
ryk-wolf Jul 27, 2023
1c7f496
WIP qpoly mul
ryk-wolf Jul 28, 2023
91a2b76
Tests of mul in qpolydata done
ryk-wolf Jul 31, 2023
a2bb429
Two tests for MatVecData
ryk-wolf Aug 1, 2023
4d2edc2
Tests for symplectic mul
ryk-wolf Aug 1, 2023
eb5dc8c
qpoly data fixtures update
ryk-wolf Aug 1, 2023
d1883c6
update symplectic data checking of whether matrices are symplectic or…
ryk-wolf Aug 1, 2023
926264a
A few tests for batched gaussian data
ryk-wolf Aug 1, 2023
6f5b197
Turn pytest fixtures for vecs and mats into arrays instead of lists
ryk-wolf Aug 1, 2023
b609cc3
Minor updates to data structures to support batch
ryk-wolf Aug 1, 2023
5d46333
more GaussianData tests
ryk-wolf Aug 2, 2023
541258d
Updates to GaussianData tests - broken right now
ryk-wolf Aug 3, 2023
3377fbf
Some more tests for GaussianData
ryk-wolf Aug 4, 2023
441cd1e
Update doc to include types in docstring
ryk-wolf Aug 4, 2023
1294712
Minor updates to tests
ryk-wolf Aug 4, 2023
90a93a6
Correctness tests for GaussianData
ryk-wolf Aug 4, 2023
2a10e4b
Modify equality definition in Arraydata and WaveFunctionData
ryk-wolf Aug 8, 2023
fcea3d2
Modify the way equality is computed for batched objects
ryk-wolf Aug 8, 2023
bb2c6cf
Add one equality test for all Data objects
ryk-wolf Aug 8, 2023
737b288
Add test on MatVecMul
ryk-wolf Aug 8, 2023
2fd8022
Add one more test on MatVec mul
ryk-wolf Aug 8, 2023
89f6827
Update the way add/sub is computed for objects which have same matric…
ryk-wolf Aug 9, 2023
c51e28e
Document new helper functions
ryk-wolf Aug 9, 2023
807e72c
Annotating code on pain points
ryk-wolf Aug 9, 2023
3ff9c3a
Correct MatvecData code
ryk-wolf Aug 10, 2023
88feeaf
gaussian data starting to modify coeffs computation in multiplication
ryk-wolf Aug 16, 2023
e01b308
tests for length of results of add/sub of MatVecData
ryk-wolf Aug 22, 2023
952d826
corrected tests for add sub in matvecdata based on length of output
ryk-wolf Aug 22, 2023
362c0eb
implements missing GaussianData methods
ziofil Aug 22, 2023
53f4371
test for gaussiandata product
ziofil Aug 22, 2023
30969c4
uses seeded rng
ziofil Aug 22, 2023
a50a949
fixes gaussian data product and test
ziofil Aug 23, 2023
649086f
renames qpoly -> abc
ziofil Aug 23, 2023
7965a25
more renaming qpoly -> abc
ziofil Aug 23, 2023
3880d03
restores np.random in tests
ziofil Aug 23, 2023
9a61bf0
removes unnecessary test
ziofil Aug 23, 2023
372a2b2
fix matvec data
ziofil Aug 23, 2023
179995e
fix import
ziofil Aug 23, 2023
e13840d
fix imports in data module init file
ziofil Aug 23, 2023
fea8167
adds atleast_2d and atleast_3d to math backend
ziofil Aug 23, 2023
819b9d1
handle initializing AbcData with single elements
ziofil Aug 23, 2023
418b134
transform eventual list to tensor
ziofil Aug 23, 2023
dcc6eb3
fix changelog
ziofil Aug 23, 2023
b0f39e5
Merge branch 'stage1_representations_MVP' into stage1_MVP_batch_dimen…
ziofil Aug 23, 2023
c516909
replace NotIimplementedError with Ellipsis in ABC
ziofil Aug 23, 2023
f0d21da
Merge branch 'stage1_MVP_batch_dimension' of github.com:XanaduAI/MrMu…
ziofil Aug 23, 2023
5d6e6ae
removed unnecessary imports
ziofil Aug 23, 2023
ea259e5
matvec_data improvements
ziofil Aug 23, 2023
c563d62
removed unused attribute
ziofil Aug 23, 2023
9ac4279
fixed dtype
ziofil Aug 23, 2023
6160d6f
removed unused argument
ziofil Aug 23, 2023
f5551ad
simplify test fixture, reformat
ziofil Aug 23, 2023
b20551c
catch ValueError because Tensorflow
ziofil Aug 23, 2023
d7889d3
TensorFlow raises ValueError instead of TypeError
ziofil Aug 23, 2023
753bb92
add arrayflatten,symplectic_tp,block_diag to math
ziofil Aug 23, 2023
1e26d28
implement tensor product for symplectic data
ziofil Aug 23, 2023
3da808f
implement tensor product for gaussian data
ziofil Aug 23, 2023
d46c2a6
implement tensor product for abc data
ziofil Aug 23, 2023
59288f8
removes commented code
ziofil Aug 23, 2023
3b984cd
simplified test
ziofil Aug 23, 2023
92e9835
remove unused helper functions
ryk-wolf Aug 23, 2023
9c072a1
Remove unused imports, reorder imports, fix typo
ryk-wolf Aug 25, 2023
b7cef3c
Blacking
ryk-wolf Aug 25, 2023
894849a
removed unused imports
ziofil Aug 28, 2023
1f32360
adds matmul
ziofil Aug 28, 2023
fbc8800
remove unused import
ziofil Aug 28, 2023
2fe6a66
adds __call__ abstract method
ziofil Aug 28, 2023
a4301a8
Merge branch 'stage1_MVP_batch_dimension' of github.com:XanaduAI/MrMu…
ziofil Aug 28, 2023
e19a72a
fix block_diag function
ziofil Aug 28, 2023
dc019dc
fix matmul in abcdata
ziofil Aug 28, 2023
45fe66e
fix assert in matvec data
ziofil Aug 28, 2023
d20422d
fix tensor concat
ziofil Aug 28, 2023
b2dba20
faster equality check
ziofil Aug 28, 2023
217f112
remove unused methods
ziofil Aug 28, 2023
e3f016a
simpler code
ziofil Aug 28, 2023
49bc5b4
fix multi-index matmul
ziofil Aug 28, 2023
b89e209
detect out of bounds index
ziofil Aug 28, 2023
d92454a
allow int
ziofil Aug 28, 2023
bc47d85
new tests
ziofil Aug 31, 2023
f8467aa
remove test
ziofil Aug 31, 2023
40653d6
improved abc_data, adds conjugate
ziofil Sep 5, 2023
2a45ab7
more efficient equality check
ziofil Sep 5, 2023
0d80d13
equality settings
ziofil Sep 5, 2023
67202d7
remove the abstract function call
sylviemonet Sep 6, 2023
f86b0d0
import change orders
sylviemonet Sep 6, 2023
8173885
fix the test errors
sylviemonet Sep 6, 2023
f93a3dd
fix more errors
sylviemonet Sep 6, 2023
ab01087
put tests from data to matvecdata because of the batch
sylviemonet Sep 7, 2023
0b1a87c
change the assertation in init of matvecdata into raise error
sylviemonet Sep 7, 2023
45e2fa1
fix the batch errors in the test
sylviemonet Sep 7, 2023
7255fde
Merge branch 'stage1_representations_MVP' into stage1_MVP_batch_dimen…
sylviemonet Sep 7, 2023
44cd2ea
remove my_stuff folder
ziofil Sep 19, 2023
5c15fa3
Merge branch 'stage1_representations_MVP' into stage1_MVP_batch_dimen…
ziofil Sep 19, 2023
56c1f87
Add notimplement error for matmul in ABCdata and removed unused import
sylviemonet Sep 20, 2023
1a762e7
store the function getitem
sylviemonet Sep 20, 2023
198b70e
blacked
sylviemonet Sep 20, 2023
2291610
Merge branch 'stage1_MVP_batch_dimension' of https://github.com/Xanad…
ziofil Oct 4, 2023
2b54fb2
added back matmul for ABCData - tests are missing
ziofil Oct 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@
### New features

* A general factory method and a duck style type-checker.
[#256](https://github.com/XanaduAI/MrMustard/pull/256)
[(#256)](https://github.com/XanaduAI/MrMustard/pull/256)

* Data classes for the representations project. Includes `WavefunctionArrayData`, `GaussianData`,
`ABCData` and `SymplecticData`.
[(#258)](https://github.com/XanaduAI/MrMustard/pull/258)

### Breaking changes

### Improvements

### Bug fixes

* Fixed a bug about the variable names in functions (apply_kraus_to_ket, apply_kraus_to_dm, apply_choi_to_ket, apply_choi_to_dm). [(#271)](https://github.com/XanaduAI/MrMustard/pull/271)
* Fixed a bug about the variable names in functions (apply_kraus_to_ket, apply_kraus_to_dm, apply_choi_to_ket, apply_choi_to_dm).
[(#271)](https://github.com/XanaduAI/MrMustard/pull/271)

### Documentation

### Contributors
[Yuan Yao](https://github.com/sylviemonet), [Richard A. Wolf](https://github.com/ryk-wolf)

[Richard A. Wolf](https://github.com/ryk-wolf), [Filippo Miatto](https://github.com/ziofil), [Yuan Yao](https://github.com/sylviemonet)


# Release 0.5.0 (current release)
Expand Down Expand Up @@ -105,6 +111,9 @@
* More robust implementation of cutoffs for States.
[(#239)](https://github.com/XanaduAI/MrMustard/pull/239)

* Dependencies and versioning are now managed using Poetry.
[(#257)](https://github.com/XanaduAI/MrMustard/pull/257)

### Bug fixes

* Fixed a bug that would make two progress bars appear during an optimization
Expand All @@ -126,6 +135,7 @@ cutoff of the first detector is equal to 1, the resulting density matrix is now
[Robbe De Prins](https://github.com/rdprins), [Gabriele Gullì](https://github.com/ggulli),
[Richard A. Wolf](https://github.com/ryk-wolf)


---

# Release 0.4.1
Expand Down
1 change: 1 addition & 0 deletions mrmustard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self):
self.EQ_TRANSFORMATION_CUTOFF = 3 # 3 is enough to include a full step of the rec relations
self.EQ_TRANSFORMATION_RTOL_FOCK = 1e-3
self.EQ_TRANSFORMATION_RTOL_GAUSS = 1e-6
self.EQUALITY_PRECISION_DECIMALS = 6 # decimals for (A,b,c) equality check
# for the detectors
self.PNR_INTERNAL_CUTOFF = 50
self.HOMODYNE_SQUEEZING = 10.0
Expand Down
23 changes: 23 additions & 0 deletions mrmustard/lab/representations/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" This package contains the modules implementing base classes for datas.
"""
from .gaussian_data import GaussianData
from .abc_data import ABCData
from .array_data import ArrayData
from .symplectic_data import SymplecticData
from .wavefunctionarray_data import WavefunctionArrayData

__all__ = ["GaussianData", "ABCData", "ArrayData", "SymplecticData", "WavefunctionArrayData"]
170 changes: 170 additions & 0 deletions mrmustard/lab/representations/data/abc_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from itertools import product
from typing import Optional, Union
import numpy as np

from mrmustard.lab.representations.data.matvec_data import MatVecData
from mrmustard.math import Math
from mrmustard.typing import Batch, ComplexMatrix, ComplexVector, Scalar

math = Math()


class ABCData(MatVecData):
r"""Exponential of quadratic polynomial for the Bargmann representation.

Quadratic Gaussian data is made of: quadratic coefficients, linear coefficients, constant.
Each of these has a batch dimension, and the batch dimension is the same for all of them.
They are the parameters of the function `c * exp(x^T A x / 2 + x^T b)`.

Note that if constants are not provided, they will all be initialized at 1.

Args:
A (Batch[Matrix]): series of quadratic coefficient
b (Batch[Vector]): series of linear coefficients
c (Optional[Batch[Scalar]]):series of constants
"""

def __init__(
self, A: Batch[ComplexMatrix], b: Batch[ComplexVector], c: Optional[Batch[Scalar]] = None
) -> None:
super().__init__(mat=A, vec=b, coeffs=c)
self._contract_idxs = []

def value(self, x: ComplexVector) -> Scalar:
r"""Value of this function at x.

Args:
x (Vector): point at which the function is evaluated

Returns:
Scalar: value of the function
"""
val = 0.0
for A, b, c in zip(self.A, self.b, self.c):
val += math.exp(0.5 * math.sum(x * math.matvec(A, x)) + math.sum(x * b)) * c
return val

def __call__(self, x: ComplexVector) -> Scalar:
return self.value(x)

@property
def A(self) -> Batch[ComplexMatrix]:
return self.mat

@property
def b(self) -> Batch[ComplexVector]:
return self.vec

@property
def c(self) -> Batch[Scalar]:
return self.coeffs

def __mul__(self, other: Union[Scalar, ABCData]) -> ABCData:
if isinstance(other, ABCData):
new_a = [A1 + A2 for A1, A2 in product(self.A, other.A)]
new_b = [b1 + b2 for b1, b2 in product(self.b, other.b)]
new_c = [c1 * c2 for c1, c2 in product(self.c, other.c)]
return self.__class__(A=new_a, b=new_b, c=new_c)
else:
try: # scalar
return self.__class__(self.A, self.b, other * self.c)
except Exception as e: # Neither same object type nor a scalar case
raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e

def __and__(self, other: ABCData) -> ABCData:
As = [math.block_diag(a1, a2) for a1 in self.A for a2 in other.A]
bs = [math.concat([b1, b2], axis=-1) for b1 in self.b for b2 in other.b]
cs = [c1 * c2 for c1 in self.c for c2 in other.c]
return self.__class__(As, bs, cs)

def conj(self):
new = self.__class__(math.conj(self.A), math.conj(self.b), math.conj(self.c))
new._contract_idxs = self._contract_idxs
return new

def __matmul__(self, other: ABCData) -> ABCData:
r"""Implements the contraction of (A,b,c) triples across the marked indices."""
# Useful for the future, but not for this PR
# raise NotImplementedError()

graph = self & other
newA = graph.A
newb = graph.b
newc = graph.c
for n, (i, j) in enumerate(zip(self._contract_idxs, other._contract_idxs)):
i = i - np.sum(np.array(self._contract_idxs[:n]) < i)
j = j + self.dim - n - np.sum(np.array(other._contract_idxs[:n]) < j)
noij = list(range(i)) + list(range(i + 1, j)) + list(range(j + 1, newA.shape[-1]))
Abar = math.gather(math.gather(newA, noij, axis=1), noij, axis=2)
bbar = math.gather(newb, noij, axis=1)
D = math.gather(
math.concat([newA[..., i][..., None], newA[..., j][..., None]], axis=-1),
noij,
axis=1,
)
M = math.concat(
[
math.concat(
[
newA[:, i, i][:, None, None],
newA[:, j, i][:, None, None] - 1,
],
axis=-1,
),
math.concat(
[
newA[:, i, j][:, None, None] - 1,
newA[:, j, j][:, None, None],
],
axis=-1,
),
],
axis=-2,
)
Minv = math.inv(M)
b_ = math.concat([newb[:, i][:, None], newb[:, j][:, None]], axis=-1)

newA = Abar - math.einsum("bij,bjk,blk", D, Minv, D)
newb = bbar - math.einsum("bij,bjk,bk", D, Minv, b_)
newc = (
newc
* math.exp(-math.einsum("bi,bij,bj", b_, Minv, b_) / 2)
/ math.sqrt(-math.det(M))
)
return self.__class__(newA, newb, newc)

def __getitem__(self, idx: int | tuple[int, ...]) -> ABCData:
idx = (idx,) if isinstance(idx, int) else idx
for i in idx:
if i > self.dim:
raise IndexError(
f"Index {i} out of bounds for {self.__class__.__qualname__} of dimension {self.dim}."
)
new = self.__class__(self.A, self.b, self.c)
new._contract_idxs = idx
return new

def transpose(self, order: tuple[int, ...] | list[int]) -> ABCData:
new = self.__class__(
A=math.gather(math.gather(self.A, order, -1), order, -2),
b=math.gather(self.b, order, -1),
c=self.c,
)
new._contract_idxs = self._contract_idxs
return new
71 changes: 71 additions & 0 deletions mrmustard/lab/representations/data/array_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import numpy as np

from mrmustard.lab.representations.data.data import Data
from mrmustard.math import Math
from mrmustard.typing import Scalar, Vector

math = Math()


class ArrayData(Data):
"""Contains array-like data for certain Representation objects.

Args:
array (Vector) : data to be contained in the class
"""

def __init__(self, array: Vector) -> None:
self.array = array

@property
def cutoffs(self):
return self.array.shape

def __neg__(self) -> Data:
return self.__class__(array=-self.array)

def __eq__(self, other: ArrayData) -> bool:
try:
return np.allclose(self.array, other.array)
except AttributeError as e:
raise TypeError(f"Cannot compare {self.__class__} and {other.__class__}.") from e

def __add__(self, other: ArrayData) -> ArrayData:
try:
return self.__class__(array=self.array + other.array)
except AttributeError as e:
raise TypeError(f"Cannot add/subtract {self.__class__} and {other.__class__}.") from e

def __truediv__(self, other: Scalar) -> ArrayData:
try:
return self.__class__(array=self.array / other)
except TypeError as e:
raise TypeError("Can only divide by a scalar.") from e

def __mul__(self, other: Scalar) -> ArrayData:
try:
return self.__class__(array=self.array * other)
except TypeError as e:
raise TypeError("Can only multiply by a scalar.") from e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if other is an array that has compatible shape with self.array, the multiplication will still be successful. So we aren't only allowing multiplication by scalars.


def __and__(self, other: ArrayData) -> ArrayData:
try:
return self.__class__(array=np.outer(self.array, other.array))
except AttributeError as e:
raise TypeError(f"Cannot tensor product {self.__class__} and {other.__class__}.") from e
58 changes: 58 additions & 0 deletions mrmustard/lab/representations/data/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Union

from mrmustard.typing import Scalar


class Data(ABC):
r"""Abstract parent class for types of data encoding a quantum state's representation."""

@abstractmethod
def __neg__(self) -> Data:
...

@abstractmethod
def __eq__(self, other: Data) -> bool:
...

@abstractmethod
def __add__(self, other: Data) -> Data:
...

def __sub__(self, other: Data) -> Data:
try:
return self.__add__(-other)
except AttributeError as e:
raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e

# @abstractmethod
# def __call__(self, dom: Any) -> Scalar:
# r"""Evaluate the function at a point in the domain."""
# ...

@abstractmethod
def __truediv__(self, other: Union[Scalar, Data]) -> Data:
...

@abstractmethod
def __mul__(self, other: Union[Scalar, Data]) -> Data:
...

def __rmul__(self, other: Scalar) -> Data:
return self.__mul__(other=other)
Loading
Loading