Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
elib20 committed Aug 26, 2024
1 parent 827e9e1 commit b803d28
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
20 changes: 16 additions & 4 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def to_quadrature(self, phi: float = 0.0) -> CircuitComponent:
in terms of A,b,c.
Args:
phi: The quadrature angle. ``phi=0`` corresponds to the x quadrature,
phi (float): The quadrature angle. ``phi=0`` corresponds to the x quadrature,
``phi=pi/2`` to the p quadrature. The default value is ``0``.
Returns:
A circuit component with the given quadrature representation.
Expand All @@ -345,6 +345,12 @@ def to_quadrature(self, phi: float = 0.0) -> CircuitComponent:
def quadrature_triple(self, phi: float = 0.0) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]:
r"""
The quadrature representation triple A,b,c of this circuit component.
Args:
phi: The quadrature angle. ``phi=0`` corresponds to the x quadrature,
``phi=pi/2`` to the p quadrature. The default value is ``0``.
Returns:
A,b,c triple of the quadrature representation
"""
if isinstance(self.representation, Fock):
raise NotImplementedError("Not implemented with Fock representation.")
Expand All @@ -355,14 +361,20 @@ def quadrature_triple(self, phi: float = 0.0) -> tuple[Batch[ComplexMatrix], Bat
def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor:
r"""
The (discretized) quadrature basis representation of the circuit component.
Args:
quad: discretized quadrature points to evaluate over in the
quadrature representation
phi: The quadrature angle. ``phi=0`` corresponds to the x quadrature,
``phi=pi/2`` to the p quadrature. The default value is ``0``.
Returns:
A circuit component with the given quadrature representation.
"""

if isinstance(self.representation, Fock):
fock_arrays = self.representation.array
# Find where all the bras and kets are so they can be conjugated appropriately
conjugates = [False] * len(self.wires.indices)
conjugates = [
conjugates[i] if i in self.wires.ket.indices else True
False if i in self.wires.ket.indices else True
for i in range(len(self.wires.indices))
]
quad_basis = math.sum(
Expand All @@ -371,7 +383,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor:
return quad_basis

QQQQ = self.to_quadrature(phi=phi)
return QQQQ.representation.ansatz(quad)
return QQQQ.representation(quad)

@classmethod
def _from_attributes(
Expand Down
12 changes: 6 additions & 6 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def phase_space(self, s: float) -> tuple:
def quadrature_distribution(self, quad: Vector, phi: float = 0.0) -> tuple | ComplexTensor:
r"""
The (discretized) quadrature distribution of the State.
Args:
quad: the discretized quadrature axis over which the distribution is computed.
phi: The quadrature angle. ``phi=0`` corresponds to the x quadrature,
``phi=pi/2`` to the p quadrature. The default value is ``0``.
Returns:
A,b,c triple of the quadrature representation
"""
raise NotImplementedError

Expand Down Expand Up @@ -727,9 +733,6 @@ def purity(self) -> float:
return self.L2_norm

def quadrature_distribution(self, quad: Vector, phi: float = 0.0) -> tuple | ComplexTensor:
r"""
The (discretized) quadrature distribution of the circuit component.
"""
quad = math.transpose(
math.astensor(
[
Expand Down Expand Up @@ -1011,9 +1014,6 @@ def dm(self) -> DM:
return ret

def quadrature_distribution(self, quad: Vector, phi: float = 0.0) -> tuple | ComplexTensor:
r"""
The (discretized) quadrature distribution of the circuit component.
"""
quad = math.transpose(
math.astensor(
[
Expand Down
7 changes: 3 additions & 4 deletions mrmustard/physics/fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from __future__ import annotations

from functools import lru_cache
from typing import Sequence
from typing import Sequence, Iterable

import numpy as np

Expand Down Expand Up @@ -843,8 +843,7 @@ def quadrature_basis(
f"Input fock array has dimension {dims} whereas ``quad`` has {quad.shape[-1]}."
)

if type(conjugates) is bool:
conjugates = [conjugates] * dims
conjugates = conjugates if isinstance(conjugates, Iterable) else [conjugates] * dims

# construct quadrature basis vectors
cutoffs = fock_array.shape
Expand Down Expand Up @@ -873,7 +872,7 @@ def quadrature_basis(
def quadrature_distribution(
state: Tensor,
quadrature_angle: float = 0.0,
x: Vector = None,
x: Vector | None = None,
):
r"""Given the ket or density matrix of a single-mode state, it generates the probability
density distribution :math:`\tr [ \rho |x_\phi><x_\phi| ]` where `\rho` is the
Expand Down

0 comments on commit b803d28

Please sign in to comment.