Skip to content

Commit

Permalink
sampler quad_dist
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Sep 18, 2024
1 parent ceed8da commit 8e74234
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
4 changes: 1 addition & 3 deletions mrmustard/lab_dev/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,5 @@ def __init__(
self._step = step

def probabilities(self, state, atol=1e-4):
q_state = state.dm() >> BtoQ(state.modes, phi=self.povms[0].phi.value[0])
z = [x * 2 for x in product(self.meas_outcomes, repeat=len(state.modes))]
probs = q_state.representation(z) * math.sqrt(settings.HBAR)
probs = state.quadrature_distribution(self.meas_outcomes, self.povms[0].phi.value[0])
return self._validate_probs(probs, self._step ** len(state.modes), atol)
7 changes: 4 additions & 3 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def phase_space(self, s: float) -> tuple:
*new_state.bargmann_triple(batched=True), batched=True
)

def quadrature_distribution(self, quad: Vector, phi: float = 0.0) -> tuple | ComplexTensor:
def quadrature_distribution(self, quad: Vector, phi: float = 0.0) -> ComplexTensor:
r"""
The (discretized) quadrature distribution of the State.
Expand All @@ -345,18 +345,19 @@ def quadrature_distribution(self, quad: Vector, phi: float = 0.0) -> tuple | Com
Returns:
The quadrature distribution.
"""
quad = math.astensor(quad)
if len(quad.shape) != 1 and len(quad.shape) != self.n_modes:
raise ValueError(
f"The dimensionality of quad should be 1, or match the number of modes."
)

if len(quad.shape) == 1:
quad = math.transpose(math.astensor([quad] * self.n_modes))
quad = math.astensor(list(product(quad, repeat=len(self.modes))))

if isinstance(self, Ket):
return math.abs(self.quadrature(quad, phi)) ** 2
else:
quad = math.tile(math.astensor(quad), (1, 2))
quad = math.tile(quad, (1, 2))
return self.quadrature(quad, phi)

def visualize_2d(
Expand Down

0 comments on commit 8e74234

Please sign in to comment.