Skip to content

Commit

Permalink
Make unitary on sim init
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian Carver committed Jan 14, 2025
1 parent d3300d1 commit 5ee4524
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions simphony/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ def __init__(self, ckt: Model, **kwargs) -> None:
raise ValueError("Must specify 'wl' (wavelengths to simulate).")
super().__init__(ckt, kwargs["wl"])

# get the unitary s-parameters of the circuit
self.s_params = dict_to_matrix(self.ckt())
self.unitary = self.to_unitary(self.s_params)

def add_qstate(self, qstate: QuantumState) -> None:
"""Add a quantum state to the simulation.
Expand Down Expand Up @@ -494,13 +498,10 @@ def run(self) -> QuantumResult:
"""Run the simulation."""
ports = get_ports(self.ckt())
n_ports = len(ports)
# get the unitary s-parameters of the circuit
s_params = dict_to_matrix(self.ckt())
unitary = self.to_unitary(s_params)
# get an array of the indices of the input ports
input_indices = [ports.index(port) for port in self.input.ports]
# create vacuum ports for each extra mode in the unitary matrix
n_modes = unitary.shape[1]
n_modes = self.unitary.shape[1]
n_vacuum = n_modes - len(input_indices)
self.input._add_vacuums(n_vacuum)
input_indices += [i for i in range(n_modes) if i not in input_indices]
Expand All @@ -511,7 +512,7 @@ def run(self) -> QuantumResult:
means = []
covs = []
for wl_ind in range(len(self.wl)):
s_wl = unitary[wl_ind]
s_wl = self.unitary[wl_ind]
transform = jnp.zeros((n_modes * 2, n_modes * 2))
n = n_modes

Expand All @@ -533,7 +534,7 @@ def run(self) -> QuantumResult:
covs.append(output_cov)

return QuantumResult(
s_params=s_params,
s_params=self.s_params,
input_means=input_means,
input_cov=input_cov,
transforms=jnp.stack(transforms),
Expand Down

0 comments on commit 5ee4524

Please sign in to comment.