-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stage1 mvp batch dimension #269
Closed
Closed
Changes from all commits
Commits
Show all changes
118 commits
Select commit
Hold shift + click to select a range
f329c36
Add utils for the Representations project
ryk-wolf 8f325f9
Add missing Mock classes for tests
ryk-wolf 55a753f
Changelog and blacked tests
ryk-wolf ac8af69
Add Data and tests from legacy branch, tested for non-regression of t…
ryk-wolf 5d1525c
Add tests for QPolyData, stable
ryk-wolf 32a7f8d
Add tests for SymplecticData
ryk-wolf f56e917
Stable basic tests for all data classes
ryk-wolf bba2bde
Proofread tests and code, fixed bugs
ryk-wolf 513c3fe
modify pytest.ini to ignore the parent abstract test files which shou…
ryk-wolf 41241d7
Files blacked and all test suite run
ryk-wolf 6ff67ea
Add first version of the test on multiplication between 2 GaussianDat…
ryk-wolf e48e054
Skip test after debugging revealed it depends on batch dimension
ryk-wolf f141554
Fix multiply
ryk-wolf 448593b
Blacked
ryk-wolf 0733fd9
Blacked Data files
ryk-wolf 299593c
Update changelog
ryk-wolf 58df8cb
Minor updates
ryk-wolf 34ea05b
Merge branch 'stage1_MVP_data' into stage1_MVP_batch_dimension
ryk-wolf 487ee0c
Merge branch 'stage1_MVP_utils' into stage1_MVP_batch_dimension
ryk-wolf a717e50
Single test passing for QPolyData with batch
ryk-wolf 72c5a4c
Updates to qpoly data for batch dimension
ryk-wolf 1c7f496
WIP qpoly mul
ryk-wolf 91a2b76
Tests of mul in qpolydata done
ryk-wolf a2bb429
Two tests for MatVecData
ryk-wolf 4d2edc2
Tests for symplectic mul
ryk-wolf eb5dc8c
qpoly data fixtures update
ryk-wolf d1883c6
update symplectic data checking of whether matrices are symplectic or…
ryk-wolf 926264a
A few tests for batched gaussian data
ryk-wolf 6f5b197
Turn pytest fixtures for vecs and mats into arrays instead of lists
ryk-wolf b609cc3
Minor updates to data structures to support batch
ryk-wolf 5d46333
more GaussianData tests
ryk-wolf 541258d
Updates to GaussianData tests - broken right now
ryk-wolf 3377fbf
Some more tests for GaussianData
ryk-wolf 441cd1e
Update doc to include types in docstring
ryk-wolf 1294712
Minor updates to tests
ryk-wolf 90a93a6
Correctness tests for GaussianData
ryk-wolf 2a10e4b
Modify equality definition in Arraydata and WaveFunctionData
ryk-wolf fcea3d2
Modify the way equality is computed for batched objects
ryk-wolf bb2c6cf
Add one equality test for all Data objects
ryk-wolf 737b288
Add test on MatVecMul
ryk-wolf 2fd8022
Add one more test on MatVec mul
ryk-wolf 89f6827
Update the way add/sub is computed for objects which have same matric…
ryk-wolf c51e28e
Document new helper functions
ryk-wolf 807e72c
Annotating code on pain points
ryk-wolf 3ff9c3a
Correct MatvecData code
ryk-wolf 88feeaf
gaussian data starting to modify coeffs computation in multiplication
ryk-wolf e01b308
tests for length of results of add/sub of MatVecData
ryk-wolf 952d826
corrected tests for add sub in matvecdata based on length of output
ryk-wolf 362c0eb
implements missing GaussianData methods
ziofil 53f4371
test for gaussiandata product
ziofil 30969c4
uses seeded rng
ziofil a50a949
fixes gaussian data product and test
ziofil 649086f
renames qpoly -> abc
ziofil 7965a25
more renaming qpoly -> abc
ziofil 3880d03
restores np.random in tests
ziofil 9a61bf0
removes unnecessary test
ziofil 372a2b2
fix matvec data
ziofil 179995e
fix import
ziofil e13840d
fix imports in data module init file
ziofil fea8167
adds atleast_2d and atleast_3d to math backend
ziofil 819b9d1
handle initializing AbcData with single elements
ziofil 418b134
transform eventual list to tensor
ziofil dcc6eb3
fix changelog
ziofil b0f39e5
Merge branch 'stage1_representations_MVP' into stage1_MVP_batch_dimen…
ziofil c516909
replace NotIimplementedError with Ellipsis in ABC
ziofil f0d21da
Merge branch 'stage1_MVP_batch_dimension' of github.com:XanaduAI/MrMu…
ziofil 5d6e6ae
removed unnecessary imports
ziofil ea259e5
matvec_data improvements
ziofil c563d62
removed unused attribute
ziofil 9ac4279
fixed dtype
ziofil 6160d6f
removed unused argument
ziofil f5551ad
simplify test fixture, reformat
ziofil b20551c
catch ValueError because Tensorflow
ziofil d7889d3
TensorFlow raises ValueError instead of TypeError
ziofil 753bb92
add arrayflatten,symplectic_tp,block_diag to math
ziofil 1e26d28
implement tensor product for symplectic data
ziofil 3da808f
implement tensor product for gaussian data
ziofil d46c2a6
implement tensor product for abc data
ziofil 59288f8
removes commented code
ziofil 3b984cd
simplified test
ziofil 92e9835
remove unused helper functions
ryk-wolf 9c072a1
Remove unused imports, reorder imports, fix typo
ryk-wolf b7cef3c
Blacking
ryk-wolf 894849a
removed unused imports
ziofil 1f32360
adds matmul
ziofil fbc8800
remove unused import
ziofil 2fe6a66
adds __call__ abstract method
ziofil a4301a8
Merge branch 'stage1_MVP_batch_dimension' of github.com:XanaduAI/MrMu…
ziofil e19a72a
fix block_diag function
ziofil dc019dc
fix matmul in abcdata
ziofil 45fe66e
fix assert in matvec data
ziofil d20422d
fix tensor concat
ziofil b2dba20
faster equality check
ziofil 217f112
remove unused methods
ziofil e3f016a
simpler code
ziofil 49bc5b4
fix multi-index matmul
ziofil b89e209
detect out of bounds index
ziofil d92454a
allow int
ziofil bc47d85
new tests
ziofil f8467aa
remove test
ziofil 40653d6
improved abc_data, adds conjugate
ziofil 2a45ab7
more efficient equality check
ziofil 0d80d13
equality settings
ziofil 67202d7
remove the abstract function call
sylviemonet f86b0d0
import change orders
sylviemonet 8173885
fix the test errors
sylviemonet f93a3dd
fix more errors
sylviemonet ab01087
put tests from data to matvecdata because of the batch
sylviemonet 0b1a87c
change the assertation in init of matvecdata into raise error
sylviemonet 45e2fa1
fix the batch errors in the test
sylviemonet 7255fde
Merge branch 'stage1_representations_MVP' into stage1_MVP_batch_dimen…
sylviemonet 44cd2ea
remove my_stuff folder
ziofil 5c15fa3
Merge branch 'stage1_representations_MVP' into stage1_MVP_batch_dimen…
ziofil 56c1f87
Add notimplement error for matmul in ABCdata and removed unused import
sylviemonet 1a762e7
store the function getitem
sylviemonet 198b70e
blacked
sylviemonet 2291610
Merge branch 'stage1_MVP_batch_dimension' of https://github.com/Xanad…
ziofil 2b54fb2
added back matmul for ABCData - tests are missing
ziofil File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2021 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" This package contains the modules implementing base classes for datas. | ||
""" | ||
from .gaussian_data import GaussianData | ||
from .abc_data import ABCData | ||
from .array_data import ArrayData | ||
from .symplectic_data import SymplecticData | ||
from .wavefunctionarray_data import WavefunctionArrayData | ||
|
||
__all__ = ["GaussianData", "ABCData", "ArrayData", "SymplecticData", "WavefunctionArrayData"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright 2023 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from itertools import product | ||
from typing import Optional, Union | ||
import numpy as np | ||
|
||
from mrmustard.lab.representations.data.matvec_data import MatVecData | ||
from mrmustard.math import Math | ||
from mrmustard.typing import Batch, ComplexMatrix, ComplexVector, Scalar | ||
|
||
math = Math() | ||
|
||
|
||
class ABCData(MatVecData): | ||
r"""Exponential of quadratic polynomial for the Bargmann representation. | ||
|
||
Quadratic Gaussian data is made of: quadratic coefficients, linear coefficients, constant. | ||
Each of these has a batch dimension, and the batch dimension is the same for all of them. | ||
They are the parameters of the function `c * exp(x^T A x / 2 + x^T b)`. | ||
|
||
Note that if constants are not provided, they will all be initialized at 1. | ||
|
||
Args: | ||
A (Batch[Matrix]): series of quadratic coefficient | ||
b (Batch[Vector]): series of linear coefficients | ||
c (Optional[Batch[Scalar]]):series of constants | ||
""" | ||
|
||
def __init__( | ||
self, A: Batch[ComplexMatrix], b: Batch[ComplexVector], c: Optional[Batch[Scalar]] = None | ||
) -> None: | ||
super().__init__(mat=A, vec=b, coeffs=c) | ||
self._contract_idxs = [] | ||
|
||
def value(self, x: ComplexVector) -> Scalar: | ||
r"""Value of this function at x. | ||
|
||
Args: | ||
x (Vector): point at which the function is evaluated | ||
|
||
Returns: | ||
Scalar: value of the function | ||
""" | ||
val = 0.0 | ||
for A, b, c in zip(self.A, self.b, self.c): | ||
val += math.exp(0.5 * math.sum(x * math.matvec(A, x)) + math.sum(x * b)) * c | ||
return val | ||
|
||
def __call__(self, x: ComplexVector) -> Scalar: | ||
return self.value(x) | ||
|
||
@property | ||
def A(self) -> Batch[ComplexMatrix]: | ||
return self.mat | ||
|
||
@property | ||
def b(self) -> Batch[ComplexVector]: | ||
return self.vec | ||
|
||
@property | ||
def c(self) -> Batch[Scalar]: | ||
return self.coeffs | ||
|
||
def __mul__(self, other: Union[Scalar, ABCData]) -> ABCData: | ||
if isinstance(other, ABCData): | ||
new_a = [A1 + A2 for A1, A2 in product(self.A, other.A)] | ||
new_b = [b1 + b2 for b1, b2 in product(self.b, other.b)] | ||
new_c = [c1 * c2 for c1, c2 in product(self.c, other.c)] | ||
return self.__class__(A=new_a, b=new_b, c=new_c) | ||
else: | ||
try: # scalar | ||
return self.__class__(self.A, self.b, other * self.c) | ||
except Exception as e: # Neither same object type nor a scalar case | ||
raise TypeError(f"Cannot multiply {self.__class__} and {other.__class__}.") from e | ||
|
||
def __and__(self, other: ABCData) -> ABCData: | ||
As = [math.block_diag(a1, a2) for a1 in self.A for a2 in other.A] | ||
bs = [math.concat([b1, b2], axis=-1) for b1 in self.b for b2 in other.b] | ||
cs = [c1 * c2 for c1 in self.c for c2 in other.c] | ||
return self.__class__(As, bs, cs) | ||
|
||
def conj(self): | ||
new = self.__class__(math.conj(self.A), math.conj(self.b), math.conj(self.c)) | ||
new._contract_idxs = self._contract_idxs | ||
return new | ||
|
||
def __matmul__(self, other: ABCData) -> ABCData: | ||
r"""Implements the contraction of (A,b,c) triples across the marked indices.""" | ||
# Useful for the future, but not for this PR | ||
# raise NotImplementedError() | ||
|
||
graph = self & other | ||
newA = graph.A | ||
newb = graph.b | ||
newc = graph.c | ||
for n, (i, j) in enumerate(zip(self._contract_idxs, other._contract_idxs)): | ||
i = i - np.sum(np.array(self._contract_idxs[:n]) < i) | ||
j = j + self.dim - n - np.sum(np.array(other._contract_idxs[:n]) < j) | ||
noij = list(range(i)) + list(range(i + 1, j)) + list(range(j + 1, newA.shape[-1])) | ||
Abar = math.gather(math.gather(newA, noij, axis=1), noij, axis=2) | ||
bbar = math.gather(newb, noij, axis=1) | ||
D = math.gather( | ||
math.concat([newA[..., i][..., None], newA[..., j][..., None]], axis=-1), | ||
noij, | ||
axis=1, | ||
) | ||
M = math.concat( | ||
[ | ||
math.concat( | ||
[ | ||
newA[:, i, i][:, None, None], | ||
newA[:, j, i][:, None, None] - 1, | ||
], | ||
axis=-1, | ||
), | ||
math.concat( | ||
[ | ||
newA[:, i, j][:, None, None] - 1, | ||
newA[:, j, j][:, None, None], | ||
], | ||
axis=-1, | ||
), | ||
], | ||
axis=-2, | ||
) | ||
Minv = math.inv(M) | ||
b_ = math.concat([newb[:, i][:, None], newb[:, j][:, None]], axis=-1) | ||
|
||
newA = Abar - math.einsum("bij,bjk,blk", D, Minv, D) | ||
newb = bbar - math.einsum("bij,bjk,bk", D, Minv, b_) | ||
newc = ( | ||
newc | ||
* math.exp(-math.einsum("bi,bij,bj", b_, Minv, b_) / 2) | ||
/ math.sqrt(-math.det(M)) | ||
) | ||
return self.__class__(newA, newb, newc) | ||
|
||
def __getitem__(self, idx: int | tuple[int, ...]) -> ABCData: | ||
idx = (idx,) if isinstance(idx, int) else idx | ||
for i in idx: | ||
if i > self.dim: | ||
raise IndexError( | ||
f"Index {i} out of bounds for {self.__class__.__qualname__} of dimension {self.dim}." | ||
) | ||
new = self.__class__(self.A, self.b, self.c) | ||
new._contract_idxs = idx | ||
return new | ||
|
||
def transpose(self, order: tuple[int, ...] | list[int]) -> ABCData: | ||
new = self.__class__( | ||
A=math.gather(math.gather(self.A, order, -1), order, -2), | ||
b=math.gather(self.b, order, -1), | ||
c=self.c, | ||
) | ||
new._contract_idxs = self._contract_idxs | ||
return new |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2023 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from mrmustard.lab.representations.data.data import Data | ||
from mrmustard.math import Math | ||
from mrmustard.typing import Scalar, Vector | ||
|
||
math = Math() | ||
|
||
|
||
class ArrayData(Data): | ||
"""Contains array-like data for certain Representation objects. | ||
|
||
Args: | ||
array (Vector) : data to be contained in the class | ||
""" | ||
|
||
def __init__(self, array: Vector) -> None: | ||
self.array = array | ||
|
||
@property | ||
def cutoffs(self): | ||
return self.array.shape | ||
|
||
def __neg__(self) -> Data: | ||
return self.__class__(array=-self.array) | ||
|
||
def __eq__(self, other: ArrayData) -> bool: | ||
try: | ||
return np.allclose(self.array, other.array) | ||
except AttributeError as e: | ||
raise TypeError(f"Cannot compare {self.__class__} and {other.__class__}.") from e | ||
|
||
def __add__(self, other: ArrayData) -> ArrayData: | ||
try: | ||
return self.__class__(array=self.array + other.array) | ||
except AttributeError as e: | ||
raise TypeError(f"Cannot add/subtract {self.__class__} and {other.__class__}.") from e | ||
|
||
def __truediv__(self, other: Scalar) -> ArrayData: | ||
try: | ||
return self.__class__(array=self.array / other) | ||
except TypeError as e: | ||
raise TypeError("Can only divide by a scalar.") from e | ||
|
||
def __mul__(self, other: Scalar) -> ArrayData: | ||
try: | ||
return self.__class__(array=self.array * other) | ||
except TypeError as e: | ||
raise TypeError("Can only multiply by a scalar.") from e | ||
|
||
def __and__(self, other: ArrayData) -> ArrayData: | ||
try: | ||
return self.__class__(array=np.outer(self.array, other.array)) | ||
except AttributeError as e: | ||
raise TypeError(f"Cannot tensor product {self.__class__} and {other.__class__}.") from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2023 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, Union | ||
|
||
from mrmustard.typing import Scalar | ||
|
||
|
||
class Data(ABC): | ||
r"""Abstract parent class for types of data encoding a quantum state's representation.""" | ||
|
||
@abstractmethod | ||
def __neg__(self) -> Data: | ||
... | ||
|
||
@abstractmethod | ||
def __eq__(self, other: Data) -> bool: | ||
... | ||
|
||
@abstractmethod | ||
def __add__(self, other: Data) -> Data: | ||
... | ||
|
||
def __sub__(self, other: Data) -> Data: | ||
try: | ||
return self.__add__(-other) | ||
except AttributeError as e: | ||
raise TypeError(f"Cannot subtract {self.__class__} and {other.__class__}.") from e | ||
|
||
# @abstractmethod | ||
# def __call__(self, dom: Any) -> Scalar: | ||
# r"""Evaluate the function at a point in the domain.""" | ||
# ... | ||
|
||
@abstractmethod | ||
def __truediv__(self, other: Union[Scalar, Data]) -> Data: | ||
... | ||
|
||
@abstractmethod | ||
def __mul__(self, other: Union[Scalar, Data]) -> Data: | ||
... | ||
|
||
def __rmul__(self, other: Scalar) -> Data: | ||
return self.__mul__(other=other) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if
other
is an array that has compatible shape withself.array
, the multiplication will still be successful. So we aren't only allowing multiplication by scalars.