-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
362 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# Copyright 2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Samplers for measurement devices. | ||
""" | ||
|
||
from __future__ import annotations | ||
from itertools import product | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from typing import Any, Sequence | ||
|
||
import numpy as np | ||
|
||
from mrmustard import math, settings | ||
|
||
from .states import State, Number, QuadratureEigenstate | ||
from .circuit_components import CircuitComponent | ||
|
||
__all__ = ["Sampler", "PNRSampler", "HomodyneSampler"] | ||
|
||
|
||
class Sampler(ABC): | ||
r""" | ||
A sampler for measurements of quantum circuits. | ||
Args: | ||
meas_outcomes: The measurement outcomes for this sampler. | ||
povms: The POVMs of this sampler. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
meas_outcomes: Sequence[Any], | ||
povms: CircuitComponent | Sequence[CircuitComponent], | ||
): | ||
self._povms = povms | ||
self._meas_outcomes = meas_outcomes | ||
|
||
self._outcome_arg = None | ||
|
||
@property | ||
def povms(self) -> CircuitComponent | Sequence[CircuitComponent]: | ||
r""" | ||
The POVMs of this sampler. | ||
""" | ||
return self._povms | ||
|
||
@property | ||
def meas_outcomes(self) -> Sequence[Any]: | ||
r""" | ||
The measurement outcomes of this sampler. | ||
""" | ||
return self._meas_outcomes | ||
|
||
@abstractmethod | ||
def probabilities(self, state: State, atol: float = 1e-4) -> Sequence[float]: | ||
r""" | ||
Returns the probability distribution of a state w.r.t. measurement outcomes. | ||
Args: | ||
state: The state to generate the probability distribution of. Note: the | ||
input state must be normalized. | ||
atol: The absolute tolerance used for validating that the computed | ||
probability distribution sums to ``1``. | ||
""" | ||
|
||
def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray: | ||
r""" | ||
Returns an array of samples given a state. | ||
Args: | ||
state: The state to sample. | ||
n_samples: The number of samples to generate. | ||
seed: An optional seed for random sampling. | ||
Returns: | ||
An array of samples such that the shape is ``(n_samples, n_modes)``. | ||
""" | ||
initial_mode = state.modes[0] | ||
initial_samples = self.sample_prob_dist(state[initial_mode], n_samples, seed) | ||
|
||
if len(state.modes) == 1: | ||
return initial_samples | ||
|
||
unique_samples, counts = np.unique(initial_samples, return_counts=True) | ||
ret = [] | ||
for unique_sample, counts in zip(unique_samples, counts): | ||
meas_op = self._get_povm(unique_sample, initial_mode).dual | ||
reduced_state = (state >> meas_op).normalize() | ||
samples = self.sample(reduced_state, counts) | ||
for sample in samples: | ||
ret.append(np.append([unique_sample], sample)) | ||
return np.array(ret) | ||
|
||
def sample_prob_dist( | ||
self, state: State, n_samples: int = 1000, seed: int | None = None | ||
) -> np.ndarray: | ||
r""" | ||
Samples a state by computing the probability distribution. | ||
Args: | ||
state: The state to sample. | ||
n_samples: The number of samples to generate. | ||
seed: An optional seed for random sampling. | ||
""" | ||
rng = np.random.default_rng(seed) if seed else settings.rng | ||
return rng.choice( | ||
a=list(product(self.meas_outcomes, repeat=len(state.modes))), | ||
p=self.probabilities(state), | ||
size=n_samples, | ||
) | ||
|
||
def _get_povm(self, meas_outcome: Any, mode: int) -> CircuitComponent: | ||
r""" | ||
Returns the POVM associated with a given outcome on a given mode. | ||
Args: | ||
meas_outcome: The measurement outcome. | ||
mode: The mode. | ||
Returns: | ||
The POVM circuit component. | ||
""" | ||
if isinstance(self.povms, CircuitComponent): | ||
kwargs = self.povms.parameter_set.to_dict() | ||
kwargs[self._outcome_arg] = meas_outcome | ||
return self.povms.__class__(modes=[mode], **kwargs) | ||
else: | ||
return self.povms[self.meas_outcomes.index(meas_outcome)].on([mode]) | ||
|
||
def _validate_probs(self, probs: Sequence[float], atol: float) -> Sequence[float]: | ||
r""" | ||
Validates that the given probability distribution sums to ``1`` within some | ||
tolerance and returns a renormalized probability distribution to account for | ||
small numerical errors. | ||
Args: | ||
probs: The probability distribution to validate. | ||
atol: The absolute tolerance to validate with. | ||
""" | ||
atol = atol or settings.ATOL | ||
prob_sum = sum(probs) | ||
if not math.allclose(prob_sum, 1, atol): | ||
raise ValueError(f"Probabilities sum to {prob_sum} and not 1.0.") | ||
return math.real(probs / prob_sum) | ||
|
||
|
||
class PNRSampler(Sampler): | ||
r""" | ||
A sampler for photon-number resolving (PNR) detectors. | ||
Args: | ||
cutoff: The photon number cutoff. | ||
""" | ||
|
||
def __init__(self, cutoff: int) -> None: | ||
super().__init__(list(range(cutoff)), Number([0], 0)) | ||
self._cutoff = cutoff | ||
self._outcome_arg = "n" | ||
|
||
def probabilities(self, state, atol=1e-4): | ||
return self._validate_probs(state.fock_distribution(self._cutoff), atol) | ||
|
||
|
||
class HomodyneSampler(Sampler): | ||
r""" | ||
A sampler for homodyne measurements. | ||
Args: | ||
phi: The quadrature angle where ``0`` corresponds to ``x`` and ``\pi/2`` to ``p``. | ||
bounds: The range of values to discretize over. | ||
num: The number of points to discretize over. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
phi: float = 0, | ||
bounds: tuple[float, float] = (-10, 10), | ||
num: int = 1000, | ||
) -> None: | ||
meas_outcomes, step = np.linspace(*bounds, num, retstep=True) | ||
super().__init__( | ||
list(meas_outcomes), | ||
QuadratureEigenstate([0], x=0, phi=phi), | ||
) | ||
self._step = step | ||
self._outcome_arg = "x" | ||
|
||
def probabilities(self, state, atol=1e-4): | ||
probs = state.quadrature_distribution( | ||
self.meas_outcomes, self.povms.phi.value[0] | ||
) * self._step ** len(state.modes) | ||
return self._validate_probs(probs, atol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for the sampler.""" | ||
|
||
# pylint: disable=missing-function-docstring | ||
|
||
import numpy as np | ||
|
||
from mrmustard import math, settings | ||
from mrmustard.lab_dev.samplers import PNRSampler, HomodyneSampler | ||
from mrmustard.lab_dev import Coherent, Number, Vacuum, QuadratureEigenstate | ||
|
||
|
||
class TestPNRSampler: | ||
r""" | ||
Tests ``PNRSampler`` objects. | ||
""" | ||
|
||
def test_init(self): | ||
sampler = PNRSampler(cutoff=10) | ||
assert sampler.meas_outcomes == list(range(10)) | ||
assert sampler.povms == Number([0], 0) | ||
|
||
def test_probabilities(self): | ||
atol = 1e-4 | ||
|
||
sampler = PNRSampler(cutoff=10) | ||
vac_prob = [1.0] + [0.0] * 99 | ||
assert math.allclose(sampler.probabilities(Vacuum([0, 1])), vac_prob) | ||
|
||
coh_state = Coherent([0, 1], x=[0.5, 1]) | ||
exp_probs = [ | ||
(coh_state >> Number([0], n0).dual >> Number([1], n1).dual) ** 2 | ||
for n0 in range(10) | ||
for n1 in range(10) | ||
] | ||
assert math.allclose(sampler.probabilities(coh_state), exp_probs, atol) | ||
|
||
def test_sample(self): | ||
n_samples = 1000 | ||
sampler = PNRSampler(cutoff=10) | ||
|
||
assert not np.any(sampler.sample(Vacuum([0]))) | ||
assert not np.any(sampler.sample_prob_dist(Vacuum([0]))) | ||
assert not np.any(sampler.sample(Vacuum([0, 1]))) | ||
assert not np.any(sampler.sample_prob_dist(Vacuum([0, 1]))) | ||
|
||
state = Coherent([0], x=[0.1]) | ||
samples = sampler.sample(state, n_samples) | ||
|
||
count = np.zeros_like(sampler.meas_outcomes) | ||
for sample in samples: | ||
idx = sampler.meas_outcomes.index(sample) | ||
count[idx] += 1 | ||
probs = count / n_samples | ||
|
||
assert np.allclose(probs, sampler.probabilities(state), atol=1e-2) | ||
|
||
|
||
class TestHomodyneSampler: | ||
r""" | ||
Tests ``HomodyneSampler`` objects. | ||
""" | ||
|
||
def test_init(self): | ||
sampler = HomodyneSampler(phi=0.5, bounds=(-5, 5), num=100) | ||
assert sampler.povms == QuadratureEigenstate([0], x=0, phi=0.5) | ||
assert math.allclose(sampler.meas_outcomes, list(np.linspace(-5, 5, 100))) | ||
|
||
def test_probabilties(self): | ||
sampler = HomodyneSampler() | ||
|
||
state = Coherent([0], x=[0.1]) | ||
|
||
exp_probs = ( | ||
state.quadrature_distribution(sampler.meas_outcomes) | ||
* sampler._step # pylint: disable=protected-access | ||
) | ||
assert math.allclose(sampler.probabilities(state), exp_probs) | ||
|
||
sampler2 = HomodyneSampler(phi=np.pi / 2) | ||
|
||
exp_probs = ( | ||
state.quadrature_distribution(sampler2.meas_outcomes, sampler2.povms[0].phi.value[0]) | ||
* sampler2._step # pylint: disable=protected-access | ||
) | ||
assert math.allclose(sampler2.probabilities(state), exp_probs) | ||
|
||
def test_sample(self): | ||
n_samples = 1000 | ||
sampler = HomodyneSampler() | ||
state = Coherent([0], x=[0.1]) | ||
samples = sampler.sample(state, n_samples) | ||
|
||
count = np.zeros_like(sampler.meas_outcomes) | ||
for sample in samples: | ||
idx = sampler.meas_outcomes.index(sample) | ||
count[idx] += 1 | ||
probs = count / n_samples | ||
|
||
assert np.allclose(probs, sampler.probabilities(state), atol=1e-2) | ||
|
||
def test_sample_mean_coherent(self): | ||
r""" | ||
Porting test from strawberry fields: | ||
https://github.com/XanaduAI/strawberryfields/blob/master/tests/backend/test_homodyne.py#L56 | ||
""" | ||
N_MEAS = 300 | ||
NUM_STDS = 10.0 | ||
std_10 = NUM_STDS / np.sqrt(N_MEAS) | ||
alpha = 1.0 + 1.0j | ||
x = np.empty(0) | ||
tol = settings.ATOL | ||
|
||
state = Coherent([0], x=math.real(alpha), y=math.imag(alpha)) | ||
sampler = HomodyneSampler() | ||
|
||
for _ in range(N_MEAS): | ||
meas_result = sampler.sample(state, 1)[0] | ||
x = np.append(x, meas_result) | ||
|
||
assert math.allclose(x.mean(), 2 * alpha.real, atol=std_10 + tol) | ||
|
||
def test_sample_mean_and_std_vacuum(self): | ||
r""" | ||
Porting test from strawberry fields: | ||
https://github.com/XanaduAI/strawberryfields/blob/master/tests/backend/test_homodyne.py#L40 | ||
""" | ||
N_MEAS = 300 | ||
NUM_STDS = 10.0 | ||
std_10 = NUM_STDS / np.sqrt(N_MEAS) | ||
x = np.empty(0) | ||
tol = settings.ATOL | ||
|
||
state = Vacuum([0]) | ||
sampler = HomodyneSampler() | ||
|
||
for _ in range(N_MEAS): | ||
meas_result = sampler.sample(state, 1)[0] | ||
x = np.append(x, meas_result) | ||
|
||
assert np.allclose(x.mean(), 0.0, atol=std_10 + tol, rtol=0) | ||
assert np.allclose(x.std(), 1.0, atol=std_10 + tol, rtol=0) |