Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Sep 19, 2024
1 parent 395a402 commit 76151a3
Show file tree
Hide file tree
Showing 2 changed files with 362 additions and 0 deletions.
207 changes: 207 additions & 0 deletions mrmustard/lab_dev/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Samplers for measurement devices.
"""

from __future__ import annotations
from itertools import product

from abc import ABC, abstractmethod

from typing import Any, Sequence

import numpy as np

from mrmustard import math, settings

from .states import State, Number, QuadratureEigenstate
from .circuit_components import CircuitComponent

__all__ = ["Sampler", "PNRSampler", "HomodyneSampler"]


class Sampler(ABC):
r"""
A sampler for measurements of quantum circuits.
Args:
meas_outcomes: The measurement outcomes for this sampler.
povms: The POVMs of this sampler.
"""

def __init__(
self,
meas_outcomes: Sequence[Any],
povms: CircuitComponent | Sequence[CircuitComponent],
):
self._povms = povms
self._meas_outcomes = meas_outcomes

self._outcome_arg = None

@property
def povms(self) -> CircuitComponent | Sequence[CircuitComponent]:
r"""
The POVMs of this sampler.
"""
return self._povms

@property
def meas_outcomes(self) -> Sequence[Any]:
r"""
The measurement outcomes of this sampler.
"""
return self._meas_outcomes

@abstractmethod
def probabilities(self, state: State, atol: float = 1e-4) -> Sequence[float]:
r"""
Returns the probability distribution of a state w.r.t. measurement outcomes.
Args:
state: The state to generate the probability distribution of. Note: the
input state must be normalized.
atol: The absolute tolerance used for validating that the computed
probability distribution sums to ``1``.
"""

def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray:
r"""
Returns an array of samples given a state.
Args:
state: The state to sample.
n_samples: The number of samples to generate.
seed: An optional seed for random sampling.
Returns:
An array of samples such that the shape is ``(n_samples, n_modes)``.
"""
initial_mode = state.modes[0]
initial_samples = self.sample_prob_dist(state[initial_mode], n_samples, seed)

if len(state.modes) == 1:
return initial_samples

unique_samples, counts = np.unique(initial_samples, return_counts=True)
ret = []
for unique_sample, counts in zip(unique_samples, counts):
meas_op = self._get_povm(unique_sample, initial_mode).dual
reduced_state = (state >> meas_op).normalize()
samples = self.sample(reduced_state, counts)
for sample in samples:
ret.append(np.append([unique_sample], sample))
return np.array(ret)

def sample_prob_dist(
self, state: State, n_samples: int = 1000, seed: int | None = None
) -> np.ndarray:
r"""
Samples a state by computing the probability distribution.
Args:
state: The state to sample.
n_samples: The number of samples to generate.
seed: An optional seed for random sampling.
"""
rng = np.random.default_rng(seed) if seed else settings.rng
return rng.choice(
a=list(product(self.meas_outcomes, repeat=len(state.modes))),
p=self.probabilities(state),
size=n_samples,
)

def _get_povm(self, meas_outcome: Any, mode: int) -> CircuitComponent:
r"""
Returns the POVM associated with a given outcome on a given mode.
Args:
meas_outcome: The measurement outcome.
mode: The mode.
Returns:
The POVM circuit component.
"""
if isinstance(self.povms, CircuitComponent):
kwargs = self.povms.parameter_set.to_dict()
kwargs[self._outcome_arg] = meas_outcome
return self.povms.__class__(modes=[mode], **kwargs)
else:
return self.povms[self.meas_outcomes.index(meas_outcome)].on([mode])

def _validate_probs(self, probs: Sequence[float], atol: float) -> Sequence[float]:
r"""
Validates that the given probability distribution sums to ``1`` within some
tolerance and returns a renormalized probability distribution to account for
small numerical errors.
Args:
probs: The probability distribution to validate.
atol: The absolute tolerance to validate with.
"""
atol = atol or settings.ATOL
prob_sum = sum(probs)
if not math.allclose(prob_sum, 1, atol):
raise ValueError(f"Probabilities sum to {prob_sum} and not 1.0.")
return math.real(probs / prob_sum)


class PNRSampler(Sampler):
r"""
A sampler for photon-number resolving (PNR) detectors.
Args:
cutoff: The photon number cutoff.
"""

def __init__(self, cutoff: int) -> None:
super().__init__(list(range(cutoff)), Number([0], 0))
self._cutoff = cutoff
self._outcome_arg = "n"

def probabilities(self, state, atol=1e-4):
return self._validate_probs(state.fock_distribution(self._cutoff), atol)


class HomodyneSampler(Sampler):
r"""
A sampler for homodyne measurements.
Args:
phi: The quadrature angle where ``0`` corresponds to ``x`` and ``\pi/2`` to ``p``.
bounds: The range of values to discretize over.
num: The number of points to discretize over.
"""

def __init__(
self,
phi: float = 0,
bounds: tuple[float, float] = (-10, 10),
num: int = 1000,
) -> None:
meas_outcomes, step = np.linspace(*bounds, num, retstep=True)
super().__init__(
list(meas_outcomes),
QuadratureEigenstate([0], x=0, phi=phi),
)
self._step = step
self._outcome_arg = "x"

def probabilities(self, state, atol=1e-4):
probs = state.quadrature_distribution(
self.meas_outcomes, self.povms.phi.value[0]
) * self._step ** len(state.modes)
return self._validate_probs(probs, atol)
155 changes: 155 additions & 0 deletions tests/test_lab_dev/test_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the sampler."""

# pylint: disable=missing-function-docstring

import numpy as np

from mrmustard import math, settings
from mrmustard.lab_dev.samplers import PNRSampler, HomodyneSampler
from mrmustard.lab_dev import Coherent, Number, Vacuum, QuadratureEigenstate


class TestPNRSampler:
r"""
Tests ``PNRSampler`` objects.
"""

def test_init(self):
sampler = PNRSampler(cutoff=10)
assert sampler.meas_outcomes == list(range(10))
assert sampler.povms == Number([0], 0)

def test_probabilities(self):
atol = 1e-4

sampler = PNRSampler(cutoff=10)
vac_prob = [1.0] + [0.0] * 99
assert math.allclose(sampler.probabilities(Vacuum([0, 1])), vac_prob)

coh_state = Coherent([0, 1], x=[0.5, 1])
exp_probs = [
(coh_state >> Number([0], n0).dual >> Number([1], n1).dual) ** 2
for n0 in range(10)
for n1 in range(10)
]
assert math.allclose(sampler.probabilities(coh_state), exp_probs, atol)

def test_sample(self):
n_samples = 1000
sampler = PNRSampler(cutoff=10)

assert not np.any(sampler.sample(Vacuum([0])))
assert not np.any(sampler.sample_prob_dist(Vacuum([0])))
assert not np.any(sampler.sample(Vacuum([0, 1])))
assert not np.any(sampler.sample_prob_dist(Vacuum([0, 1])))

state = Coherent([0], x=[0.1])
samples = sampler.sample(state, n_samples)

count = np.zeros_like(sampler.meas_outcomes)
for sample in samples:
idx = sampler.meas_outcomes.index(sample)
count[idx] += 1
probs = count / n_samples

assert np.allclose(probs, sampler.probabilities(state), atol=1e-2)


class TestHomodyneSampler:
r"""
Tests ``HomodyneSampler`` objects.
"""

def test_init(self):
sampler = HomodyneSampler(phi=0.5, bounds=(-5, 5), num=100)
assert sampler.povms == QuadratureEigenstate([0], x=0, phi=0.5)
assert math.allclose(sampler.meas_outcomes, list(np.linspace(-5, 5, 100)))

def test_probabilties(self):
sampler = HomodyneSampler()

state = Coherent([0], x=[0.1])

exp_probs = (
state.quadrature_distribution(sampler.meas_outcomes)
* sampler._step # pylint: disable=protected-access
)
assert math.allclose(sampler.probabilities(state), exp_probs)

sampler2 = HomodyneSampler(phi=np.pi / 2)

exp_probs = (
state.quadrature_distribution(sampler2.meas_outcomes, sampler2.povms[0].phi.value[0])
* sampler2._step # pylint: disable=protected-access
)
assert math.allclose(sampler2.probabilities(state), exp_probs)

def test_sample(self):
n_samples = 1000
sampler = HomodyneSampler()
state = Coherent([0], x=[0.1])
samples = sampler.sample(state, n_samples)

count = np.zeros_like(sampler.meas_outcomes)
for sample in samples:
idx = sampler.meas_outcomes.index(sample)
count[idx] += 1
probs = count / n_samples

assert np.allclose(probs, sampler.probabilities(state), atol=1e-2)

def test_sample_mean_coherent(self):
r"""
Porting test from strawberry fields:
https://github.com/XanaduAI/strawberryfields/blob/master/tests/backend/test_homodyne.py#L56
"""
N_MEAS = 300
NUM_STDS = 10.0
std_10 = NUM_STDS / np.sqrt(N_MEAS)
alpha = 1.0 + 1.0j
x = np.empty(0)
tol = settings.ATOL

state = Coherent([0], x=math.real(alpha), y=math.imag(alpha))
sampler = HomodyneSampler()

for _ in range(N_MEAS):
meas_result = sampler.sample(state, 1)[0]
x = np.append(x, meas_result)

assert math.allclose(x.mean(), 2 * alpha.real, atol=std_10 + tol)

def test_sample_mean_and_std_vacuum(self):
r"""
Porting test from strawberry fields:
https://github.com/XanaduAI/strawberryfields/blob/master/tests/backend/test_homodyne.py#L40
"""
N_MEAS = 300
NUM_STDS = 10.0
std_10 = NUM_STDS / np.sqrt(N_MEAS)
x = np.empty(0)
tol = settings.ATOL

state = Vacuum([0])
sampler = HomodyneSampler()

for _ in range(N_MEAS):
meas_result = sampler.sample(state, 1)[0]
x = np.append(x, meas_result)

assert np.allclose(x.mean(), 0.0, atol=std_10 + tol, rtol=0)
assert np.allclose(x.std(), 1.0, atol=std_10 + tol, rtol=0)

0 comments on commit 76151a3

Please sign in to comment.