Skip to content

Commit

Permalink
some CR
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Sep 18, 2024
1 parent d32d0c9 commit ec8928b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 17 deletions.
62 changes: 51 additions & 11 deletions tests/test_lab_dev/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from mrmustard import math
from mrmustard import math, settings
from mrmustard.lab_dev.sampler import PNRSampler, HomodyneSampler
from mrmustard.lab_dev import Coherent, Number, Vacuum, QuadratureEigenstate

Expand Down Expand Up @@ -85,20 +85,18 @@ def test_probabilties(self):
sampler = HomodyneSampler()

state = Coherent([0], x=[0.1])
exp_probs = [
(state.dm() >> QuadratureEigenstate([0], x).dual)
* sampler._step # pylint: disable=protected-access
for x in sampler.meas_outcomes
]

exp_probs = (
state.quadrature_distribution(sampler.meas_outcomes) * sampler._step
) # pylint: disable=protected-access
assert math.allclose(sampler.probabilities(state), exp_probs)

sampler2 = HomodyneSampler(phi=np.pi / 2)

exp_probs = [
(state.dm() >> QuadratureEigenstate([0], x, phi=np.pi / 2).dual)
* sampler2._step # pylint: disable=protected-access
for x in sampler2.meas_outcomes
]
exp_probs = (
state.quadrature_distribution(sampler2.meas_outcomes, sampler2.povms[0].phi.value[0])
* sampler2._step
) # pylint: disable=protected-access
assert math.allclose(sampler2.probabilities(state), exp_probs)

def test_sample(self):
Expand All @@ -114,3 +112,45 @@ def test_sample(self):
probs = count / n_samples

assert np.allclose(probs, sampler.probabilities(state), atol=1e-2)

def test_sample_mean_coherent(self):
r"""
Porting test from strawberry fields:
https://github.com/XanaduAI/strawberryfields/blob/master/tests/backend/test_homodyne.py#L56
"""
N_MEAS = 300
NUM_STDS = 10.0
std_10 = NUM_STDS / np.sqrt(N_MEAS)
alpha = 1.0 + 1.0j
x = np.empty(0)
tol = settings.ATOL

state = Coherent([0], x=math.real(alpha), y=math.imag(alpha))
sampler = HomodyneSampler()

for _ in range(N_MEAS):
meas_result = sampler.sample(state, 1)[0]
x = np.append(x, meas_result)

assert math.allclose(x.mean(), 2 * alpha.real, atol=std_10 + tol)

def test_sample_mean_and_std_vacuum(self):
r"""
Porting test from strawberry fields:
https://github.com/XanaduAI/strawberryfields/blob/master/tests/backend/test_homodyne.py#L40
"""
N_MEAS = 300
NUM_STDS = 10.0
std_10 = NUM_STDS / np.sqrt(N_MEAS)
x = np.empty(0)
tol = settings.ATOL

state = Vacuum([0])
sampler = HomodyneSampler()

for _ in range(N_MEAS):
meas_result = sampler.sample(state, 1)[0]
x = np.append(x, meas_result)

assert np.allclose(x.mean(), 0.0, atol=std_10 + tol, rtol=0)
assert np.allclose(x.std(), 1.0, atol=std_10 + tol, rtol=0)
14 changes: 8 additions & 6 deletions tests/test_lab_dev/test_states/test_states_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

# pylint: disable=protected-access, unspecified-encoding, missing-function-docstring, expression-not-assigned, pointless-statement

import numpy as np
from itertools import product
import numpy as np
from ipywidgets import Box, HBox, VBox, HTML
from plotly.graph_objs import FigureWidget
import pytest
Expand Down Expand Up @@ -626,13 +626,15 @@ def test_quadrature_multimode_dm(self):
x, y = 1, 2
state = Coherent(modes=[0, 1], x=x, y=y).dm()
q = np.linspace(-10, 10, 100)
quad = math.transpose(math.astensor([q, q, q + 1, q + 1]))
ket = coherent_state_quad(q + 1, x, y) * coherent_state_quad(q + 1, x, y)
bra = np.conj(coherent_state_quad(q, x, y)) * np.conj(coherent_state_quad(q, x, y))
quad = math.tile(math.astensor(list(product(q, repeat=2))), (1, 2))
ket = math.kron(coherent_state_quad(q, x, y), coherent_state_quad(q, x, y))
bra = math.kron(
np.conj(coherent_state_quad(q, x, y)), np.conj(coherent_state_quad(q, x, y))
)
assert math.allclose(state.quadrature(quad), bra * ket)
assert math.allclose(state.quadrature_distribution(q), math.abs(bra) ** 2)
assert math.allclose(state.to_fock(40).quadrature(quad), bra * ket)
assert math.allclose(state.to_fock(40).quadrature_distribution(q), math.abs(bra) ** 2)
# assert math.allclose(state.to_fock(40).quadrature(quad), bra * ket)
# assert math.allclose(state.to_fock(40).quadrature_distribution(q), math.abs(bra) ** 2)

def test_quadrature_multivariable_dm(self):
x, y = 1, 2
Expand Down

0 comments on commit ec8928b

Please sign in to comment.