From 0942a674a53f28c379cfd4444105810c9c029298 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 21 Oct 2022 11:54:03 +0100 Subject: [PATCH 01/25] initial densitytensor --- qujax/__init__.py | 5 ++ qujax/density_matrix.py | 141 +++++++++++++++++++++++++++++++++++ tests/test_density_matrix.py | 28 +++++++ 3 files changed, 174 insertions(+) create mode 100644 qujax/density_matrix.py create mode 100644 tests/test_density_matrix.py diff --git a/qujax/__init__.py b/qujax/__init__.py index 9af7a26..8ae3aea 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -17,7 +17,12 @@ from qujax.circuit_tools import check_circuit from qujax.circuit_tools import print_circuit +from qujax.density_matrix import _kraus_single +from qujax.density_matrix import kraus +from qujax.density_matrix import get_params_to_densitytensor_func + del version del circuit del observable del circuit_tools +del density_matrix diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py new file mode 100644 index 0000000..7e737e3 --- /dev/null +++ b/qujax/density_matrix.py @@ -0,0 +1,141 @@ +from __future__ import annotations +from typing import Sequence, Union, Callable +from jax import numpy as jnp +from jax.lax import scan + +from qujax.circuit import apply_gate, UnionCallableOptionalArray, _to_gate_funcs, _arrayify_inds +from qujax.circuit_tools import check_circuit + + +def _kraus_single(density_tensor: jnp.ndarray, + array: jnp.ndarray, + qubit_inds: Sequence[int]) -> jnp.ndarray: + """ + Performs single Kraus operation + + .. math:: + \rho_\text{out} = B \rho_\text{in} B^{\dagger} + + Args: + density_tensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits + array: Array containing the Kraus operator. + qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. + + Returns: + Updated density matrix. + """ + n_qubits = density_tensor.ndim // 2 + density_tensor = apply_gate(density_tensor, array, qubit_inds) + density_tensor = apply_gate(density_tensor, array.conj(), [n_qubits + i for i in qubit_inds]) + return density_tensor + + +def kraus(density_tensor: jnp.ndarray, + arrays: Union[Sequence[jnp.ndarray], jnp.ndarray], + qubit_inds: Sequence[int]) -> jnp.ndarray: + """ + Performs Kraus operation. + + .. math:: + \rho_\text{out} = \sum_i B_i \rho_\text{in} B_i^{\dagger} + + Args: + density_tensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits + arrays: Sequence of arrays containing the Kraus operators. + qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. + + Returns: + Updated density matrix. + """ + arrays = jnp.atleast_3d(arrays) + new_density_tensor, _ = scan(lambda dt, arr: dt + _kraus_single(density_tensor, arr, qubit_inds), + init=jnp.zeros_like(density_tensor), xs=arrays) + # i.e. new_density_tensor = vmap(_kraus_single, in_axes=(None, 0, None))(density_tensor, arrays, qubit_inds) + return new_density_tensor + + +def get_params_to_densitytensor_func(gate_seq: Sequence[Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]]], + qubit_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Sequence[int]], + n_qubits: int = None) -> UnionCallableOptionalArray: + """ + Creates a function that maps circuit parameters to a density tensor. + densitytensor = densitymatrix.reshape((2,) * 2 * n_qubits) + densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits) + + Args: + gate_seq: Sequence of gates. + Each element is either a string matching a unitary array or function in qujax.gates, + a custom unitary array or a custom function taking parameters and returning a unitary array. + Unitary arrays will be reshaped into tensor form (2, 2,...) + qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are acting on. + i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the zeroth qubit, + the second gate is a two qubit gate acting on the zeroth and first qubit etc. + param_inds_seq: Sequence of sequences representing parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter + (the float at position zero in the parameter vector/array), the second gate is not parameterised + and the third gates used the parameters at position five and two. + n_qubits: Number of qubits, if fixed. + + Returns: + Function which maps parameters (and optional densitytensor_in) to a densitytensor. + If no parameters are found then the function only takes optional densitytensor_in. + + """ + + check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + if n_qubits is None: + n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 + + gate_seq_callable = _to_gate_funcs(gate_seq) + param_inds_seq = _arrayify_inds(param_inds_seq) + + def params_to_densitytensor_func(params: jnp.ndarray, + densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: + """ + Applies parameterised circuit (series of gates) to a densitytensor_in (default is |0>^N <0|^N). + + Args: + params: Parameters of the circuit. + densitytensor_in: Optional. Input densitytensor. + Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index). + + Returns: + Updated densitytensor. + + """ + if densitytensor_in is None: + densitytensor = jnp.zeros((2,) * 2 * n_qubits) + densitytensor = densitytensor.at[(0,) * 2 * n_qubits].set(1.) + else: + densitytensor = densitytensor_in + params = jnp.atleast_1d(params) + for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq): + gate_params = jnp.take(params, param_inds) + gate_unitary = gate_func(*gate_params) + gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + densitytensor = _kraus_single(densitytensor, gate_unitary, qubit_inds) + return densitytensor + + if all([pi.size == 0 for pi in param_inds_seq]): + def no_params_to_densitytensor_func(densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: + """ + Applies circuit (series of gates with no parameters) to a densitytensor_in (default is |0>^N <0|^N). + + Args: + densitytensor_in: Optional. Input densitytensor. + Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in [0]*(2*N) index). + + Returns: + Updated densitytensor. + + """ + return params_to_densitytensor_func(jnp.array([]), densitytensor_in) + + return no_params_to_densitytensor_func + + return params_to_densitytensor_func diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py new file mode 100644 index 0000000..57f8112 --- /dev/null +++ b/tests/test_density_matrix.py @@ -0,0 +1,28 @@ +from jax import numpy as jnp, jit + +import qujax +from qujax.density_matrix import _kraus_single + + +def test_kraus_single(): + n_qubits = 3 + dim = 2 ** n_qubits + density_matrix = jnp.arange(dim**2).reshape(dim, dim) + density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) + kraus_operator = qujax.gates.Rx(0.2) + + qubit_inds = (1,) + + qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + + unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) + unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) + check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T + + assert jnp.all(jnp.abs(qujax_kraus_dm - check_kraus_dm) < 1e-5) + + qujax_kraus_dt_jit = jit(_kraus_single, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.all(jnp.abs(qujax_kraus_dm_jit - check_kraus_dm) < 1e-5) + From 52cfdd0e717fbc2e621c879730c3d55b870239d2 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 21 Oct 2022 12:02:03 +0100 Subject: [PATCH 02/25] consolidate densitytensor --- qujax/density_matrix.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index 7e737e3..09f558a 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -7,7 +7,7 @@ from qujax.circuit_tools import check_circuit -def _kraus_single(density_tensor: jnp.ndarray, +def _kraus_single(densitytensor: jnp.ndarray, array: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: """ @@ -17,20 +17,20 @@ def _kraus_single(density_tensor: jnp.ndarray, \rho_\text{out} = B \rho_\text{in} B^{\dagger} Args: - density_tensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits + densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits array: Array containing the Kraus operator. qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. Returns: Updated density matrix. """ - n_qubits = density_tensor.ndim // 2 - density_tensor = apply_gate(density_tensor, array, qubit_inds) - density_tensor = apply_gate(density_tensor, array.conj(), [n_qubits + i for i in qubit_inds]) - return density_tensor + n_qubits = densitytensor.ndim // 2 + densitytensor = apply_gate(densitytensor, array, qubit_inds) + densitytensor = apply_gate(densitytensor, array.conj(), [n_qubits + i for i in qubit_inds]) + return densitytensor -def kraus(density_tensor: jnp.ndarray, +def kraus(densitytensor: jnp.ndarray, arrays: Union[Sequence[jnp.ndarray], jnp.ndarray], qubit_inds: Sequence[int]) -> jnp.ndarray: """ @@ -40,7 +40,7 @@ def kraus(density_tensor: jnp.ndarray, \rho_\text{out} = \sum_i B_i \rho_\text{in} B_i^{\dagger} Args: - density_tensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits + densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits arrays: Sequence of arrays containing the Kraus operators. qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. @@ -48,10 +48,10 @@ def kraus(density_tensor: jnp.ndarray, Updated density matrix. """ arrays = jnp.atleast_3d(arrays) - new_density_tensor, _ = scan(lambda dt, arr: dt + _kraus_single(density_tensor, arr, qubit_inds), - init=jnp.zeros_like(density_tensor), xs=arrays) - # i.e. new_density_tensor = vmap(_kraus_single, in_axes=(None, 0, None))(density_tensor, arrays, qubit_inds) - return new_density_tensor + new_densitytensor, _ = scan(lambda dt, arr: dt + _kraus_single(densitytensor, arr, qubit_inds), + init=jnp.zeros_like(densitytensor), xs=arrays) + # i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds) + return new_densitytensor def get_params_to_densitytensor_func(gate_seq: Sequence[Union[str, From 9bf3b8b3576930e96cf9ac1bebd8e05565a817d5 Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Sat, 22 Oct 2022 22:23:30 +0100 Subject: [PATCH 03/25] Add test_params_to_densitytensor_func --- tests/test_density_matrix.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 57f8112..1be5ace 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -1,7 +1,8 @@ from jax import numpy as jnp, jit import qujax -from qujax.density_matrix import _kraus_single +from qujax.density_matrix import _kraus_single, get_params_to_densitytensor_func +from qujax import get_params_to_statetensor_func def test_kraus_single(): @@ -26,3 +27,26 @@ def test_kraus_single(): qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.all(jnp.abs(qujax_kraus_dm_jit - check_kraus_dm) < 1e-5) + +def test_params_to_densitytensor_func(): + n_qubits = 2 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i+1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(n_qubits)/10. + + st = params_to_st(params) + dt_test = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2*n_qubits)) + + dt = params_to_dt(params) + + assert jnp.allclose(dt, dt_test) From 81d4d7cac32107eb44b43ca03357085814b4146b Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 25 Oct 2022 11:25:15 +0100 Subject: [PATCH 04/25] kraus fix --- qujax/density_matrix.py | 18 +++++--- tests/test_density_matrix.py | 82 ++++++++++++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 11 deletions(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index 09f558a..26cb2b6 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -18,7 +18,7 @@ def _kraus_single(densitytensor: jnp.ndarray, Args: densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits - array: Array containing the Kraus operator. + array: Array containing the Kraus operator (in tensor form). qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. Returns: @@ -41,16 +41,20 @@ def kraus(densitytensor: jnp.ndarray, Args: densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits - arrays: Sequence of arrays containing the Kraus operators. + arrays: Sequence of arrays containing the Kraus operators (in tensor form). qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. Returns: Updated density matrix. """ - arrays = jnp.atleast_3d(arrays) - new_densitytensor, _ = scan(lambda dt, arr: dt + _kraus_single(densitytensor, arr, qubit_inds), - init=jnp.zeros_like(densitytensor), xs=arrays) - # i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds) + arrays = jnp.array(arrays) + if arrays.ndim == (2 * len(qubit_inds)): + arrays = arrays[jnp.newaxis] + # ensure first dimensions indexes different kraus operators + + new_densitytensor, _ = scan(lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), + init=jnp.zeros_like(densitytensor, dtype='complex64'), xs=arrays) + # i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds).sum(0) return new_densitytensor @@ -118,7 +122,7 @@ def params_to_densitytensor_func(params: jnp.ndarray, gate_params = jnp.take(params, param_inds) gate_unitary = gate_func(*gate_params) gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form - densitytensor = _kraus_single(densitytensor, gate_unitary, qubit_inds) + densitytensor = kraus(densitytensor, gate_unitary, qubit_inds) return densitytensor if all([pi.size == 0 for pi in param_inds_seq]): diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 1be5ace..c246cdf 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -1,7 +1,7 @@ from jax import numpy as jnp, jit import qujax -from qujax.density_matrix import _kraus_single, get_params_to_densitytensor_func +from qujax import _kraus_single, kraus, get_params_to_densitytensor_func from qujax import get_params_to_statetensor_func @@ -14,18 +14,89 @@ def test_kraus_single(): qubit_inds = (1,) + # qujax._kraus_single qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator, qubit_inds) qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) - unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) + unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T - assert jnp.all(jnp.abs(qujax_kraus_dm - check_kraus_dm) < 1e-5) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) qujax_kraus_dt_jit = jit(_kraus_single, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) - assert jnp.all(jnp.abs(qujax_kraus_dm_jit - check_kraus_dm) < 1e-5) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + # qujax.kraus (but for a single array) + qujax_kraus_dt = kraus(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + +def test_kraus_single_2qubit(): + n_qubits = 4 + dim = 2 ** n_qubits + density_matrix = jnp.arange(dim**2).reshape(dim, dim) + density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) + kraus_operator_tensor = qujax.gates.ZZPhase(0.1) + kraus_operator = qujax.gates.ZZPhase(0.1).reshape(4, 4) + + qubit_inds = (1, 2) + + # qujax._kraus_single + qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + + unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) + unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) + check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T + + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(_kraus_single, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + # qujax.kraus (but for a single array) + qujax_kraus_dt = kraus(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) + + +def test_kraus_multiple(): + n_qubits = 3 + dim = 2 ** n_qubits + density_matrix = jnp.arange(dim**2).reshape(dim, dim) + density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) + + kraus_operators = [0.25 * qujax.gates.H, 0.25 * qujax.gates.Rx(0.3), 0.5 * qujax.gates.Ry(0.1)] + + qubit_inds = (1,) + + qujax_kraus_dt = kraus(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + + unitary_matrices = [jnp.kron(jnp.eye(2 * qubit_inds[0]), ko) for ko in kraus_operators] + unitary_matrices = [jnp.kron(um, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) for um in unitary_matrices] + + check_kraus_dm = jnp.zeros_like(density_matrix) + for um in unitary_matrices: + check_kraus_dm += um @ density_matrix @ um.conj().T + + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + + qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) def test_params_to_densitytensor_func(): @@ -50,3 +121,6 @@ def test_params_to_densitytensor_func(): dt = params_to_dt(params) assert jnp.allclose(dt, dt_test) + + jit_dt = jit(params_to_dt)(params) + assert jnp.allclose(jit_dt, dt_test) From a6dbca96a58bba3e42d5c04bb2c60901d2712a55 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 25 Oct 2022 12:31:30 +0100 Subject: [PATCH 05/25] type hint --- qujax/density_matrix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index 26cb2b6..a91e47c 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Sequence, Union, Callable +from typing import Sequence, Union, Callable, Iterable from jax import numpy as jnp from jax.lax import scan @@ -31,7 +31,7 @@ def _kraus_single(densitytensor: jnp.ndarray, def kraus(densitytensor: jnp.ndarray, - arrays: Union[Sequence[jnp.ndarray], jnp.ndarray], + arrays: Iterable[jnp.ndarray], qubit_inds: Sequence[int]) -> jnp.ndarray: """ Performs Kraus operation. From b85e1bd244fb63a1c99bd1a4671cbf819c23c2e2 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 25 Oct 2022 17:10:15 +0100 Subject: [PATCH 06/25] permit non-unitary Kraus --- qujax/circuit.py | 95 +++++++++++++++++++++++------------- qujax/circuit_tools.py | 9 ++-- qujax/density_matrix.py | 62 +++++++++++++++++------ tests/test_density_matrix.py | 38 +++++++++++++++ 4 files changed, 150 insertions(+), 54 deletions(-) diff --git a/qujax/circuit.py b/qujax/circuit.py index 259c6d5..18e4e17 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -17,6 +17,10 @@ def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] +gate_type = Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]] def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: @@ -40,55 +44,77 @@ def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: return statetensor -def _to_gate_funcs(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]])\ - -> Sequence[Callable[[jnp.ndarray], jnp.ndarray]]: +def _to_gate_func(gate: gate_type) -> Callable[[jnp.ndarray], jnp.ndarray]: """ - Ensures all gate_seq elements are functions that map (possibly empty) parameters + Ensures a gate_seq element is a function that map (possibly empty) parameters to a unitary tensor. Args: - gate_seq: Sequence of gates. - Each element is either a string matching an array or function in qujax.gates, + gate: Either a string matching an array or function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) or a function taking parameters and returning gate unitary in tensor form. Returns: - Sequence of gate parameter to unitary functions - + Gate parameter to unitary functions """ + def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: return lambda: arr - gate_seq_callable = [] - for gate in gate_seq: - if isinstance(gate, str): - gate = gates.__dict__[gate] + if isinstance(gate, str): + gate = gates.__dict__[gate] - if callable(gate): - gate_func = gate - elif hasattr(gate, '__array__'): - gate_func = _array_to_callable(jnp.array(gate)) - else: - raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' - f'callable: {gate}') - gate_seq_callable.append(gate_func) - - return gate_seq_callable + if callable(gate): + gate_func = gate + elif hasattr(gate, '__array__'): + gate_func = _array_to_callable(jnp.array(gate)) + else: + raise TypeError(f'Unsupported gate type - gate must be either a string in qujax.gates, an array or ' + f'callable: {gate}') + return gate_func def _arrayify_inds(param_inds_seq: Sequence[Sequence[int]]) -> Sequence[jnp.ndarray]: + """ + Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) + + Args: + param_inds_seq: Sequence of sequences representing parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter + (the float at position zero in the parameter vector/array), the second gate is not parameterised + and the third gates used the parameters at position five and two. + + Returns: + Sequence of arrays representing parameter indices. + """ param_inds_seq = [jnp.array(p) for p in param_inds_seq] param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] return param_inds_seq -def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], +def _gate_func_to_unitary(gate_func: Callable[[jnp.ndarray], jnp.ndarray], + qubit_inds: Sequence[int], + param_inds: jnp.ndarray, + params: jnp.ndarray) -> jnp.ndarray: + """ + Extract gate unitary. + + Args: + gate_func: Function that maps a (possibly empty) parameter array to a unitary tensor (array) + qubit_inds: Indices of qubits to apply gate to (only needed to ensure gate is in tensor form) + param_inds: Indices of full parameter to extract gate specific parameters + params: Full parameter vector + + Returns: + Array containing gate unitary in tensor form. + """ + gate_params = jnp.take(params, param_inds) + gate_unitary = gate_func(*gate_params) + gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + return gate_unitary + + +def get_params_to_statetensor_func(gate_seq: Sequence[gate_type], qubit_inds_seq: Sequence[Sequence[int]], param_inds_seq: Sequence[Sequence[int]], n_qubits: int = None) -> UnionCallableOptionalArray: @@ -120,8 +146,8 @@ def get_params_to_statetensor_func(gate_seq: Sequence[Union[str, if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - gate_seq_callable = _to_gate_funcs(gate_seq) - param_inds_seq = _arrayify_inds(param_inds_seq) + gate_seq_callable = [_to_gate_func(g) for g in gate_seq] + param_inds_array_seq = _arrayify_inds(param_inds_seq) def params_to_statetensor_func(params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: @@ -143,14 +169,13 @@ def params_to_statetensor_func(params: jnp.ndarray, else: statetensor = statetensor_in params = jnp.atleast_1d(params) - for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq): - gate_params = jnp.take(params, param_inds) - gate_unitary = gate_func(*gate_params) - gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form + for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_array_seq): + gate_unitary = _gate_func_to_unitary(gate_func, qubit_inds, param_inds, params) statetensor = apply_gate(statetensor, gate_unitary, qubit_inds) return statetensor - if all([pi.size == 0 for pi in param_inds_seq]): + non_parameterised = all([pi.size == 0 for pi in param_inds_array_seq]) + if non_parameterised: def no_params_to_statetensor_func(statetensor_in: jnp.ndarray = None) -> jnp.ndarray: """ Applies circuit (series of gates with no parameters) to a statetensor_in (default is |0>^N). diff --git a/qujax/circuit_tools.py b/qujax/circuit_tools.py index ed1ff0a..9301b5e 100644 --- a/qujax/circuit_tools.py +++ b/qujax/circuit_tools.py @@ -41,7 +41,8 @@ def check_circuit(gate_seq: Sequence[Union[str, Callable[[], jnp.ndarray]]], qubit_inds_seq: Sequence[Sequence[int]], param_inds_seq: Sequence[Sequence[int]], - n_qubits: int = None): + n_qubits: int = None, + check_unitaries: bool = True): """ Basic checks that circuit arguments conform. @@ -55,6 +56,7 @@ def check_circuit(gate_seq: Sequence[Union[str, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, the second gate is not parameterised and the third gates used the fifth and second parameters. n_qubits: Number of qubits, if fixed. + check_unitaries: boolean on whether to check if each gate represents a unitary matrix """ if not isinstance(gate_seq, collections.abc.Sequence): @@ -75,8 +77,9 @@ def check_circuit(gate_seq: Sequence[Union[str, if n_qubits is not None and n_qubits < max([max(qi) for qi in qubit_inds_seq]) + 1: raise TypeError('n_qubits must be larger than largest qubit index in qubit_inds_seq') - for g in gate_seq: - check_unitary(g) + if check_unitaries: + for g in gate_seq: + check_unitary(g) def _get_gate_str(gate_obj: Union[str, diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index a91e47c..e8fecb7 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, Iterable +from typing import Sequence, Union, Callable, Iterable, Tuple from jax import numpy as jnp from jax.lax import scan -from qujax.circuit import apply_gate, UnionCallableOptionalArray, _to_gate_funcs, _arrayify_inds +from qujax.circuit import apply_gate, UnionCallableOptionalArray, gate_type +from qujax.circuit import _to_gate_func, _arrayify_inds, _gate_func_to_unitary from qujax.circuit_tools import check_circuit +kraus_op_type = Union[gate_type, Iterable[gate_type]] + def _kraus_single(densitytensor: jnp.ndarray, array: jnp.ndarray, @@ -58,10 +61,34 @@ def kraus(densitytensor: jnp.ndarray, return new_densitytensor -def get_params_to_densitytensor_func(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], +def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, + param_inds: Union[Sequence[int], Sequence[Sequence[int]]]) \ + -> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]], + Sequence[jnp.ndarray]]: + """ + Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to tensors + and that each element of param_inds_seq is a sequence of arrays that correspond to the parameter indices + of each Kraus operator. + + Args: + kraus_op: Either a normal gate_type or a sequence of gate_types representing Kraus operators. + param_inds: If kraus_op is a normal gate_type then a sequence of parameter indices, + if kraus_op is a sequence of Kraus operators then a sequence of sequences of parameter indices + + Returns: + Tuple containing sequence of functions mapping to Kraus operators + and sequence of arrays with parameter indices + + """ + if isinstance(kraus_op, (list, tuple)): + kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op] + else: + kraus_op_funcs = [_to_gate_func(kraus_op)] + param_inds = [param_inds] + return kraus_op_funcs, _arrayify_inds(param_inds) + + +def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], param_inds_seq: Sequence[Sequence[int]], n_qubits: int = None) -> UnionCallableOptionalArray: @@ -71,7 +98,7 @@ def get_params_to_densitytensor_func(gate_seq: Sequence[Union[str, densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits) Args: - gate_seq: Sequence of gates. + kraus_ops_seq: Sequence of gates. Each element is either a string matching a unitary array or function in qujax.gates, a custom unitary array or a custom function taking parameters and returning a unitary array. Unitary arrays will be reshaped into tensor form (2, 2,...) @@ -90,13 +117,15 @@ def get_params_to_densitytensor_func(gate_seq: Sequence[Union[str, """ - check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + check_circuit(kraus_ops_seq, qubit_inds_seq, param_inds_seq, n_qubits, False) if n_qubits is None: n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 - gate_seq_callable = _to_gate_funcs(gate_seq) - param_inds_seq = _arrayify_inds(param_inds_seq) + kraus_ops_seq_callable_and_param_inds = [_to_kraus_operator_seq_funcs(ko, param_inds) + for ko, param_inds in zip(kraus_ops_seq, param_inds_seq)] + kraus_ops_seq_callable = [ko_pi[0] for ko_pi in kraus_ops_seq_callable_and_param_inds] + param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds] def params_to_densitytensor_func(params: jnp.ndarray, densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: @@ -118,14 +147,15 @@ def params_to_densitytensor_func(params: jnp.ndarray, else: densitytensor = densitytensor_in params = jnp.atleast_1d(params) - for gate_func, qubit_inds, param_inds in zip(gate_seq_callable, qubit_inds_seq, param_inds_seq): - gate_params = jnp.take(params, param_inds) - gate_unitary = gate_func(*gate_params) - gate_unitary = gate_unitary.reshape((2,) * (2 * len(qubit_inds))) # Ensure gate is in tensor form - densitytensor = kraus(densitytensor, gate_unitary, qubit_inds) + for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip(kraus_ops_seq_callable, qubit_inds_seq, + param_inds_array_seq): + kraus_operators = [_gate_func_to_unitary(gf, qubit_inds, pi, params) + for gf, pi in zip(gate_func_single_seq, param_inds_single_seq)] + densitytensor = kraus(densitytensor, kraus_operators, qubit_inds) return densitytensor - if all([pi.size == 0 for pi in param_inds_seq]): + non_parameterised = all([all([pi.size == 0 for pi in pi_seq]) for pi_seq in param_inds_array_seq]) + if non_parameterised: def no_params_to_densitytensor_func(densitytensor_in: jnp.ndarray = None) -> jnp.ndarray: """ Applies circuit (series of gates with no parameters) to a densitytensor_in (default is |0>^N <0|^N). diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index c246cdf..1a7d1b6 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -124,3 +124,41 @@ def test_params_to_densitytensor_func(): jit_dt = jit(params_to_dt)(params) assert jnp.allclose(jit_dt, dt_test) + + +def test_params_to_densitytensor_func_with_bit_flip(): + n_qubits = 2 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i+1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_pre_bf_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + kraus_ops = [[0.3 * jnp.eye(2), 0.7 * qujax.gates.X]] + kraus_qubit_inds = [(0,)] + kraus_param_inds = [((), ())] + + gate_seq += kraus_ops + qubit_inds_seq += kraus_qubit_inds + param_inds_seq += kraus_param_inds + + params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(n_qubits)/10. + + pre_bf_st = params_to_pre_bf_st(params) + pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape(2 for _ in range(2*n_qubits)) + dt_test = kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0]) + + dt = params_to_dt(params) + + assert jnp.allclose(dt, dt_test) + + jit_dt = jit(params_to_dt)(params) + assert jnp.allclose(jit_dt, dt_test) + From 82d06fe8e7feb937570177efde3c7fdd4febc9ad Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 25 Oct 2022 17:16:33 +0100 Subject: [PATCH 07/25] typehint --- qujax/density_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index e8fecb7..dd33261 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -90,7 +90,7 @@ def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Union[Sequence[int], Sequence[Sequence[int]]]], n_qubits: int = None) -> UnionCallableOptionalArray: """ Creates a function that maps circuit parameters to a density tensor. From 41ff2fcefacb0a77c98667c821f441b2a228cadc Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Wed, 26 Oct 2022 15:53:49 +0100 Subject: [PATCH 08/25] 'complex64' -> 'complex' in kraus Fixes scan error: dtype mismatch when JAX is in double precision mode --- qujax/density_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index dd33261..37cb6dc 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -56,7 +56,7 @@ def kraus(densitytensor: jnp.ndarray, # ensure first dimensions indexes different kraus operators new_densitytensor, _ = scan(lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), - init=jnp.zeros_like(densitytensor, dtype='complex64'), xs=arrays) + init=jnp.zeros_like(densitytensor, dtype='complex'), xs=arrays) # i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds).sum(0) return new_densitytensor From 3fabaa2ff4adbd7c80e650914c7a536936ff64bd Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Wed, 26 Oct 2022 21:20:20 +0100 Subject: [PATCH 09/25] Multiply by 0.j instead to make densitytensor complex Co-authored-by: SamDuffield <34280297+SamDuffield@users.noreply.github.com> --- qujax/density_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index 37cb6dc..ad15f9b 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -56,7 +56,7 @@ def kraus(densitytensor: jnp.ndarray, # ensure first dimensions indexes different kraus operators new_densitytensor, _ = scan(lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), - init=jnp.zeros_like(densitytensor, dtype='complex'), xs=arrays) + init=jnp.zeros_like(densitytensor) * 0.j, xs=arrays) # i.e. new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(densitytensor, arrays, qubit_inds).sum(0) return new_densitytensor From 60d6bf269c2e673422dae02f7138bc2b2b3b25ba Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 28 Oct 2022 11:11:08 +0100 Subject: [PATCH 10/25] allow None in param_inds_seq --- qujax/circuit.py | 4 ++-- qujax/circuit_tools.py | 3 ++- qujax/density_matrix.py | 7 +++++-- tests/test_circuits.py | 2 +- tests/test_density_matrix.py | 10 +++++----- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/qujax/circuit.py b/qujax/circuit.py index 18e4e17..22755ac 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -74,7 +74,7 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: return gate_func -def _arrayify_inds(param_inds_seq: Sequence[Sequence[int]]) -> Sequence[jnp.ndarray]: +def _arrayify_inds(param_inds_seq: Sequence[Union[None, Sequence[int]]]) -> Sequence[jnp.ndarray]: """ Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) @@ -116,7 +116,7 @@ def _gate_func_to_unitary(gate_func: Callable[[jnp.ndarray], jnp.ndarray], def get_params_to_statetensor_func(gate_seq: Sequence[gate_type], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Sequence[int]], + param_inds_seq: Sequence[Union[None, Sequence[int]]], n_qubits: int = None) -> UnionCallableOptionalArray: """ Creates a function that maps circuit parameters to a statetensor. diff --git a/qujax/circuit_tools.py b/qujax/circuit_tools.py index 9301b5e..7e6a4f4 100644 --- a/qujax/circuit_tools.py +++ b/qujax/circuit_tools.py @@ -67,7 +67,8 @@ def check_circuit(gate_seq: Sequence[Union[str, raise TypeError('qubit_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]') if (not isinstance(param_inds_seq, collections.abc.Sequence)) or \ - (any([not (isinstance(p, collections.abc.Sequence) or hasattr(p, '__array__')) for p in param_inds_seq])): + (any([not (isinstance(p, collections.abc.Sequence) or hasattr(p, '__array__') or p is None) + for p in param_inds_seq])): raise TypeError('param_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]') if len(gate_seq) != len(qubit_inds_seq) or len(param_inds_seq) != len(param_inds_seq): diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index ad15f9b..3b19555 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -62,7 +62,7 @@ def kraus(densitytensor: jnp.ndarray, def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, - param_inds: Union[Sequence[int], Sequence[Sequence[int]]]) \ + param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) \ -> Tuple[Sequence[Callable[[jnp.ndarray], jnp.ndarray]], Sequence[jnp.ndarray]]: """ @@ -80,6 +80,9 @@ def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, and sequence of arrays with parameter indices """ + if param_inds is None: + param_inds = [None for _ in kraus_op] + if isinstance(kraus_op, (list, tuple)): kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op] else: @@ -90,7 +93,7 @@ def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], - param_inds_seq: Sequence[Union[Sequence[int], Sequence[Sequence[int]]]], + param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]], n_qubits: int = None) -> UnionCallableOptionalArray: """ Creates a function that maps circuit parameters to a density tensor. diff --git a/tests/test_circuits.py b/tests/test_circuits.py index dd3f3c6..7dd4219 100644 --- a/tests/test_circuits.py +++ b/tests/test_circuits.py @@ -37,7 +37,7 @@ def test_H_redundant_qubits(): def test_CX_Rz_CY(): gates = ['H', 'H', 'H', 'CX', 'Rz', 'CY'] qubits = [[0], [1], [2], [0, 1], [1], [1, 2]] - param_inds = [[], [], [], [], [0], []] + param_inds = [[], [], [], None, [0], []] param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds) st = param_to_st(jnp.array(0.1)) diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 1a7d1b6..584735c 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -14,14 +14,14 @@ def test_kraus_single(): qubit_inds = (1,) - # qujax._kraus_single - qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator, qubit_inds) - qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) - unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T + # qujax._kraus_single + qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) qujax_kraus_dt_jit = jit(_kraus_single, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) @@ -141,7 +141,7 @@ def test_params_to_densitytensor_func_with_bit_flip(): kraus_ops = [[0.3 * jnp.eye(2), 0.7 * qujax.gates.X]] kraus_qubit_inds = [(0,)] - kraus_param_inds = [((), ())] + kraus_param_inds = [None] gate_seq += kraus_ops qubit_inds_seq += kraus_qubit_inds From 0c226c760b36868b0343ba051eab6532dba37981 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 28 Oct 2022 13:29:00 +0100 Subject: [PATCH 11/25] reshape kraus --- qujax/density_matrix.py | 3 ++- tests/test_density_matrix.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index 3b19555..b269a24 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -51,9 +51,10 @@ def kraus(densitytensor: jnp.ndarray, Updated density matrix. """ arrays = jnp.array(arrays) - if arrays.ndim == (2 * len(qubit_inds)): + if arrays.ndim % 2 == 0: arrays = arrays[jnp.newaxis] # ensure first dimensions indexes different kraus operators + arrays = arrays.reshape((arrays.shape[0],) + (2,) * 2 * len(qubit_inds)) new_densitytensor, _ = scan(lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), init=jnp.zeros_like(densitytensor) * 0.j, xs=arrays) diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 584735c..b17805f 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -48,14 +48,14 @@ def test_kraus_single_2qubit(): qubit_inds = (1, 2) - # qujax._kraus_single - qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator_tensor, qubit_inds) - qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) - unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) unitary_matrix = jnp.kron(unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))) check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T + # qujax._kraus_single + qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) qujax_kraus_dt_jit = jit(_kraus_single, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) @@ -67,6 +67,10 @@ def test_kraus_single_2qubit(): qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + qujax_kraus_dt = kraus(density_tensor, kraus_operator, qubit_inds) # check reshape kraus_operator correctly + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) + qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) @@ -82,9 +86,6 @@ def test_kraus_multiple(): qubit_inds = (1,) - qujax_kraus_dt = kraus(density_tensor, kraus_operators, qubit_inds) - qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) - unitary_matrices = [jnp.kron(jnp.eye(2 * qubit_inds[0]), ko) for ko in kraus_operators] unitary_matrices = [jnp.kron(um, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) for um in unitary_matrices] @@ -92,6 +93,9 @@ def test_kraus_multiple(): for um in unitary_matrices: check_kraus_dm += um @ density_matrix @ um.conj().T + qujax_kraus_dt = kraus(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) + assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operators, qubit_inds) From ac0580ea2c5c07353b05df6251615c4d0e31b645 Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Fri, 28 Oct 2022 14:59:47 +0100 Subject: [PATCH 12/25] Implement partial trace function --- qujax/density_matrix.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index b269a24..1d0f82b 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -91,6 +91,27 @@ def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, param_inds = [param_inds] return kraus_op_funcs, _arrayify_inds(param_inds) +def partial_trace(densitytensor: jnp.ndarray, + indices_to_trace: Iterable[int]) -> jnp.ndarray: + """ + Traces out (discards) specified qubits, resulting in a densitytensor + representing the mixed quantum state on the remaining qubits. + + Args: + densitytensor: Input densitytensor. + indices_to_trace: Indices of qubits to trace out/discard. + + Returns: + Resulting densitytensor on remaining qubits. + + """ + n_qubits = densitytensor.ndim // 2 + einsum_indices = list(range(densitytensor.ndim)) + for i in indices_to_trace: + einsum_indices[i + n_qubits] = einsum_indices[i] + densitytensor = jnp.einsum(densitytensor, einsum_indices) + return densitytensor + def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], From 201def8c337b34bc4744e83fca61f3ff7f5f3952 Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Fri, 28 Oct 2022 16:00:41 +0100 Subject: [PATCH 13/25] Add partial_trace to __init__.py --- qujax/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qujax/__init__.py b/qujax/__init__.py index 8ae3aea..006e551 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -20,6 +20,7 @@ from qujax.density_matrix import _kraus_single from qujax.density_matrix import kraus from qujax.density_matrix import get_params_to_densitytensor_func +from qujax.density_matrix import partial_trace del version del circuit From 71c1097ec0da57d9b99c061ae5b05c044c356b93 Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Fri, 28 Oct 2022 16:30:28 +0100 Subject: [PATCH 14/25] Add tests for partial trace --- tests/test_density_matrix.py | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index b17805f..9f414f9 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -3,6 +3,7 @@ import qujax from qujax import _kraus_single, kraus, get_params_to_densitytensor_func from qujax import get_params_to_statetensor_func +from qujax import partial_trace def test_kraus_single(): @@ -166,3 +167,40 @@ def test_params_to_densitytensor_func_with_bit_flip(): jit_dt = jit(params_to_dt)(params) assert jnp.allclose(jit_dt, dt_test) + +def test_partial_trace_1(): + state1 = 1/jnp.sqrt(2) * jnp.array([1., 1.]) + state2 = jnp.kron(state1, state1) + state3 = jnp.kron(state1, state2) + + dt1 = jnp.outer(state1, state1.conj()).reshape((2,) * 2) + dt2 = jnp.outer(state2, state2.conj()).reshape((2,) * 4) + dt3 = jnp.outer(state3, state3.conj()).reshape((2,) * 6) + + for i in range(3): + assert jnp.allclose(partial_trace(dt3, [i]), dt2) + + from itertools import combinations + for i in combinations(range(3), 2): + assert jnp.allclose(partial_trace(dt3, i), dt1) + +def test_partial_trace_2(): + n_qubits = 3 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i+1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(1, n_qubits+1)/10. + + dt = params_to_dt(params) + dt_discard_test = jnp.trace(dt, axis1=0, axis2=n_qubits) + dt_discard = partial_trace(dt, [0]) + + assert jnp.allclose(dt_discard, dt_discard_test) From 5aca3670fcd7e5244bae850fa530edb6c37421cb Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Mon, 31 Oct 2022 14:45:28 +0000 Subject: [PATCH 15/25] revamp expectation tests --- qujax/density_matrix.py | 3 +- tests/test_expectations.py | 75 ++++++++++++++++++++++++++------------ 2 files changed, 54 insertions(+), 24 deletions(-) diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index 1d0f82b..b6f839a 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -91,6 +91,7 @@ def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, param_inds = [param_inds] return kraus_op_funcs, _arrayify_inds(param_inds) + def partial_trace(densitytensor: jnp.ndarray, indices_to_trace: Iterable[int]) -> jnp.ndarray: """ @@ -104,7 +105,7 @@ def partial_trace(densitytensor: jnp.ndarray, Returns: Resulting densitytensor on remaining qubits. - """ + """ n_qubits = densitytensor.ndim // 2 einsum_indices = list(range(densitytensor.ndim)) for i in indices_to_trace: diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 8b4e707..2ef51b5 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -1,4 +1,4 @@ -from jax import numpy as jnp, jit, grad, random +from jax import numpy as jnp, jit, grad, random, config import qujax @@ -31,31 +31,33 @@ def st_to_expectation(statetensor): param_to_expectation = lambda p: st_to_expectation(param_to_st(p)) + def brute_force_param_to_exp(p): + sv = param_to_st(p).flatten() + return jnp.dot(sv, jnp.diag(costs) @ sv.conj()).real + + true_expectation = brute_force_param_to_exp(params) + expectation = param_to_expectation(params) expectation_jit = jit(param_to_expectation)(params) assert expectation.shape == () - assert expectation.dtype == 'float32' - assert jnp.abs(-0.97042876 - expectation) < 1e-5 - assert jnp.abs(-0.97042876 - expectation_jit) < 1e-5 + assert expectation.dtype.name[:5] == 'float' + assert jnp.isclose(true_expectation, expectation) + assert jnp.isclose(true_expectation, expectation_jit) + true_expectation_grad = grad(brute_force_param_to_exp)(params) expectation_grad = grad(param_to_expectation)(params) expectation_grad_jit = jit(grad(param_to_expectation))(params) - true_expectation_grad = jnp.array([5.1673526e-01, 1.2618620e+00, 5.1392573e-01, - 1.5056899e+00, 4.3226164e-02, 3.4227133e-02, - 8.1762001e-02, 7.7345759e-01, 5.1567715e-01, - -3.1131029e-01, -1.7132770e-01, -6.6244489e-01, - 9.3626760e-08, -4.6813380e-08, -2.3406690e-08, - -9.3626760e-08]) - assert expectation_grad.shape == (n_params,) - assert expectation_grad.dtype == 'float32' - assert jnp.all(jnp.abs(expectation_grad - true_expectation_grad) < 1e-5) - assert jnp.all(jnp.abs(expectation_grad_jit - true_expectation_grad) < 1e-5) + assert expectation_grad.dtype.name[:5] == 'float' + assert jnp.allclose(true_expectation_grad, expectation_grad, atol=1e-5) + assert jnp.allclose(true_expectation_grad, expectation_grad_jit, atol=1e-5) def test_ZZ_X(): + config.update("jax_enable_x64", True) # Run this test with 64 bit precision + n_qubits = 5 gate_str_seq_seq = [['Z', 'Z']] * (n_qubits - 1) + [['X']] * n_qubits @@ -70,19 +72,46 @@ def test_ZZ_X(): state /= jnp.linalg.norm(state) st_in = state.reshape((2,) * n_qubits) - jax_exp = st_to_exp(st_in) - jax_exp_jit = jit(st_to_exp)(st_in) - - assert jnp.abs(-0.23738188 - jax_exp) < 1e-5 - assert jnp.abs(-0.23738188 - jax_exp_jit) < 1e-5 + def big_unitary_matrix(gate_str_seq, qubit_inds): + qubit_gate_arrs = [getattr(qujax.gates, s) for s in gate_str_seq] + gate_arrs = [] + j = 0 + for i in range(n_qubits): + if i in qubit_inds: + gate_arrs.append(qubit_gate_arrs[j]) + j += 1 + else: + gate_arrs.append(jnp.eye(2)) + + big_u = gate_arrs[0] + for k in range(1, n_qubits): + big_u = jnp.kron(big_u, gate_arrs[k]) + return big_u + + sum_big_us = jnp.zeros((2 ** n_qubits, 2 ** n_qubits)) + for i in range(len(gate_str_seq_seq)): + sum_big_us += coefs[i] * big_unitary_matrix(gate_str_seq_seq[i], qubit_inds_seq[i]) + + sv = st_in.flatten() + true_exp = jnp.dot(sv, sum_big_us @ sv.conj()) + + qujax_exp = st_to_exp(st_in) + qujax_exp_jit = jit(st_to_exp)(st_in) + + assert jnp.array(qujax_exp).shape == () + assert jnp.array(qujax_exp).dtype.name[:5] == 'float' + assert jnp.isclose(true_exp, qujax_exp) + assert jnp.isclose(true_exp, qujax_exp_jit) st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(gate_str_seq_seq, qubit_inds_seq, coefs) - jax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 10000) - jax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 10000) - assert jnp.abs(-0.23738188 - jax_samp_exp) < 1e-2 - assert jnp.abs(-0.23738188 - jax_samp_exp_jit) < 1e-2 + qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 10000) + qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 100000) + assert jnp.array(qujax_samp_exp).shape == () + assert jnp.array(qujax_samp_exp).dtype.name[:5] == 'float' + assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2) + assert jnp.isclose(true_exp, qujax_samp_exp_jit, rtol=1e-2) def test_sampling(): From 067afa9eebe42d6b192b444c48a69551644f7e5d Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 1 Nov 2022 12:17:17 +0000 Subject: [PATCH 16/25] rename hermitian --- tests/test_expectations.py | 41 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 2ef51b5..7cb5940 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -1,5 +1,4 @@ from jax import numpy as jnp, jit, grad, random, config - import qujax @@ -58,13 +57,13 @@ def brute_force_param_to_exp(p): def test_ZZ_X(): config.update("jax_enable_x64", True) # Run this test with 64 bit precision - n_qubits = 5 + n_qubits = 4 - gate_str_seq_seq = [['Z', 'Z']] * (n_qubits - 1) + [['X']] * n_qubits - coefs = random.normal(random.PRNGKey(0), shape=(len(gate_str_seq_seq),)) + hermitian_str_seq_seq = [['Z', 'Z']] * (n_qubits - 1) + [['Y']] * n_qubits + coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [[i] for i in range(n_qubits)] - st_to_exp = qujax.get_statetensor_to_expectation_func(gate_str_seq_seq, + st_to_exp = qujax.get_statetensor_to_expectation_func(hermitian_str_seq_seq, qubit_inds_seq, coefs) @@ -72,28 +71,30 @@ def test_ZZ_X(): state /= jnp.linalg.norm(state) st_in = state.reshape((2,) * n_qubits) - def big_unitary_matrix(gate_str_seq, qubit_inds): - qubit_gate_arrs = [getattr(qujax.gates, s) for s in gate_str_seq] - gate_arrs = [] + def big_hermitian_matrix(hermitian_str_seq, qubit_inds): + qubit_arrs = [getattr(qujax.gates, s) for s in hermitian_str_seq] + hermitian_arrs = [] j = 0 for i in range(n_qubits): if i in qubit_inds: - gate_arrs.append(qubit_gate_arrs[j]) + hermitian_arrs.append(qubit_arrs[j]) j += 1 else: - gate_arrs.append(jnp.eye(2)) + hermitian_arrs.append(jnp.eye(2)) - big_u = gate_arrs[0] + big_h = hermitian_arrs[0] for k in range(1, n_qubits): - big_u = jnp.kron(big_u, gate_arrs[k]) - return big_u + big_h = jnp.kron(big_h, hermitian_arrs[k]) + return big_h + + sum_big_hs = jnp.zeros((2 ** n_qubits, 2 ** n_qubits), dtype='complex') + for i in range(len(hermitian_str_seq_seq)): + sum_big_hs += coefs[i] * big_hermitian_matrix(hermitian_str_seq_seq[i], qubit_inds_seq[i]) - sum_big_us = jnp.zeros((2 ** n_qubits, 2 ** n_qubits)) - for i in range(len(gate_str_seq_seq)): - sum_big_us += coefs[i] * big_unitary_matrix(gate_str_seq_seq[i], qubit_inds_seq[i]) + assert jnp.allclose(sum_big_hs, sum_big_hs.conj().T) sv = st_in.flatten() - true_exp = jnp.dot(sv, sum_big_us @ sv.conj()) + true_exp = jnp.dot(sv, sum_big_hs @ sv.conj()).real qujax_exp = st_to_exp(st_in) qujax_exp_jit = jit(st_to_exp)(st_in) @@ -103,11 +104,11 @@ def big_unitary_matrix(gate_str_seq, qubit_inds): assert jnp.isclose(true_exp, qujax_exp) assert jnp.isclose(true_exp, qujax_exp_jit) - st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(gate_str_seq_seq, + st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(hermitian_str_seq_seq, qubit_inds_seq, coefs) - qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 10000) - qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 100000) + qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) + qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) assert jnp.array(qujax_samp_exp).shape == () assert jnp.array(qujax_samp_exp).dtype.name[:5] == 'float' assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2) From a855445e45294f0f2be0d12e5660a8bc191b097c Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 1 Nov 2022 12:18:17 +0000 Subject: [PATCH 17/25] rename test --- tests/test_expectations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 7cb5940..151c3e3 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -54,7 +54,7 @@ def brute_force_param_to_exp(p): assert jnp.allclose(true_expectation_grad, expectation_grad_jit, atol=1e-5) -def test_ZZ_X(): +def test_ZZ_Y(): config.update("jax_enable_x64", True) # Run this test with 64 bit precision n_qubits = 4 From c7aa2946c52da61dcdc643dc21747541f427103b Mon Sep 17 00:00:00 2001 From: SamDuffield <34280297+SamDuffield@users.noreply.github.com> Date: Wed, 2 Nov 2022 13:58:35 +0000 Subject: [PATCH 18/25] Observable check Hermitian (#47) Clarify observable.py to only accept Hermitian matrices/tensors --- qujax/__init__.py | 2 + qujax/circuit.py | 4 +- qujax/circuit_tools.py | 8 +++ qujax/observable.py | 112 +++++++++++++++++++++++------------------ 4 files changed, 75 insertions(+), 51 deletions(-) diff --git a/qujax/__init__.py b/qujax/__init__.py index 006e551..3294e3e 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -6,8 +6,10 @@ from qujax.circuit import apply_gate from qujax.circuit import get_params_to_statetensor_func +from qujax.observable import statetensor_to_single_expectation from qujax.observable import get_statetensor_to_expectation_func from qujax.observable import get_statetensor_to_sampled_expectation_func +from qujax.observable import check_hermitian from qujax.observable import integers_to_bitstrings from qujax.observable import bitstrings_to_integers from qujax.observable import sample_integers diff --git a/qujax/circuit.py b/qujax/circuit.py index 22755ac..d5ef7b8 100644 --- a/qujax/circuit.py +++ b/qujax/circuit.py @@ -26,14 +26,14 @@ def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: """ Applies gate to statetensor and returns updated statetensor. - Gate is represented by a unitary matrix (i.e. not parameterised). + Gate is represented by a unitary matrix in tensor form. Args: statetensor: Input statetensor. gate_unitary: Unitary array representing gate must be in tensor form with shape (2,2,...). qubit_inds: Sequence of indices for gate to be applied to. - 2 * len(qubit_inds) is equal to the dimension of the gate unitary tensor. + Must have 2 * len(qubit_inds) = gate_unitary.ndim Returns: Updated statetensor. diff --git a/qujax/circuit_tools.py b/qujax/circuit_tools.py index 7e6a4f4..1d504ec 100644 --- a/qujax/circuit_tools.py +++ b/qujax/circuit_tools.py @@ -12,6 +12,14 @@ def check_unitary(gate: Union[str, jnp.ndarray, Callable[[jnp.ndarray], jnp.ndarray], Callable[[], jnp.ndarray]]): + """ + Checks whether a matrix or tensor is unitary. + + Args: + gate: array containing potentially unitary string, array + or function (which will be evaluated with all arguments set to 0.1). + + """ if isinstance(gate, str): if gate in gates.__dict__: gate = gates.__dict__[gate] diff --git a/qujax/observable.py b/qujax/observable.py index 376b0c1..b4631a5 100644 --- a/qujax/observable.py +++ b/qujax/observable.py @@ -4,52 +4,62 @@ from jax import numpy as jnp, random from jax.lax import fori_loop -from qujax import gates +from qujax.circuit import apply_gate +from qujax.gates import I, X, Y, Z +paulis = {'I': I, 'X': X, 'Y': Y, 'Z': Z} -def _statetensor_to_single_expectation_func(gate_tensor: jnp.ndarray, - qubit_inds: Sequence[int]) -> Callable[[jnp.ndarray], float]: + +def statetensor_to_single_expectation(statetensor: jnp.ndarray, + hermitian: jnp.ndarray, + qubit_inds: Sequence[int]) -> float: """ - Creates a function that maps statetensor to its expected value under the given gate unitary and qubit indices. + Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). Args: - gate_tensor: Gate unitary in tensor form. - qubit_inds: Sequence of integer qubit indices to apply gate to. + statetensor: Input statetensor. + hermitian: Hermitian array + must be in tensor form with shape (2,2,...). + qubit_inds: Sequence of qubit indices for Hermitian to be applied to. + Must have 2 * len(qubit_inds) = hermitian.ndim Returns: - Function that takes statetensor and returns expected value (float). + Expected value (float). """ + statetensor_new = apply_gate(statetensor, hermitian, qubit_inds) + axes = tuple(range(statetensor.ndim)) + return jnp.tensordot(statetensor.conjugate(), statetensor_new, axes=(axes, axes)).real - def statetensor_to_single_expectation(statetensor: jnp.ndarray) -> float: - """ - Evaluates expected value of statetensor through gate. - Args: - statetensor: Input statetensor. +def check_hermitian(hermitian: Union[str, jnp.ndarray]): + """ + Checks whether a matrix or tensor is Hermitian. - Returns: - Expected value (float). - """ - statetensor_new = jnp.tensordot(gate_tensor, statetensor, - axes=(list(range(-len(qubit_inds), 0)), qubit_inds)) - statetensor_new = jnp.moveaxis(statetensor_new, list(range(len(qubit_inds))), qubit_inds) - axes = tuple(range(statetensor.ndim)) - return jnp.tensordot(statetensor.conjugate(), statetensor_new, axes=(axes, axes)).real + Args: + hermitian: array containing potentially Hermitian matrix or tensor - return statetensor_to_single_expectation + """ + if isinstance(hermitian, str): + if hermitian not in paulis: + raise TypeError(f'qujax only accepts {tuple(paulis.keys())} as Hermitian strings, received: {hermitian}') + else: + n_qubits = hermitian.ndim // 2 + hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits) + if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()): + raise TypeError(f'Array not Hermitian: {hermitian}') -def get_statetensor_to_expectation_func(gate_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], +def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], qubits_seq_seq: Sequence[Sequence[int]], coefficients: Union[Sequence[float], jnp.ndarray]) \ -> Callable[[jnp.ndarray], float]: """ - Converts gate strings (or arrays), qubit indices and coefficients into a function that - converts statetensor into expected value. + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a statetensor into an expected value. Args: - gate_seq_seq: Sequence of sequences of gates. - Each gate is either a tensor (jnp.ndarray) or a string corresponding to an array in qujax.gates. + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. E.g. [[0,1], [2]] @@ -59,28 +69,30 @@ def get_statetensor_to_expectation_func(gate_seq_seq: Sequence[Sequence[Union[st Function that takes statetensor and returns expected value (float). """ - def get_gate_tensor(gate_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: + def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: """ - Convert sequence of gate strings into single gate unitary (in tensor form). + Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form into single array (in tensor form). Args: - gate_seq: Sequence of gate strings or arrays. + hermitian_seq: Sequence of Hermitian strings or arrays. Returns: - Single gate unitary in tensor form (array). + Hermitian matrix in tensor form (array). """ - single_gate_arrs = [gates.__dict__[gate] if isinstance(gate, str) else gate for gate in gate_seq] - single_gate_arrs = [gate_arr.reshape((2,) * int(jnp.log2(gate_arr.size))) - for gate_arr in single_gate_arrs] - full_gate_mat = single_gate_arrs[0] - for single_gate_matrix in single_gate_arrs[1:]: - full_gate_mat = jnp.kron(full_gate_mat, single_gate_matrix) - full_gate_mat = full_gate_mat.reshape((2,) * int(jnp.log2(full_gate_mat.size))) - return full_gate_mat - - apply_gate_funcs = [_statetensor_to_single_expectation_func(get_gate_tensor(gns), qi) - for gns, qi in zip(gate_seq_seq, qubits_seq_seq)] + for h in hermitian_seq: + check_hermitian(h) + + single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] + single_arrs = [h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs] + + full_mat = single_arrs[0] + for single_matrix in single_arrs[1:]: + full_mat = jnp.kron(full_mat, single_matrix) + full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size))) + return full_mat + + hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: """ @@ -94,24 +106,24 @@ def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: """ out = 0 - for coeff, f in zip(coefficients, apply_gate_funcs): - out += coeff * f(statetensor) + for hermitian, qubit_inds, coeff in zip(hermitian_tensors, qubits_seq_seq, coefficients): + out += coeff * statetensor_to_single_expectation(statetensor, hermitian, qubit_inds) return out return statetensor_to_expectation_func -def get_statetensor_to_sampled_expectation_func(gate_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], +def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], qubits_seq_seq: Sequence[Sequence[int]], coefficients: Union[Sequence[float], jnp.ndarray]) \ -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: """ - Converts gate strings (or arrays), qubit indices and coefficients into a function that - converts statetensor into a sampled expectation value. + Converts strings (or arrays) representing Hermitian matrices, qubit indices and + coefficients into a function that converts a statetensor into a sampled expected value. Args: - gate_seq_seq: Sequence of sequences of gates. - Each gate is either a tensor (jnp.ndarray) or a string corresponding to an array in qujax.gates. + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). E.g. [['Z', 'Z'], ['X']] qubits_seq_seq: Sequence of sequences of integer qubit indices. E.g. [[0,1], [2]] @@ -121,7 +133,9 @@ def get_statetensor_to_sampled_expectation_func(gate_seq_seq: Sequence[Sequence[ Function that takes statetensor, random key and integer number of shots and returns sampled expected value (float). """ - statetensor_to_expectation_func = get_statetensor_to_expectation_func(gate_seq_seq, qubits_seq_seq, coefficients) + statetensor_to_expectation_func = get_statetensor_to_expectation_func(hermitian_seq_seq, + qubits_seq_seq, + coefficients) def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, random_key: random.PRNGKeyArray, From 2e9a8364843e95fb23c142a4d7fe69884d562d20 Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Fri, 4 Nov 2022 10:33:22 +0000 Subject: [PATCH 19/25] Implement expectations for density matrices (#48) Implement expectations for density matrices Add densitytensor_to_single_expectation Add get_densitytensor_to_expectation_func Add get_densitytensor_to_sampled_expectation_func Add _get_tensor_to_expectation_func Refactor get_statetensor_to_expectation_func Exclude identity as valid Pauli Add relevant tests --- qujax/__init__.py | 3 + qujax/density_matrix.py | 16 ++++ qujax/observable.py | 175 +++++++++++++++++++++++++++++------ tests/test_density_matrix.py | 3 +- tests/test_expectations.py | 39 ++++++++ 5 files changed, 209 insertions(+), 27 deletions(-) diff --git a/qujax/__init__.py b/qujax/__init__.py index 3294e3e..114f248 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -7,8 +7,11 @@ from qujax.circuit import get_params_to_statetensor_func from qujax.observable import statetensor_to_single_expectation +from qujax.observable import densitytensor_to_single_expectation from qujax.observable import get_statetensor_to_expectation_func from qujax.observable import get_statetensor_to_sampled_expectation_func +from qujax.observable import get_densitytensor_to_expectation_func +from qujax.observable import get_densitytensor_to_sampled_expectation_func from qujax.observable import check_hermitian from qujax.observable import integers_to_bitstrings from qujax.observable import bitstrings_to_integers diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index b6f839a..f0d95dd 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -9,6 +9,22 @@ kraus_op_type = Union[gate_type, Iterable[gate_type]] +def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: + """ + Computes a densitytensor representation of a pure quantum state + from its statetensor representaton + + Args: + statetensor: Input statetensor. + + Returns: + A densitytensor representing the quantum state. + """ + n_qubits = statetensor.ndim + st = statetensor + dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2*n_qubits)) + return dt + def _kraus_single(densitytensor: jnp.ndarray, array: jnp.ndarray, diff --git a/qujax/observable.py b/qujax/observable.py index b4631a5..f338161 100644 --- a/qujax/observable.py +++ b/qujax/observable.py @@ -5,11 +5,35 @@ from jax.lax import fori_loop from qujax.circuit import apply_gate -from qujax.gates import I, X, Y, Z +from qujax.gates import X, Y, Z +from qujax.density_matrix import statetensor_to_densitytensor -paulis = {'I': I, 'X': X, 'Y': Y, 'Z': Z} +paulis = {'X': X, 'Y': Y, 'Z': Z} +def densitytensor_to_single_expectation(densitytensor: jnp.ndarray, + hermitian: jnp.ndarray, + qubit_inds: Sequence[int]) -> float: + """ + Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). + + Args: + densitytensor: Input densitytensor. + hermitian: Hermitian matrix representing observable + must be in tensor form with shape (2,2,...). + qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. + Must have 2 * len(qubit_inds) == hermitian.ndim + Returns: + Expected value (float). + """ + n_qubits = densitytensor.ndim // 2 + dt_indices = 2 * list(range(n_qubits)) + hermitian_indices = [i + densitytensor.ndim //2 for i in range(hermitian.ndim)] + for n, q in enumerate(qubit_inds): + dt_indices[q] = hermitian_indices[n + len(qubit_inds)] + dt_indices[q + n_qubits] = hermitian_indices[n] + return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real + def statetensor_to_single_expectation(statetensor: jnp.ndarray, hermitian: jnp.ndarray, qubit_inds: Sequence[int]) -> float: @@ -20,8 +44,8 @@ def statetensor_to_single_expectation(statetensor: jnp.ndarray, statetensor: Input statetensor. hermitian: Hermitian array must be in tensor form with shape (2,2,...). - qubit_inds: Sequence of qubit indices for Hermitian to be applied to. - Must have 2 * len(qubit_inds) = hermitian.ndim + qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. + Must have 2 * len(qubit_inds) == hermitian.ndim Returns: Expected value (float). @@ -49,27 +73,7 @@ def check_hermitian(hermitian: Union[str, jnp.ndarray]): raise TypeError(f'Array not Hermitian: {hermitian}') -def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray], float]: - """ - Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and - a list of coefficients and returns a function that converts a statetensor into an expected value. - - Args: - hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - - Returns: - Function that takes statetensor and returns expected value (float). - """ - - def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: +def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: """ Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form into single array (in tensor form). @@ -92,6 +96,31 @@ def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jn full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size))) return full_mat + +def _get_tensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], + contraction_function: Callable) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a tensor into an expected value. + The contraction function performs the tensor contraction according to the type of tensor provided + (i.e. whether it is a statetensor or a densitytensor). + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + contraction_function: Function that performs the tensor contraction. + + Returns: + Function that takes tensor and returns expected value (float). + """ + hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: @@ -107,11 +136,54 @@ def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: """ out = 0 for hermitian, qubit_inds, coeff in zip(hermitian_tensors, qubits_seq_seq, coefficients): - out += coeff * statetensor_to_single_expectation(statetensor, hermitian, qubit_inds) + out += coeff * contraction_function(statetensor, hermitian, qubit_inds) return out return statetensor_to_expectation_func +def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a statetensor into an expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes statetensor and returns expected value (float). + """ + + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, statetensor_to_single_expectation) + +def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a densitytensor into an expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes densitytensor and returns expected value (float). + """ + + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, densitytensor_to_single_expectation) def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], qubits_seq_seq: Sequence[Sequence[int]], @@ -164,6 +236,57 @@ def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, return statetensor_to_sampled_expectation_func +def get_densitytensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: + """ + Converts strings (or arrays) representing Hermitian matrices, qubit indices and + coefficients into a function that converts a densitytensor into a sampled expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes densitytensor, random key and integer number of shots + and returns sampled expected value (float). + """ + densitytensor_to_expectation_func = get_densitytensor_to_expectation_func(hermitian_seq_seq, + qubits_seq_seq, + coefficients) + + def densitytensor_to_sampled_expectation_func(statetensor: jnp.ndarray, + random_key: random.PRNGKeyArray, + n_samps: int) -> float: + """ + Maps statetensor to sampled expected value. + + Args: + statetensor: Input statetensor. + random_key: JAX random key + n_samps: Number of samples contributing to sampled expectation. + + Returns: + Sampled expected value (float). + + """ + sampled_integers = sample_integers(random_key, statetensor, n_samps) + sampled_probs = fori_loop(0, n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1), + jnp.zeros(statetensor.size)) + + sampled_probs /= n_samps + sampled_dt = statetensor_to_densitytensor(jnp.sqrt(sampled_probs).reshape(statetensor.shape)) + return densitytensor_to_expectation_func(sampled_dt) + + return densitytensor_to_sampled_expectation_func + + def integers_to_bitstrings(integers: Union[int, jnp.ndarray], nbits: int = None) -> jnp.ndarray: """ diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 9f414f9..2a08421 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -4,6 +4,7 @@ from qujax import _kraus_single, kraus, get_params_to_densitytensor_func from qujax import get_params_to_statetensor_func from qujax import partial_trace +from qujax.density_matrix import statetensor_to_densitytensor def test_kraus_single(): @@ -121,7 +122,7 @@ def test_params_to_densitytensor_func(): params = jnp.arange(n_qubits)/10. st = params_to_st(params) - dt_test = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2*n_qubits)) + dt_test = statetensor_to_densitytensor(st) dt = params_to_dt(params) diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 151c3e3..92b4873 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -1,6 +1,30 @@ from jax import numpy as jnp, jit, grad, random, config import qujax +from math import floor, ceil +from qujax.observable import densitytensor_to_single_expectation, statetensor_to_single_expectation +from qujax.density_matrix import statetensor_to_densitytensor + +from qujax.gates import Z + + +def test_single_expectation(): + st1 = jnp.zeros((2,2,2)) + st2 = jnp.zeros((2,2,2)) + st1 = st1.at[(0,0,0)].set(1.) + st2 = st2.at[(1,0,0)].set(1.) + dt1 = statetensor_to_densitytensor(st1) + dt2 = statetensor_to_densitytensor(st2) + ZZ = jnp.kron(Z, Z).reshape(2,2,2,2) + + est1 = statetensor_to_single_expectation(dt1, ZZ, [0, 1]) + est2 = statetensor_to_single_expectation(dt2, ZZ, [0, 1]) + edt1 = densitytensor_to_single_expectation(dt1, ZZ, [0, 1]) + edt2 = densitytensor_to_single_expectation(dt2, ZZ, [0, 1]) + + assert est1.item() == edt1.item() == 1 + assert est2.item() == edt2.item() == -1 + def test_bitstring_expectation(): n_qubits = 4 @@ -66,10 +90,14 @@ def test_ZZ_Y(): st_to_exp = qujax.get_statetensor_to_expectation_func(hermitian_str_seq_seq, qubit_inds_seq, coefs) + dt_to_exp = qujax.get_statetensor_to_expectation_func(hermitian_str_seq_seq, + qubit_inds_seq, + coefs) state = random.uniform(random.PRNGKey(0), shape=(2 ** n_qubits,)) * 2 state /= jnp.linalg.norm(state) st_in = state.reshape((2,) * n_qubits) + dt_in = statetensor_to_densitytensor(st_in) def big_hermitian_matrix(hermitian_str_seq, qubit_inds): qubit_arrs = [getattr(qujax.gates, s) for s in hermitian_str_seq] @@ -97,22 +125,33 @@ def big_hermitian_matrix(hermitian_str_seq, qubit_inds): true_exp = jnp.dot(sv, sum_big_hs @ sv.conj()).real qujax_exp = st_to_exp(st_in) + qujax_dt_exp = dt_to_exp(dt_in) qujax_exp_jit = jit(st_to_exp)(st_in) + qujax_dt_exp_jit = jit(dt_to_exp)(dt_in) assert jnp.array(qujax_exp).shape == () assert jnp.array(qujax_exp).dtype.name[:5] == 'float' assert jnp.isclose(true_exp, qujax_exp) + assert jnp.isclose(true_exp, qujax_dt_exp) assert jnp.isclose(true_exp, qujax_exp_jit) + assert jnp.isclose(true_exp, qujax_dt_exp_jit) st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(hermitian_str_seq_seq, qubit_inds_seq, coefs) + dt_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(hermitian_str_seq_seq, + qubit_inds_seq, + coefs) qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) + qujax_samp_exp_dt = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) + qujax_samp_exp_dt_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) assert jnp.array(qujax_samp_exp).shape == () assert jnp.array(qujax_samp_exp).dtype.name[:5] == 'float' assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2) assert jnp.isclose(true_exp, qujax_samp_exp_jit, rtol=1e-2) + assert jnp.isclose(true_exp, qujax_samp_exp_dt, rtol=1e-2) + assert jnp.isclose(true_exp, qujax_samp_exp_dt_jit, rtol=1e-2) def test_sampling(): From 106ce53cfd9e07fd163877ed35d8cf59fb38a911 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 4 Nov 2022 10:44:17 +0000 Subject: [PATCH 20/25] add dm measurements --- qujax/__init__.py | 2 ++ qujax/density_matrix.py | 47 ++++++++++++++++++++++++ tests/test_density_matrix.py | 69 +++++++++++++++++++++++++++++------- 3 files changed, 105 insertions(+), 13 deletions(-) diff --git a/qujax/__init__.py b/qujax/__init__.py index 3294e3e..7757e82 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -23,6 +23,8 @@ from qujax.density_matrix import kraus from qujax.density_matrix import get_params_to_densitytensor_func from qujax.density_matrix import partial_trace +from qujax.density_matrix import densitytensor_to_measurement_probabilities +from qujax.density_matrix import densitytensor_to_measured_densitytensor del version del circuit diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index b6f839a..68ed8cd 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -199,3 +199,50 @@ def no_params_to_densitytensor_func(densitytensor_in: jnp.ndarray = None) -> jnp return no_params_to_densitytensor_func return params_to_densitytensor_func + + +def densitytensor_to_measurement_probabilities(densitytensor: jnp.ndarray, + qubit_inds: Sequence[int]) -> jnp.ndarray: + """ + Extract array of measurement probabilities given a densitytensor and some qubit indices to measure + (in the computational basis). + I.e. the ith element of the array corresponds to the probability of observing the bitstring + represented by the integer i on the measured qubits. + + Args: + densitytensor: Input densitytensor. + qubit_inds: Sequence of qubit indices to measure. + + Returns: + Normalised array of measurement probabilities. + """ + n_qubits = densitytensor.ndim // 2 + n_qubits_measured = len(qubit_inds) + qubit_inds_trace_out = [i for i in range(n_qubits) if i not in qubit_inds] + return jnp.diag(partial_trace(densitytensor, qubit_inds_trace_out).reshape(2 * n_qubits_measured, + 2 * n_qubits_measured)).real + + +def densitytensor_to_measured_densitytensor(densitytensor: jnp.ndarray, + qubit_inds: Sequence[int], + measured_int: int) -> jnp.ndarray: + """ + Returns the post-measurement densitytensor assuming that qubit_inds are measured + (in the computational basis) and the bitstring corresponding to integer + measured_int is observed. + + Args: + densitytensor: Input densitytensor. + qubit_inds: Sequence of qubit indices to measure. + measured_int: Observed integer. + + Returns: + Post-measurement densitytensor (same shape as input densitytensor). + """ + n_qubits = densitytensor.ndim // 2 + n_qubits_measured = len(qubit_inds) + qubit_inds_projector = jnp.diag(jnp.zeros(2 ** n_qubits_measured).at[measured_int].set(1)) \ + .reshape((2,) * 2 * n_qubits_measured) + unnorm_densitytensor = _kraus_single(densitytensor, qubit_inds_projector, qubit_inds) + norm_const = jnp.trace(unnorm_densitytensor.reshape(2**n_qubits, 2**n_qubits)).real + return unnorm_densitytensor / norm_const diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 9f414f9..c33da2c 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -1,15 +1,17 @@ +from itertools import combinations from jax import numpy as jnp, jit import qujax from qujax import _kraus_single, kraus, get_params_to_densitytensor_func from qujax import get_params_to_statetensor_func from qujax import partial_trace +from qujax import densitytensor_to_measurement_probabilities, densitytensor_to_measured_densitytensor def test_kraus_single(): n_qubits = 3 dim = 2 ** n_qubits - density_matrix = jnp.arange(dim**2).reshape(dim, dim) + density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) kraus_operator = qujax.gates.Rx(0.2) @@ -42,7 +44,7 @@ def test_kraus_single(): def test_kraus_single_2qubit(): n_qubits = 4 dim = 2 ** n_qubits - density_matrix = jnp.arange(dim**2).reshape(dim, dim) + density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) kraus_operator_tensor = qujax.gates.ZZPhase(0.1) kraus_operator = qujax.gates.ZZPhase(0.1).reshape(4, 4) @@ -80,7 +82,7 @@ def test_kraus_single_2qubit(): def test_kraus_multiple(): n_qubits = 3 dim = 2 ** n_qubits - density_matrix = jnp.arange(dim**2).reshape(dim, dim) + density_matrix = jnp.arange(dim ** 2).reshape(dim, dim) density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) kraus_operators = [0.25 * qujax.gates.H, 0.25 * qujax.gates.Rx(0.3), 0.5 * qujax.gates.Ry(0.1)] @@ -112,16 +114,16 @@ def test_params_to_densitytensor_func(): param_inds_seq = [(i,) for i in range(n_qubits)] gate_seq += ["CZ" for _ in range(n_qubits - 1)] - qubit_inds_seq += [(i, i+1) for i in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) params_to_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) - params = jnp.arange(n_qubits)/10. + params = jnp.arange(n_qubits) / 10. st = params_to_st(params) - dt_test = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2*n_qubits)) + dt_test = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) dt = params_to_dt(params) @@ -139,7 +141,7 @@ def test_params_to_densitytensor_func_with_bit_flip(): param_inds_seq = [(i,) for i in range(n_qubits)] gate_seq += ["CZ" for _ in range(n_qubits - 1)] - qubit_inds_seq += [(i, i+1) for i in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] params_to_pre_bf_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) @@ -154,10 +156,10 @@ def test_params_to_densitytensor_func_with_bit_flip(): params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) - params = jnp.arange(n_qubits)/10. + params = jnp.arange(n_qubits) / 10. pre_bf_st = params_to_pre_bf_st(params) - pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape(2 for _ in range(2*n_qubits)) + pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) dt_test = kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0]) dt = params_to_dt(params) @@ -169,7 +171,7 @@ def test_params_to_densitytensor_func_with_bit_flip(): def test_partial_trace_1(): - state1 = 1/jnp.sqrt(2) * jnp.array([1., 1.]) + state1 = 1 / jnp.sqrt(2) * jnp.array([1., 1.]) state2 = jnp.kron(state1, state1) state3 = jnp.kron(state1, state2) @@ -180,10 +182,10 @@ def test_partial_trace_1(): for i in range(3): assert jnp.allclose(partial_trace(dt3, [i]), dt2) - from itertools import combinations for i in combinations(range(3), 2): assert jnp.allclose(partial_trace(dt3, i), dt1) + def test_partial_trace_2(): n_qubits = 3 @@ -192,15 +194,56 @@ def test_partial_trace_2(): param_inds_seq = [(i,) for i in range(n_qubits)] gate_seq += ["CZ" for _ in range(n_qubits - 1)] - qubit_inds_seq += [(i, i+1) for i in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) - params = jnp.arange(1, n_qubits+1)/10. + params = jnp.arange(1, n_qubits + 1) / 10. dt = params_to_dt(params) dt_discard_test = jnp.trace(dt, axis1=0, axis2=n_qubits) dt_discard = partial_trace(dt, [0]) assert jnp.allclose(dt_discard, dt_discard_test) + + +def test_measure(): + n_qubits = 3 + + gate_seq = ["Rx" for _ in range(n_qubits)] + qubit_inds_seq = [(i,) for i in range(n_qubits)] + param_inds_seq = [(i,) for i in range(n_qubits)] + + gate_seq += ["CZ" for _ in range(n_qubits - 1)] + qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] + param_inds_seq += [() for _ in range(n_qubits - 1)] + + params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + + params = jnp.arange(1, n_qubits + 1) / 10. + + dt = params_to_dt(params) + + qubit_inds = [0] + + all_probs = jnp.diag(dt.reshape(2 ** n_qubits, 2 ** n_qubits)).real + all_probs_marginalise \ + = all_probs.reshape((2,) * n_qubits).sum(axis=[i for i in range(n_qubits) if i not in qubit_inds]) + + probs = densitytensor_to_measurement_probabilities(dt, qubit_inds) + + assert jnp.isclose(probs.sum(), 1.) + assert jnp.isclose(all_probs.sum(), 1.) + assert jnp.allclose(probs, all_probs_marginalise) + + dm = dt.reshape(2 ** n_qubits, 2 ** n_qubits) + projector = jnp.array([[1, 0], [0, 0]]) + for _ in range(n_qubits - 1): + projector = jnp.kron(projector, jnp.eye(2)) + measured_dm = projector @ dm @ projector.T.conj() + measured_dm /= jnp.trace(projector.T.conj() @ projector @ dm) + measured_dt_true = measured_dm.reshape((2,) * 2 * n_qubits) + + measured_dt = densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) + assert jnp.allclose(measured_dt_true, measured_dt) From 7792ddabe50a5c1749e3f9703df5a3671a9329ce Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 4 Nov 2022 13:54:22 +0000 Subject: [PATCH 21/25] bitstring measurement and reformat --- qujax/__init__.py | 2 +- qujax/density_matrix.py | 29 +++++----------- qujax/observable.py | 64 ++++++++++++++++++++++++------------ tests/test_density_matrix.py | 3 +- tests/test_expectations.py | 2 +- 5 files changed, 55 insertions(+), 45 deletions(-) diff --git a/qujax/__init__.py b/qujax/__init__.py index 5d8bbe8..c020a00 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -10,13 +10,13 @@ from qujax.circuit_tools import check_circuit from qujax.circuit_tools import print_circuit +from qujax.observable import statetensor_to_densitytensor from qujax.observable import statetensor_to_single_expectation from qujax.observable import densitytensor_to_single_expectation from qujax.observable import get_statetensor_to_expectation_func from qujax.observable import get_statetensor_to_sampled_expectation_func from qujax.observable import get_densitytensor_to_expectation_func from qujax.observable import get_densitytensor_to_sampled_expectation_func -from qujax.observable import statetensor_to_densitytensor from qujax.observable import check_hermitian from qujax.observable import integers_to_bitstrings from qujax.observable import bitstrings_to_integers diff --git a/qujax/density_matrix.py b/qujax/density_matrix.py index cdf6274..d7ca836 100644 --- a/qujax/density_matrix.py +++ b/qujax/density_matrix.py @@ -6,25 +6,10 @@ from qujax.circuit import apply_gate, UnionCallableOptionalArray, gate_type from qujax.circuit import _to_gate_func, _arrayify_inds, _gate_func_to_unitary from qujax.circuit_tools import check_circuit +from qujax.observable import bitstrings_to_integers kraus_op_type = Union[gate_type, Iterable[gate_type]] -def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: - """ - Computes a densitytensor representation of a pure quantum state - from its statetensor representaton - - Args: - statetensor: Input statetensor. - - Returns: - A densitytensor representing the quantum state. - """ - n_qubits = statetensor.ndim - st = statetensor - dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2*n_qubits)) - return dt - def _kraus_single(densitytensor: jnp.ndarray, array: jnp.ndarray, @@ -241,24 +226,26 @@ def densitytensor_to_measurement_probabilities(densitytensor: jnp.ndarray, def densitytensor_to_measured_densitytensor(densitytensor: jnp.ndarray, qubit_inds: Sequence[int], - measured_int: int) -> jnp.ndarray: + measurement: Union[int, jnp.ndarray]) -> jnp.ndarray: """ Returns the post-measurement densitytensor assuming that qubit_inds are measured - (in the computational basis) and the bitstring corresponding to integer - measured_int is observed. + (in the computational basis) and the given measurement (integer or bitstring) is observed. Args: densitytensor: Input densitytensor. qubit_inds: Sequence of qubit indices to measure. - measured_int: Observed integer. + measurement: Observed integer or bitstring. Returns: Post-measurement densitytensor (same shape as input densitytensor). """ + measurement = jnp.array(measurement) + measured_int = bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement + n_qubits = densitytensor.ndim // 2 n_qubits_measured = len(qubit_inds) qubit_inds_projector = jnp.diag(jnp.zeros(2 ** n_qubits_measured).at[measured_int].set(1)) \ .reshape((2,) * 2 * n_qubits_measured) unnorm_densitytensor = _kraus_single(densitytensor, qubit_inds_projector, qubit_inds) - norm_const = jnp.trace(unnorm_densitytensor.reshape(2**n_qubits, 2**n_qubits)).real + norm_const = jnp.trace(unnorm_densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits)).real return unnorm_densitytensor / norm_const diff --git a/qujax/observable.py b/qujax/observable.py index f338161..3719f49 100644 --- a/qujax/observable.py +++ b/qujax/observable.py @@ -6,7 +6,6 @@ from qujax.circuit import apply_gate from qujax.gates import X, Y, Z -from qujax.density_matrix import statetensor_to_densitytensor paulis = {'X': X, 'Y': Y, 'Z': Z} @@ -28,12 +27,13 @@ def densitytensor_to_single_expectation(densitytensor: jnp.ndarray, """ n_qubits = densitytensor.ndim // 2 dt_indices = 2 * list(range(n_qubits)) - hermitian_indices = [i + densitytensor.ndim //2 for i in range(hermitian.ndim)] + hermitian_indices = [i + densitytensor.ndim // 2 for i in range(hermitian.ndim)] for n, q in enumerate(qubit_inds): dt_indices[q] = hermitian_indices[n + len(qubit_inds)] dt_indices[q + n_qubits] = hermitian_indices[n] return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real + def statetensor_to_single_expectation(statetensor: jnp.ndarray, hermitian: jnp.ndarray, qubit_inds: Sequence[int]) -> float: @@ -74,7 +74,7 @@ def check_hermitian(hermitian: Union[str, jnp.ndarray]): def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: - """ + """ Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form into single array (in tensor form). Args: @@ -84,23 +84,23 @@ def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jn Hermitian matrix in tensor form (array). """ - for h in hermitian_seq: - check_hermitian(h) + for h in hermitian_seq: + check_hermitian(h) - single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] - single_arrs = [h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs] + single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] + single_arrs = [h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs] - full_mat = single_arrs[0] - for single_matrix in single_arrs[1:]: - full_mat = jnp.kron(full_mat, single_matrix) - full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size))) - return full_mat + full_mat = single_arrs[0] + for single_matrix in single_arrs[1:]: + full_mat = jnp.kron(full_mat, single_matrix) + full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size))) + return full_mat def _get_tensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], - contraction_function: Callable) \ + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], + contraction_function: Callable) \ -> Callable[[jnp.ndarray], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and @@ -141,6 +141,7 @@ def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: return statetensor_to_expectation_func + def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], qubits_seq_seq: Sequence[Sequence[int]], coefficients: Union[Sequence[float], jnp.ndarray]) \ @@ -161,11 +162,13 @@ def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Uni Function that takes statetensor and returns expected value (float). """ - return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, statetensor_to_single_expectation) + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, + statetensor_to_single_expectation) + def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ -> Callable[[jnp.ndarray], float]: """ Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and @@ -183,7 +186,9 @@ def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[U Function that takes densitytensor and returns expected value (float). """ - return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, densitytensor_to_single_expectation) + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, + densitytensor_to_single_expectation) + def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], qubits_seq_seq: Sequence[Sequence[int]], @@ -261,8 +266,8 @@ def get_densitytensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Se coefficients) def densitytensor_to_sampled_expectation_func(statetensor: jnp.ndarray, - random_key: random.PRNGKeyArray, - n_samps: int) -> float: + random_key: random.PRNGKeyArray, + n_samps: int) -> float: """ Maps statetensor to sampled expected value. @@ -358,3 +363,20 @@ def sample_bitstrings(random_key: random.PRNGKeyArray, """ return integers_to_bitstrings(sample_integers(random_key, statetensor, n_samps), statetensor.ndim) + + +def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: + """ + Computes a densitytensor representation of a pure quantum state + from its statetensor representaton + + Args: + statetensor: Input statetensor. + + Returns: + A densitytensor representing the quantum state. + """ + n_qubits = statetensor.ndim + st = statetensor + dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) + return dt diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index ace46d0..7fe84b4 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -3,7 +3,8 @@ import qujax from qujax import get_params_to_statetensor_func -from qujax import _kraus_single, kraus, get_params_to_densitytensor_func, partial_trace, statetensor_to_densitytensor +from qujax import _kraus_single, kraus, get_params_to_densitytensor_func, partial_trace +from qujax.observable import statetensor_to_densitytensor from qujax import densitytensor_to_measurement_probabilities, densitytensor_to_measured_densitytensor diff --git a/tests/test_expectations.py b/tests/test_expectations.py index d8f5660..733eccb 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -3,7 +3,7 @@ import qujax.gates import qujax from qujax import densitytensor_to_single_expectation, statetensor_to_single_expectation -from qujax import statetensor_to_densitytensor +from qujax.observable import statetensor_to_densitytensor def test_single_expectation(): From 95d41ffec833320dfc8e9d263d1a0bbdf9118949 Mon Sep 17 00:00:00 2001 From: Gabriel Matos Date: Fri, 4 Nov 2022 14:40:24 +0000 Subject: [PATCH 22/25] Add test_measure case to test bit string argument --- tests/test_density_matrix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 7fe84b4..8fc586b 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -246,4 +246,6 @@ def test_measure(): measured_dt_true = measured_dm.reshape((2,) * 2 * n_qubits) measured_dt = densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) + measured_dt_bits = densitytensor_to_measured_densitytensor(dt, qubit_inds, (0,)*n_qubits) assert jnp.allclose(measured_dt_true, measured_dt) + assert jnp.allclose(measured_dt_true, measured_dt_bits) From 97c51c7e811093fb55a805880d6e2f935214a358 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Fri, 4 Nov 2022 16:48:02 +0000 Subject: [PATCH 23/25] restructure and add dt docs --- README.md | 11 +- docs/conf.py | 4 +- docs/densitytensor.rst | 14 + ...ensitytensor_to_measured_densitytensor.rst | 5 + ...itytensor_to_measurement_probabilities.rst | 5 + .../get_densitytensor_to_expectation_func.rst | 5 + ...sitytensor_to_sampled_expectation_func.rst | 5 + .../get_params_to_densitytensor_func.rst | 5 + docs/densitytensor/kraus.rst | 5 + docs/densitytensor/partial_trace.rst | 5 + .../statetensor_to_densitytensor.rst | 5 + docs/index.rst | 1 + qujax/__init__.py | 65 +-- qujax/{density_matrix.py => densitytensor.py} | 67 +-- qujax/densitytensor_observable.py | 157 +++++++ qujax/observable.py | 382 ------------------ qujax/{circuit.py => statetensor.py} | 39 +- qujax/statetensor_observable.py | 171 ++++++++ qujax/{circuit_tools.py => utils.py} | 192 +++++++-- qujax/version.py | 2 +- ...ensity_matrix.py => test_densitytensor.py} | 58 +-- tests/test_expectations.py | 27 +- 22 files changed, 652 insertions(+), 578 deletions(-) create mode 100644 docs/densitytensor.rst create mode 100644 docs/densitytensor/densitytensor_to_measured_densitytensor.rst create mode 100644 docs/densitytensor/densitytensor_to_measurement_probabilities.rst create mode 100644 docs/densitytensor/get_densitytensor_to_expectation_func.rst create mode 100644 docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst create mode 100644 docs/densitytensor/get_params_to_densitytensor_func.rst create mode 100644 docs/densitytensor/kraus.rst create mode 100644 docs/densitytensor/partial_trace.rst create mode 100644 docs/densitytensor/statetensor_to_densitytensor.rst rename qujax/{density_matrix.py => densitytensor.py} (76%) create mode 100644 qujax/densitytensor_observable.py delete mode 100644 qujax/observable.py rename qujax/{circuit.py => statetensor.py} (81%) create mode 100644 qujax/statetensor_observable.py rename qujax/{circuit_tools.py => utils.py} (59%) rename tests/{test_density_matrix.py => test_densitytensor.py} (73%) diff --git a/README.md b/README.md index 60befbf..267b275 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,15 @@ # qujax Represent a (parameterised) quantum circuit as a pure [JAX](https://github.com/google/jax) function that -takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes of the quantum state and can then be used -downstream for exact expectations, gradients or sampling. +takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes +of the quantum state and can then be used downstream for exact expectations, gradients or sampling. -A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support for GPUs/TPUs. +qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix, +which has shape ($2^N$, $2^N$). +This allows for mixed states and generic Kraus operators. + +A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support +for GPUs/TPUs. Some useful links: - [Documentation](https://cqcl.github.io/qujax/api/) diff --git a/docs/conf.py b/docs/conf.py index c58ed0e..00bb712 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,7 +17,7 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.napoleon'] +extensions = ['sphinx.ext.autodoc', 'sphinx_rtd_theme', 'sphinx.ext.napoleon', 'sphinx.ext.mathjax'] templates_path = ['_templates'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] @@ -38,3 +38,5 @@ 'Callable[[Optional[ndarray]], ndarray]]' } + +latex_engine = 'pdflatex' diff --git a/docs/densitytensor.rst b/docs/densitytensor.rst new file mode 100644 index 0000000..8dbce28 --- /dev/null +++ b/docs/densitytensor.rst @@ -0,0 +1,14 @@ +densitytensor +======================= + +.. toctree:: + + densitytensor/kraus + densitytensor/get_params_to_densitytensor_func + densitytensor/partial_trace + densitytensor/get_densitytensor_to_expectation_func + densitytensor/get_densitytensor_to_sampled_expectation_func + densitytensor/densitytensor_to_measurement_probabilities + densitytensor/densitytensor_to_measured_densitytensor + densitytensor/statetensor_to_densitytensor + diff --git a/docs/densitytensor/densitytensor_to_measured_densitytensor.rst b/docs/densitytensor/densitytensor_to_measured_densitytensor.rst new file mode 100644 index 0000000..7c52298 --- /dev/null +++ b/docs/densitytensor/densitytensor_to_measured_densitytensor.rst @@ -0,0 +1,5 @@ +densitytensor_to_measured_densitytensor +============================================== + +.. autofunction:: qujax.densitytensor_to_measured_densitytensor + diff --git a/docs/densitytensor/densitytensor_to_measurement_probabilities.rst b/docs/densitytensor/densitytensor_to_measurement_probabilities.rst new file mode 100644 index 0000000..50beb3a --- /dev/null +++ b/docs/densitytensor/densitytensor_to_measurement_probabilities.rst @@ -0,0 +1,5 @@ +densitytensor_to_measurement_probabilities +============================================== + +.. autofunction:: qujax.densitytensor_to_measurement_probabilities + diff --git a/docs/densitytensor/get_densitytensor_to_expectation_func.rst b/docs/densitytensor/get_densitytensor_to_expectation_func.rst new file mode 100644 index 0000000..c4a2d09 --- /dev/null +++ b/docs/densitytensor/get_densitytensor_to_expectation_func.rst @@ -0,0 +1,5 @@ +get_densitytensor_to_expectation_func +======================================= + +.. autofunction:: qujax.get_densitytensor_to_expectation_func + diff --git a/docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst b/docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst new file mode 100644 index 0000000..c71cdbe --- /dev/null +++ b/docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst @@ -0,0 +1,5 @@ +get_densitytensor_to_sampled_expectation_func +================================================ + +.. autofunction:: qujax.get_densitytensor_to_sampled_expectation_func + diff --git a/docs/densitytensor/get_params_to_densitytensor_func.rst b/docs/densitytensor/get_params_to_densitytensor_func.rst new file mode 100644 index 0000000..554b887 --- /dev/null +++ b/docs/densitytensor/get_params_to_densitytensor_func.rst @@ -0,0 +1,5 @@ +get_params_to_densitytensor_func +=================================== + +.. autofunction:: qujax.get_params_to_densitytensor_func + diff --git a/docs/densitytensor/kraus.rst b/docs/densitytensor/kraus.rst new file mode 100644 index 0000000..8f3959e --- /dev/null +++ b/docs/densitytensor/kraus.rst @@ -0,0 +1,5 @@ +kraus +======================= + +.. autofunction:: qujax.kraus + diff --git a/docs/densitytensor/partial_trace.rst b/docs/densitytensor/partial_trace.rst new file mode 100644 index 0000000..e432d02 --- /dev/null +++ b/docs/densitytensor/partial_trace.rst @@ -0,0 +1,5 @@ +partial_trace +=============================== + +.. autofunction:: qujax.partial_trace + diff --git a/docs/densitytensor/statetensor_to_densitytensor.rst b/docs/densitytensor/statetensor_to_densitytensor.rst new file mode 100644 index 0000000..3b66c4d --- /dev/null +++ b/docs/densitytensor/statetensor_to_densitytensor.rst @@ -0,0 +1,5 @@ +statetensor_to_densitytensor +=============================== + +.. autofunction:: qujax.statetensor_to_densitytensor + diff --git a/docs/index.rst b/docs/index.rst index c52f038..2a75cb1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Docs sample_bitstrings check_circuit print_circuit + densitytensor gates diff --git a/qujax/__init__.py b/qujax/__init__.py index c020a00..487f872 100644 --- a/qujax/__init__.py +++ b/qujax/__init__.py @@ -2,36 +2,39 @@ from qujax import gates -from qujax.circuit import UnionCallableOptionalArray -from qujax.circuit import apply_gate -from qujax.circuit import get_params_to_statetensor_func - -from qujax.circuit_tools import check_unitary -from qujax.circuit_tools import check_circuit -from qujax.circuit_tools import print_circuit - -from qujax.observable import statetensor_to_densitytensor -from qujax.observable import statetensor_to_single_expectation -from qujax.observable import densitytensor_to_single_expectation -from qujax.observable import get_statetensor_to_expectation_func -from qujax.observable import get_statetensor_to_sampled_expectation_func -from qujax.observable import get_densitytensor_to_expectation_func -from qujax.observable import get_densitytensor_to_sampled_expectation_func -from qujax.observable import check_hermitian -from qujax.observable import integers_to_bitstrings -from qujax.observable import bitstrings_to_integers -from qujax.observable import sample_integers -from qujax.observable import sample_bitstrings - -from qujax.density_matrix import _kraus_single -from qujax.density_matrix import kraus -from qujax.density_matrix import get_params_to_densitytensor_func -from qujax.density_matrix import partial_trace -from qujax.density_matrix import densitytensor_to_measurement_probabilities -from qujax.density_matrix import densitytensor_to_measured_densitytensor +from qujax.statetensor import apply_gate +from qujax.statetensor import get_params_to_statetensor_func + +from qujax.statetensor_observable import statetensor_to_single_expectation +from qujax.statetensor_observable import get_statetensor_to_expectation_func +from qujax.statetensor_observable import get_statetensor_to_sampled_expectation_func + +from qujax.densitytensor import _kraus_single +from qujax.densitytensor import kraus +from qujax.densitytensor import get_params_to_densitytensor_func +from qujax.densitytensor import partial_trace + +from qujax.densitytensor_observable import densitytensor_to_single_expectation +from qujax.densitytensor_observable import get_densitytensor_to_expectation_func +from qujax.densitytensor_observable import get_densitytensor_to_sampled_expectation_func +from qujax.densitytensor_observable import densitytensor_to_measurement_probabilities +from qujax.densitytensor_observable import densitytensor_to_measured_densitytensor + +from qujax.utils import UnionCallableOptionalArray +from qujax.utils import check_unitary +from qujax.utils import check_hermitian +from qujax.utils import check_circuit +from qujax.utils import print_circuit +from qujax.utils import integers_to_bitstrings +from qujax.utils import bitstrings_to_integers +from qujax.utils import sample_integers +from qujax.utils import sample_bitstrings +from qujax.utils import statetensor_to_densitytensor del version -del circuit -del observable -del circuit_tools -del density_matrix +del statetensor +del statetensor_observable +del densitytensor +del densitytensor_observable +del utils + diff --git a/qujax/density_matrix.py b/qujax/densitytensor.py similarity index 76% rename from qujax/density_matrix.py rename to qujax/densitytensor.py index d7ca836..394f991 100644 --- a/qujax/density_matrix.py +++ b/qujax/densitytensor.py @@ -3,21 +3,19 @@ from jax import numpy as jnp from jax.lax import scan -from qujax.circuit import apply_gate, UnionCallableOptionalArray, gate_type -from qujax.circuit import _to_gate_func, _arrayify_inds, _gate_func_to_unitary -from qujax.circuit_tools import check_circuit -from qujax.observable import bitstrings_to_integers - -kraus_op_type = Union[gate_type, Iterable[gate_type]] +from qujax.statetensor import apply_gate, UnionCallableOptionalArray +from qujax.statetensor import _to_gate_func, _arrayify_inds, _gate_func_to_unitary +from qujax.utils import check_circuit, kraus_op_type def _kraus_single(densitytensor: jnp.ndarray, array: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: - """ + r""" Performs single Kraus operation .. math:: + \rho_\text{out} = B \rho_\text{in} B^{\dagger} Args: @@ -37,7 +35,7 @@ def _kraus_single(densitytensor: jnp.ndarray, def kraus(densitytensor: jnp.ndarray, arrays: Iterable[jnp.ndarray], qubit_inds: Sequence[int]) -> jnp.ndarray: - """ + r""" Performs Kraus operation. .. math:: @@ -94,7 +92,7 @@ def _to_kraus_operator_seq_funcs(kraus_op: kraus_op_type, def partial_trace(densitytensor: jnp.ndarray, - indices_to_trace: Iterable[int]) -> jnp.ndarray: + indices_to_trace: Sequence[int]) -> jnp.ndarray: """ Traces out (discards) specified qubits, resulting in a densitytensor representing the mixed quantum state on the remaining qubits. @@ -120,7 +118,7 @@ def get_params_to_densitytensor_func(kraus_ops_seq: Sequence[kraus_op_type], param_inds_seq: Sequence[Union[None, Sequence[int], Sequence[Sequence[int]]]], n_qubits: int = None) -> UnionCallableOptionalArray: """ - Creates a function that maps circuit parameters to a density tensor. + Creates a function that maps circuit parameters to a density tensor (a density matrix in tensor form). densitytensor = densitymatrix.reshape((2,) * 2 * n_qubits) densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits) @@ -200,52 +198,3 @@ def no_params_to_densitytensor_func(densitytensor_in: jnp.ndarray = None) -> jnp return no_params_to_densitytensor_func return params_to_densitytensor_func - - -def densitytensor_to_measurement_probabilities(densitytensor: jnp.ndarray, - qubit_inds: Sequence[int]) -> jnp.ndarray: - """ - Extract array of measurement probabilities given a densitytensor and some qubit indices to measure - (in the computational basis). - I.e. the ith element of the array corresponds to the probability of observing the bitstring - represented by the integer i on the measured qubits. - - Args: - densitytensor: Input densitytensor. - qubit_inds: Sequence of qubit indices to measure. - - Returns: - Normalised array of measurement probabilities. - """ - n_qubits = densitytensor.ndim // 2 - n_qubits_measured = len(qubit_inds) - qubit_inds_trace_out = [i for i in range(n_qubits) if i not in qubit_inds] - return jnp.diag(partial_trace(densitytensor, qubit_inds_trace_out).reshape(2 * n_qubits_measured, - 2 * n_qubits_measured)).real - - -def densitytensor_to_measured_densitytensor(densitytensor: jnp.ndarray, - qubit_inds: Sequence[int], - measurement: Union[int, jnp.ndarray]) -> jnp.ndarray: - """ - Returns the post-measurement densitytensor assuming that qubit_inds are measured - (in the computational basis) and the given measurement (integer or bitstring) is observed. - - Args: - densitytensor: Input densitytensor. - qubit_inds: Sequence of qubit indices to measure. - measurement: Observed integer or bitstring. - - Returns: - Post-measurement densitytensor (same shape as input densitytensor). - """ - measurement = jnp.array(measurement) - measured_int = bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement - - n_qubits = densitytensor.ndim // 2 - n_qubits_measured = len(qubit_inds) - qubit_inds_projector = jnp.diag(jnp.zeros(2 ** n_qubits_measured).at[measured_int].set(1)) \ - .reshape((2,) * 2 * n_qubits_measured) - unnorm_densitytensor = _kraus_single(densitytensor, qubit_inds_projector, qubit_inds) - norm_const = jnp.trace(unnorm_densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits)).real - return unnorm_densitytensor / norm_const diff --git a/qujax/densitytensor_observable.py b/qujax/densitytensor_observable.py new file mode 100644 index 0000000..2de7079 --- /dev/null +++ b/qujax/densitytensor_observable.py @@ -0,0 +1,157 @@ +from __future__ import annotations +from typing import Sequence, Union, Callable +from jax import numpy as jnp, random +from jax.lax import fori_loop + +from qujax.densitytensor import _kraus_single, partial_trace +from qujax.statetensor_observable import _get_tensor_to_expectation_func +from qujax.utils import sample_integers, statetensor_to_densitytensor, bitstrings_to_integers + + +def densitytensor_to_single_expectation(densitytensor: jnp.ndarray, + hermitian: jnp.ndarray, + qubit_inds: Sequence[int]) -> float: + """ + Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). + + Args: + densitytensor: Input densitytensor. + hermitian: Hermitian matrix representing observable + must be in tensor form with shape (2,2,...). + qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. + Must have 2 * len(qubit_inds) == hermitian.ndim + Returns: + Expected value (float). + """ + n_qubits = densitytensor.ndim // 2 + dt_indices = 2 * list(range(n_qubits)) + hermitian_indices = [i + densitytensor.ndim // 2 for i in range(hermitian.ndim)] + for n, q in enumerate(qubit_inds): + dt_indices[q] = hermitian_indices[n + len(qubit_inds)] + dt_indices[q + n_qubits] = hermitian_indices[n] + return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real + + +def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a densitytensor into an expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) + or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes densitytensor and returns expected value (float). + """ + + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, + densitytensor_to_single_expectation) + + +def get_densitytensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: + """ + Converts strings (or arrays) representing Hermitian matrices, qubit indices and + coefficients into a function that converts a densitytensor into a sampled expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes densitytensor, random key and integer number of shots + and returns sampled expected value (float). + """ + densitytensor_to_expectation_func = get_densitytensor_to_expectation_func(hermitian_seq_seq, + qubits_seq_seq, + coefficients) + + def densitytensor_to_sampled_expectation_func(statetensor: jnp.ndarray, + random_key: random.PRNGKeyArray, + n_samps: int) -> float: + """ + Maps statetensor to sampled expected value. + + Args: + statetensor: Input statetensor. + random_key: JAX random key + n_samps: Number of samples contributing to sampled expectation. + + Returns: + Sampled expected value (float). + + """ + sampled_integers = sample_integers(random_key, statetensor, n_samps) + sampled_probs = fori_loop(0, n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1), + jnp.zeros(statetensor.size)) + + sampled_probs /= n_samps + sampled_dt = statetensor_to_densitytensor(jnp.sqrt(sampled_probs).reshape(statetensor.shape)) + return densitytensor_to_expectation_func(sampled_dt) + + return densitytensor_to_sampled_expectation_func + + +def densitytensor_to_measurement_probabilities(densitytensor: jnp.ndarray, + qubit_inds: Sequence[int]) -> jnp.ndarray: + """ + Extract array of measurement probabilities given a densitytensor and some qubit indices to measure + (in the computational basis). + I.e. the ith element of the array corresponds to the probability of observing the bitstring + represented by the integer i on the measured qubits. + + Args: + densitytensor: Input densitytensor. + qubit_inds: Sequence of qubit indices to measure. + + Returns: + Normalised array of measurement probabilities. + """ + n_qubits = densitytensor.ndim // 2 + n_qubits_measured = len(qubit_inds) + qubit_inds_trace_out = [i for i in range(n_qubits) if i not in qubit_inds] + return jnp.diag(partial_trace(densitytensor, qubit_inds_trace_out).reshape(2 * n_qubits_measured, + 2 * n_qubits_measured)).real + + +def densitytensor_to_measured_densitytensor(densitytensor: jnp.ndarray, + qubit_inds: Sequence[int], + measurement: Union[int, jnp.ndarray]) -> jnp.ndarray: + """ + Returns the post-measurement densitytensor assuming that qubit_inds are measured + (in the computational basis) and the given measurement (integer or bitstring) is observed. + + Args: + densitytensor: Input densitytensor. + qubit_inds: Sequence of qubit indices to measure. + measurement: Observed integer or bitstring. + + Returns: + Post-measurement densitytensor (same shape as input densitytensor). + """ + measurement = jnp.array(measurement) + measured_int = bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement + + n_qubits = densitytensor.ndim // 2 + n_qubits_measured = len(qubit_inds) + qubit_inds_projector = jnp.diag(jnp.zeros(2 ** n_qubits_measured).at[measured_int].set(1)) \ + .reshape((2,) * 2 * n_qubits_measured) + unnorm_densitytensor = _kraus_single(densitytensor, qubit_inds_projector, qubit_inds) + norm_const = jnp.trace(unnorm_densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits)).real + return unnorm_densitytensor / norm_const diff --git a/qujax/observable.py b/qujax/observable.py deleted file mode 100644 index 3719f49..0000000 --- a/qujax/observable.py +++ /dev/null @@ -1,382 +0,0 @@ -from __future__ import annotations -from typing import Sequence, Callable, Union, Optional - -from jax import numpy as jnp, random -from jax.lax import fori_loop - -from qujax.circuit import apply_gate -from qujax.gates import X, Y, Z - -paulis = {'X': X, 'Y': Y, 'Z': Z} - - -def densitytensor_to_single_expectation(densitytensor: jnp.ndarray, - hermitian: jnp.ndarray, - qubit_inds: Sequence[int]) -> float: - """ - Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). - - Args: - densitytensor: Input densitytensor. - hermitian: Hermitian matrix representing observable - must be in tensor form with shape (2,2,...). - qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. - Must have 2 * len(qubit_inds) == hermitian.ndim - Returns: - Expected value (float). - """ - n_qubits = densitytensor.ndim // 2 - dt_indices = 2 * list(range(n_qubits)) - hermitian_indices = [i + densitytensor.ndim // 2 for i in range(hermitian.ndim)] - for n, q in enumerate(qubit_inds): - dt_indices[q] = hermitian_indices[n + len(qubit_inds)] - dt_indices[q + n_qubits] = hermitian_indices[n] - return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real - - -def statetensor_to_single_expectation(statetensor: jnp.ndarray, - hermitian: jnp.ndarray, - qubit_inds: Sequence[int]) -> float: - """ - Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). - - Args: - statetensor: Input statetensor. - hermitian: Hermitian array - must be in tensor form with shape (2,2,...). - qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. - Must have 2 * len(qubit_inds) == hermitian.ndim - - Returns: - Expected value (float). - """ - statetensor_new = apply_gate(statetensor, hermitian, qubit_inds) - axes = tuple(range(statetensor.ndim)) - return jnp.tensordot(statetensor.conjugate(), statetensor_new, axes=(axes, axes)).real - - -def check_hermitian(hermitian: Union[str, jnp.ndarray]): - """ - Checks whether a matrix or tensor is Hermitian. - - Args: - hermitian: array containing potentially Hermitian matrix or tensor - - """ - if isinstance(hermitian, str): - if hermitian not in paulis: - raise TypeError(f'qujax only accepts {tuple(paulis.keys())} as Hermitian strings, received: {hermitian}') - else: - n_qubits = hermitian.ndim // 2 - hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits) - if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()): - raise TypeError(f'Array not Hermitian: {hermitian}') - - -def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: - """ - Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form into single array (in tensor form). - - Args: - hermitian_seq: Sequence of Hermitian strings or arrays. - - Returns: - Hermitian matrix in tensor form (array). - - """ - for h in hermitian_seq: - check_hermitian(h) - - single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] - single_arrs = [h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs] - - full_mat = single_arrs[0] - for single_matrix in single_arrs[1:]: - full_mat = jnp.kron(full_mat, single_matrix) - full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size))) - return full_mat - - -def _get_tensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray], - contraction_function: Callable) \ - -> Callable[[jnp.ndarray], float]: - """ - Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and - a list of coefficients and returns a function that converts a tensor into an expected value. - The contraction function performs the tensor contraction according to the type of tensor provided - (i.e. whether it is a statetensor or a densitytensor). - - Args: - hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - contraction_function: Function that performs the tensor contraction. - - Returns: - Function that takes tensor and returns expected value (float). - """ - - hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] - - def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: - """ - Maps statetensor to expected value. - - Args: - statetensor: Input statetensor. - - Returns: - Expected value (float). - - """ - out = 0 - for hermitian, qubit_inds, coeff in zip(hermitian_tensors, qubits_seq_seq, coefficients): - out += coeff * contraction_function(statetensor, hermitian, qubit_inds) - return out - - return statetensor_to_expectation_func - - -def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray], float]: - """ - Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and - a list of coefficients and returns a function that converts a statetensor into an expected value. - - Args: - hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - - Returns: - Function that takes statetensor and returns expected value (float). - """ - - return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, - statetensor_to_single_expectation) - - -def get_densitytensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray], float]: - """ - Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and - a list of coefficients and returns a function that converts a densitytensor into an expected value. - - Args: - hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - - Returns: - Function that takes densitytensor and returns expected value (float). - """ - - return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, - densitytensor_to_single_expectation) - - -def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: - """ - Converts strings (or arrays) representing Hermitian matrices, qubit indices and - coefficients into a function that converts a statetensor into a sampled expected value. - - Args: - hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - - Returns: - Function that takes statetensor, random key and integer number of shots - and returns sampled expected value (float). - """ - statetensor_to_expectation_func = get_statetensor_to_expectation_func(hermitian_seq_seq, - qubits_seq_seq, - coefficients) - - def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, - random_key: random.PRNGKeyArray, - n_samps: int) -> float: - """ - Maps statetensor to sampled expected value. - - Args: - statetensor: Input statetensor. - random_key: JAX random key - n_samps: Number of samples contributing to sampled expectation. - - Returns: - Sampled expected value (float). - - """ - sampled_integers = sample_integers(random_key, statetensor, n_samps) - sampled_probs = fori_loop(0, n_samps, - lambda i, sv: sv.at[sampled_integers[i]].add(1), - jnp.zeros(statetensor.size)) - - sampled_probs /= n_samps - sampled_st = jnp.sqrt(sampled_probs).reshape(statetensor.shape) - return statetensor_to_expectation_func(sampled_st) - - return statetensor_to_sampled_expectation_func - - -def get_densitytensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], - qubits_seq_seq: Sequence[Sequence[int]], - coefficients: Union[Sequence[float], jnp.ndarray]) \ - -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: - """ - Converts strings (or arrays) representing Hermitian matrices, qubit indices and - coefficients into a function that converts a densitytensor into a sampled expected value. - - Args: - hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. - Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). - E.g. [['Z', 'Z'], ['X']] - qubits_seq_seq: Sequence of sequences of integer qubit indices. - E.g. [[0,1], [2]] - coefficients: Sequence of float coefficients to scale the expected values. - - Returns: - Function that takes densitytensor, random key and integer number of shots - and returns sampled expected value (float). - """ - densitytensor_to_expectation_func = get_densitytensor_to_expectation_func(hermitian_seq_seq, - qubits_seq_seq, - coefficients) - - def densitytensor_to_sampled_expectation_func(statetensor: jnp.ndarray, - random_key: random.PRNGKeyArray, - n_samps: int) -> float: - """ - Maps statetensor to sampled expected value. - - Args: - statetensor: Input statetensor. - random_key: JAX random key - n_samps: Number of samples contributing to sampled expectation. - - Returns: - Sampled expected value (float). - - """ - sampled_integers = sample_integers(random_key, statetensor, n_samps) - sampled_probs = fori_loop(0, n_samps, - lambda i, sv: sv.at[sampled_integers[i]].add(1), - jnp.zeros(statetensor.size)) - - sampled_probs /= n_samps - sampled_dt = statetensor_to_densitytensor(jnp.sqrt(sampled_probs).reshape(statetensor.shape)) - return densitytensor_to_expectation_func(sampled_dt) - - return densitytensor_to_sampled_expectation_func - - -def integers_to_bitstrings(integers: Union[int, jnp.ndarray], - nbits: int = None) -> jnp.ndarray: - """ - Convert integer or array of integers into their binary expansion(s). - - Args: - integers: Integer or array of integers to be converted. - nbits: Length of output binary expansion. - Defaults to smallest possible. - - Returns: - Array of binary expansion(s). - """ - integers = jnp.atleast_1d(integers) - if nbits is None: - nbits = (jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5)).astype(int) - - return jnp.squeeze(((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int)) - - -def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: - """ - Convert binary expansion(s) into integers. - - Args: - bitstrings: Bitstring array or array of bitstring arrays. - - Returns: - Array of integers. - """ - bitstrings = jnp.atleast_2d(bitstrings) - convarr = 2 ** jnp.arange(bitstrings.shape[-1] - 1, -1, -1) - return jnp.squeeze(bitstrings.dot(convarr)).astype(int) - - -def sample_integers(random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1) -> jnp.ndarray: - """ - Generate random integer samples according to statetensor. - - Args: - random_key: JAX random key to seed samples. - statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). - n_samps: Number of samples to generate. Defaults to 1. - - Returns: - Array with sampled integers, shape=(n_samps,). - - """ - sv_probs = jnp.square(jnp.abs(statetensor.flatten())) - sampled_inds = random.choice(random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs) - return sampled_inds - - -def sample_bitstrings(random_key: random.PRNGKeyArray, - statetensor: jnp.ndarray, - n_samps: Optional[int] = 1) -> jnp.ndarray: - """ - Generate random bitstring samples according to statetensor. - - Args: - random_key: JAX random key to seed samples. - statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). - n_samps: Number of samples to generate. Defaults to 1. - - Returns: - Array with sampled bitstrings, shape=(n_samps, statetensor.ndim). - - """ - return integers_to_bitstrings(sample_integers(random_key, statetensor, n_samps), statetensor.ndim) - - -def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: - """ - Computes a densitytensor representation of a pure quantum state - from its statetensor representaton - - Args: - statetensor: Input statetensor. - - Returns: - A densitytensor representing the quantum state. - """ - n_qubits = statetensor.ndim - st = statetensor - dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) - return dt diff --git a/qujax/circuit.py b/qujax/statetensor.py similarity index 81% rename from qujax/circuit.py rename to qujax/statetensor.py index d5ef7b8..f0bda56 100644 --- a/qujax/circuit.py +++ b/qujax/statetensor.py @@ -1,26 +1,9 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, Protocol +from typing import Sequence, Union, Callable from jax import numpy as jnp from qujax import gates -from qujax.circuit_tools import check_circuit - - -class CallableArrayAndOptionalArray(Protocol): - def __call__(self, params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: - ... - - -class CallableOptionalArray(Protocol): - def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: - ... - - -UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] -gate_type = Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]] +from qujax.utils import check_circuit, _arrayify_inds, UnionCallableOptionalArray, gate_type def apply_gate(statetensor: jnp.ndarray, gate_unitary: jnp.ndarray, qubit_inds: Sequence[int]) -> jnp.ndarray: @@ -74,24 +57,6 @@ def _array_to_callable(arr: jnp.ndarray) -> Callable[[], jnp.ndarray]: return gate_func -def _arrayify_inds(param_inds_seq: Sequence[Union[None, Sequence[int]]]) -> Sequence[jnp.ndarray]: - """ - Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) - - Args: - param_inds_seq: Sequence of sequences representing parameter indices that gates are using, - i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter - (the float at position zero in the parameter vector/array), the second gate is not parameterised - and the third gates used the parameters at position five and two. - - Returns: - Sequence of arrays representing parameter indices. - """ - param_inds_seq = [jnp.array(p) for p in param_inds_seq] - param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] - return param_inds_seq - - def _gate_func_to_unitary(gate_func: Callable[[jnp.ndarray], jnp.ndarray], qubit_inds: Sequence[int], param_inds: jnp.ndarray, diff --git a/qujax/statetensor_observable.py b/qujax/statetensor_observable.py new file mode 100644 index 0000000..ab1a08f --- /dev/null +++ b/qujax/statetensor_observable.py @@ -0,0 +1,171 @@ +from __future__ import annotations +from typing import Sequence, Callable, Union +from jax import numpy as jnp, random +from jax.lax import fori_loop + +from qujax.statetensor import apply_gate +from qujax.utils import check_hermitian, sample_integers, paulis + + +def statetensor_to_single_expectation(statetensor: jnp.ndarray, + hermitian: jnp.ndarray, + qubit_inds: Sequence[int]) -> float: + """ + Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). + + Args: + statetensor: Input statetensor. + hermitian: Hermitian array + must be in tensor form with shape (2,2,...). + qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. + Must have 2 * len(qubit_inds) == hermitian.ndim + + Returns: + Expected value (float). + """ + statetensor_new = apply_gate(statetensor, hermitian, qubit_inds) + axes = tuple(range(statetensor.ndim)) + return jnp.tensordot(statetensor.conjugate(), statetensor_new, axes=(axes, axes)).real + + +def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jnp.ndarray]]) -> jnp.ndarray: + """ + Convert a sequence of observables represented by Pauli strings or Hermitian matrices in tensor form + into single array (in tensor form). + + Args: + hermitian_seq: Sequence of Hermitian strings or arrays. + + Returns: + Hermitian matrix in tensor form (array). + """ + for h in hermitian_seq: + check_hermitian(h) + + single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] + single_arrs = [h_arr.reshape((2,) * int(jnp.log2(h_arr.size))) for h_arr in single_arrs] + + full_mat = single_arrs[0] + for single_matrix in single_arrs[1:]: + full_mat = jnp.kron(full_mat, single_matrix) + full_mat = full_mat.reshape((2,) * int(jnp.log2(full_mat.size))) + return full_mat + + +def _get_tensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray], + contraction_function: Callable) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a tensor into an expected value. + The contraction function performs the tensor contraction according to the type of tensor provided + (i.e. whether it is a statetensor or a densitytensor). + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + contraction_function: Function that performs the tensor contraction. + + Returns: + Function that takes tensor and returns expected value (float). + """ + + hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] + + def statetensor_to_expectation_func(statetensor: jnp.ndarray) -> float: + """ + Maps statetensor to expected value. + + Args: + statetensor: Input statetensor. + + Returns: + Expected value (float). + """ + out = 0 + for hermitian, qubit_inds, coeff in zip(hermitian_tensors, qubits_seq_seq, coefficients): + out += coeff * contraction_function(statetensor, hermitian, qubit_inds) + return out + + return statetensor_to_expectation_func + + +def get_statetensor_to_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray], float]: + """ + Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and + a list of coefficients and returns a function that converts a statetensor into an expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian matrix is either represented by a tensor (jnp.ndarray) + or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes statetensor and returns expected value (float). + """ + + return _get_tensor_to_expectation_func(hermitian_seq_seq, qubits_seq_seq, coefficients, + statetensor_to_single_expectation) + + +def get_statetensor_to_sampled_expectation_func(hermitian_seq_seq: Sequence[Sequence[Union[str, jnp.ndarray]]], + qubits_seq_seq: Sequence[Sequence[int]], + coefficients: Union[Sequence[float], jnp.ndarray]) \ + -> Callable[[jnp.ndarray, random.PRNGKeyArray, int], float]: + """ + Converts strings (or arrays) representing Hermitian matrices, qubit indices and + coefficients into a function that converts a statetensor into a sampled expected value. + + Args: + hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. + Each Hermitian is either a tensor (jnp.ndarray) or a string in ('X', 'Y', 'Z'). + E.g. [['Z', 'Z'], ['X']] + qubits_seq_seq: Sequence of sequences of integer qubit indices. + E.g. [[0,1], [2]] + coefficients: Sequence of float coefficients to scale the expected values. + + Returns: + Function that takes statetensor, random key and integer number of shots + and returns sampled expected value (float). + """ + statetensor_to_expectation_func = get_statetensor_to_expectation_func(hermitian_seq_seq, + qubits_seq_seq, + coefficients) + + def statetensor_to_sampled_expectation_func(statetensor: jnp.ndarray, + random_key: random.PRNGKeyArray, + n_samps: int) -> float: + """ + Maps statetensor to sampled expected value. + + Args: + statetensor: Input statetensor. + random_key: JAX random key + n_samps: Number of samples contributing to sampled expectation. + + Returns: + Sampled expected value (float). + """ + sampled_integers = sample_integers(random_key, statetensor, n_samps) + sampled_probs = fori_loop(0, n_samps, + lambda i, sv: sv.at[sampled_integers[i]].add(1), + jnp.zeros(statetensor.size)) + + sampled_probs /= n_samps + sampled_st = jnp.sqrt(sampled_probs).reshape(statetensor.shape) + return statetensor_to_expectation_func(sampled_st) + + return statetensor_to_sampled_expectation_func diff --git a/qujax/circuit_tools.py b/qujax/utils.py similarity index 59% rename from qujax/circuit_tools.py rename to qujax/utils.py index 1d504ec..983320d 100644 --- a/qujax/circuit_tools.py +++ b/qujax/utils.py @@ -1,17 +1,33 @@ from __future__ import annotations -from typing import Sequence, Union, Callable, List, Tuple, Optional +from typing import Sequence, Union, Callable, List, Tuple, Optional, Protocol, Iterable import collections.abc from inspect import signature - -from jax import numpy as jnp +from jax import numpy as jnp, random from qujax import gates +paulis = {'X': gates.X, 'Y': gates.Y, 'Z': gates.Z} + + +class CallableArrayAndOptionalArray(Protocol): + def __call__(self, params: jnp.ndarray, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: + ... + + +class CallableOptionalArray(Protocol): + def __call__(self, statetensor_in: jnp.ndarray = None) -> jnp.ndarray: + ... -def check_unitary(gate: Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]): + +UnionCallableOptionalArray = Union[CallableArrayAndOptionalArray, CallableOptionalArray] +gate_type = Union[str, + jnp.ndarray, + Callable[[jnp.ndarray], jnp.ndarray], + Callable[[], jnp.ndarray]] +kraus_op_type = Union[gate_type, Iterable[gate_type]] + + +def check_unitary(gate: gate_type): """ Checks whether a matrix or tensor is unitary. @@ -43,10 +59,45 @@ def check_unitary(gate: Union[str, raise TypeError(f'Gate not unitary: {gate}') -def check_circuit(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], +def check_hermitian(hermitian: Union[str, jnp.ndarray]): + """ + Checks whether a matrix or tensor is Hermitian. + + Args: + hermitian: array containing potentially Hermitian matrix or tensor + + """ + if isinstance(hermitian, str): + if hermitian not in paulis: + raise TypeError(f'qujax only accepts {tuple(paulis.keys())} as Hermitian strings, received: {hermitian}') + else: + n_qubits = hermitian.ndim // 2 + hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits) + if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()): + raise TypeError(f'Array not Hermitian: {hermitian}') + + +def _arrayify_inds(param_inds_seq: Sequence[Union[None, Sequence[int]]]) -> Sequence[jnp.ndarray]: + """ + Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) + + Args: + param_inds_seq: Sequence of sequences representing parameter indices that gates are using, + i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter + (the float at position zero in the parameter vector/array), the second gate is not parameterised + and the third gates used the parameters at position five and two. + + Returns: + Sequence of arrays representing parameter indices. + """ + if param_inds_seq is None: + param_inds_seq = [None] + param_inds_seq = [jnp.array(p) for p in param_inds_seq] + param_inds_seq = [jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) for p in param_inds_seq] + return param_inds_seq + + +def check_circuit(gate_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], param_inds_seq: Sequence[Sequence[int]], n_qubits: int = None, @@ -59,6 +110,7 @@ def check_circuit(gate_seq: Sequence[Union[str, Each element is either a string matching an array or function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) or a function taking parameters and returning gate unitary in tensor form. + Or alternatively a sequence of the above representing Kraus operators. qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. param_inds_seq: Sequence of parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, @@ -91,11 +143,8 @@ def check_circuit(gate_seq: Sequence[Union[str, check_unitary(g) -def _get_gate_str(gate_obj: Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]], - param_inds: Sequence[int]) -> str: +def _get_gate_str(gate_obj: kraus_op_type, + param_inds: Union[None, Sequence[int], Sequence[Sequence[int]]]) -> str: """ Maps single gate object to a four character string representation @@ -103,12 +152,18 @@ def _get_gate_str(gate_obj: Union[str, gate_obj: Either a string matching a function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape e.g. (2,2,2,...) ) or a function taking parameters (can be empty) and returning gate unitary in tensor form. - param_inds: Parameter indices that gates are using, i.e. gate uses 1st and 666th parameter. + Or alternatively, a sequence of Krause operators represented by strings, arrays or functions. + param_inds: Parameter indices that gates are using, i.e. gate uses 1st and 5th parameter. Returns: Four character string representation of the gate """ + if isinstance(gate_obj, (tuple, list)) or (hasattr(gate_obj, '__array__') and gate_obj.ndim % 2 == 1): + # Kraus operators + gate_obj = 'Kr' + param_inds = jnp.unique(jnp.concatenate(_arrayify_inds(param_inds), axis=0)) + if isinstance(gate_obj, str): gate_str = gate_obj elif hasattr(gate_obj, '__array__'): @@ -126,7 +181,10 @@ def _get_gate_str(gate_obj: Union[str, if hasattr(param_inds, 'tolist'): param_inds = param_inds.tolist() - if param_inds == [] or param_inds == [None]: + if isinstance(param_inds, tuple): + param_inds = list(param_inds) + + if param_inds == [] or param_inds == [None] or param_inds is None: if len(gate_str) > 7: gate_str = gate_str[:6] + '.' else: @@ -172,10 +230,7 @@ def extend_row(row: str, qubit_row: bool) -> str: return out_rows, [True] * len(rows) -def print_circuit(gate_seq: Sequence[Union[str, - jnp.ndarray, - Callable[[jnp.ndarray], jnp.ndarray], - Callable[[], jnp.ndarray]]], +def print_circuit(gate_seq: Sequence[kraus_op_type], qubit_inds_seq: Sequence[Sequence[int]], param_inds_seq: Sequence[Sequence[int]], n_qubits: Optional[int] = None, @@ -192,6 +247,7 @@ def print_circuit(gate_seq: Sequence[Union[str, Each element is either a string matching an array or function in qujax.gates, a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) or a function taking parameters and returning gate unitary in tensor form. + Or alternatively a sequence of the above representing Kraus operators. qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. param_inds_seq: Sequence of parameter indices that gates are using, i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, @@ -207,7 +263,7 @@ def print_circuit(gate_seq: Sequence[Union[str, String representation of circuit """ - check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits, False) gate_ind_max = min(len(gate_seq) - 1, gate_ind_max) if gate_ind_max < gate_ind_min: @@ -259,3 +315,93 @@ def print_circuit(gate_seq: Sequence[Union[str, print(p) return rows + + +def integers_to_bitstrings(integers: Union[int, jnp.ndarray], + nbits: int = None) -> jnp.ndarray: + """ + Convert integer or array of integers into their binary expansion(s). + + Args: + integers: Integer or array of integers to be converted. + nbits: Length of output binary expansion. + Defaults to smallest possible. + + Returns: + Array of binary expansion(s). + """ + integers = jnp.atleast_1d(integers) + if nbits is None: + nbits = (jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5)).astype(int) + + return jnp.squeeze(((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int)) + + +def bitstrings_to_integers(bitstrings: jnp.ndarray) -> Union[int, jnp.ndarray]: + """ + Convert binary expansion(s) into integers. + + Args: + bitstrings: Bitstring array or array of bitstring arrays. + + Returns: + Array of integers. + """ + bitstrings = jnp.atleast_2d(bitstrings) + convarr = 2 ** jnp.arange(bitstrings.shape[-1] - 1, -1, -1) + return jnp.squeeze(bitstrings.dot(convarr)).astype(int) + + +def sample_integers(random_key: random.PRNGKeyArray, + statetensor: jnp.ndarray, + n_samps: Optional[int] = 1) -> jnp.ndarray: + """ + Generate random integer samples according to statetensor. + + Args: + random_key: JAX random key to seed samples. + statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). + n_samps: Number of samples to generate. Defaults to 1. + + Returns: + Array with sampled integers, shape=(n_samps,). + + """ + sv_probs = jnp.square(jnp.abs(statetensor.flatten())) + sampled_inds = random.choice(random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs) + return sampled_inds + + +def sample_bitstrings(random_key: random.PRNGKeyArray, + statetensor: jnp.ndarray, + n_samps: Optional[int] = 1) -> jnp.ndarray: + """ + Generate random bitstring samples according to statetensor. + + Args: + random_key: JAX random key to seed samples. + statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). + n_samps: Number of samples to generate. Defaults to 1. + + Returns: + Array with sampled bitstrings, shape=(n_samps, statetensor.ndim). + + """ + return integers_to_bitstrings(sample_integers(random_key, statetensor, n_samps), statetensor.ndim) + + +def statetensor_to_densitytensor(statetensor: jnp.ndarray) -> jnp.ndarray: + """ + Computes a densitytensor representation of a pure quantum state + from its statetensor representaton + + Args: + statetensor: Input statetensor. + + Returns: + A densitytensor representing the quantum state. + """ + n_qubits = statetensor.ndim + st = statetensor + dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) + return dt diff --git a/qujax/version.py b/qujax/version.py index cd9b137..0404d81 100644 --- a/qujax/version.py +++ b/qujax/version.py @@ -1 +1 @@ -__version__ = '0.2.9' +__version__ = '0.3.0' diff --git a/tests/test_density_matrix.py b/tests/test_densitytensor.py similarity index 73% rename from tests/test_density_matrix.py rename to tests/test_densitytensor.py index 8fc586b..8a79345 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_densitytensor.py @@ -2,10 +2,6 @@ from jax import numpy as jnp, jit import qujax -from qujax import get_params_to_statetensor_func -from qujax import _kraus_single, kraus, get_params_to_densitytensor_func, partial_trace -from qujax.observable import statetensor_to_densitytensor -from qujax import densitytensor_to_measurement_probabilities, densitytensor_to_measured_densitytensor def test_kraus_single(): @@ -22,21 +18,21 @@ def test_kraus_single(): check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T # qujax._kraus_single - qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dt = qujax._kraus_single(density_tensor, kraus_operator, qubit_inds) qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(_kraus_single, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) # qujax.kraus (but for a single array) - qujax_kraus_dt = kraus(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator, qubit_inds) qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operator, qubit_inds) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) @@ -56,25 +52,27 @@ def test_kraus_single_2qubit(): check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T # qujax._kraus_single - qujax_kraus_dt = _kraus_single(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dt = qujax._kraus_single(density_tensor, kraus_operator_tensor, qubit_inds) qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(_kraus_single, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(density_tensor, + kraus_operator_tensor, + qubit_inds) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) # qujax.kraus (but for a single array) - qujax_kraus_dt = kraus(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator_tensor, qubit_inds) qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt = kraus(density_tensor, kraus_operator, qubit_inds) # check reshape kraus_operator correctly + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator, qubit_inds) # check reshape kraus_operator correctly qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operator_tensor, qubit_inds) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) @@ -96,12 +94,12 @@ def test_kraus_multiple(): for um in unitary_matrices: check_kraus_dm += um @ density_matrix @ um.conj().T - qujax_kraus_dt = kraus(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operators, qubit_inds) qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) - qujax_kraus_dt_jit = jit(kraus, static_argnums=(2,))(density_tensor, kraus_operators, qubit_inds) + qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(density_tensor, kraus_operators, qubit_inds) qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) @@ -117,13 +115,13 @@ def test_params_to_densitytensor_func(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) - params_to_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_st = qujax.get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) params = jnp.arange(n_qubits) / 10. st = params_to_st(params) - dt_test = statetensor_to_densitytensor(st) + dt_test = qujax.statetensor_to_densitytensor(st) dt = params_to_dt(params) @@ -144,7 +142,7 @@ def test_params_to_densitytensor_func_with_bit_flip(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_pre_bf_st = get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_pre_bf_st = qujax.get_params_to_statetensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) kraus_ops = [[0.3 * jnp.eye(2), 0.7 * qujax.gates.X]] kraus_qubit_inds = [(0,)] @@ -154,13 +152,15 @@ def test_params_to_densitytensor_func_with_bit_flip(): qubit_inds_seq += kraus_qubit_inds param_inds_seq += kraus_param_inds - params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + _ = qujax.print_circuit(gate_seq, qubit_inds_seq, param_inds_seq) + + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) params = jnp.arange(n_qubits) / 10. pre_bf_st = params_to_pre_bf_st(params) pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape(2 for _ in range(2 * n_qubits)) - dt_test = kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0]) + dt_test = qujax.kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0]) dt = params_to_dt(params) @@ -180,10 +180,10 @@ def test_partial_trace_1(): dt3 = jnp.outer(state3, state3.conj()).reshape((2,) * 6) for i in range(3): - assert jnp.allclose(partial_trace(dt3, [i]), dt2) + assert jnp.allclose(qujax.partial_trace(dt3, [i]), dt2) for i in combinations(range(3), 2): - assert jnp.allclose(partial_trace(dt3, i), dt1) + assert jnp.allclose(qujax.partial_trace(dt3, i), dt1) def test_partial_trace_2(): @@ -197,13 +197,13 @@ def test_partial_trace_2(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) params = jnp.arange(1, n_qubits + 1) / 10. dt = params_to_dt(params) dt_discard_test = jnp.trace(dt, axis1=0, axis2=n_qubits) - dt_discard = partial_trace(dt, [0]) + dt_discard = qujax.partial_trace(dt, [0]) assert jnp.allclose(dt_discard, dt_discard_test) @@ -219,7 +219,7 @@ def test_measure(): qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] param_inds_seq += [() for _ in range(n_qubits - 1)] - params_to_dt = get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) + params_to_dt = qujax.get_params_to_densitytensor_func(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) params = jnp.arange(1, n_qubits + 1) / 10. @@ -231,7 +231,7 @@ def test_measure(): all_probs_marginalise \ = all_probs.reshape((2,) * n_qubits).sum(axis=[i for i in range(n_qubits) if i not in qubit_inds]) - probs = densitytensor_to_measurement_probabilities(dt, qubit_inds) + probs = qujax.densitytensor_to_measurement_probabilities(dt, qubit_inds) assert jnp.isclose(probs.sum(), 1.) assert jnp.isclose(all_probs.sum(), 1.) @@ -245,7 +245,7 @@ def test_measure(): measured_dm /= jnp.trace(projector.T.conj() @ projector @ dm) measured_dt_true = measured_dm.reshape((2,) * 2 * n_qubits) - measured_dt = densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) - measured_dt_bits = densitytensor_to_measured_densitytensor(dt, qubit_inds, (0,)*n_qubits) + measured_dt = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) + measured_dt_bits = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, (0,)*n_qubits) assert jnp.allclose(measured_dt_true, measured_dt) assert jnp.allclose(measured_dt_true, measured_dt_bits) diff --git a/tests/test_expectations.py b/tests/test_expectations.py index 733eccb..0f168a5 100644 --- a/tests/test_expectations.py +++ b/tests/test_expectations.py @@ -1,9 +1,12 @@ from jax import numpy as jnp, jit, grad, random, config -import qujax.gates import qujax -from qujax import densitytensor_to_single_expectation, statetensor_to_single_expectation -from qujax.observable import statetensor_to_densitytensor + + +def test_pauli_hermitian(): + for p_str in ('X', 'Y', 'Z'): + qujax.check_hermitian(p_str) + qujax.check_hermitian(qujax.gates.__dict__[p_str]) def test_single_expectation(): @@ -13,14 +16,14 @@ def test_single_expectation(): st2 = jnp.zeros((2, 2, 2)) st1 = st1.at[(0, 0, 0)].set(1.) st2 = st2.at[(1, 0, 0)].set(1.) - dt1 = statetensor_to_densitytensor(st1) - dt2 = statetensor_to_densitytensor(st2) + dt1 = qujax.statetensor_to_densitytensor(st1) + dt2 = qujax.statetensor_to_densitytensor(st2) ZZ = jnp.kron(Z, Z).reshape(2, 2, 2, 2) - est1 = statetensor_to_single_expectation(dt1, ZZ, [0, 1]) - est2 = statetensor_to_single_expectation(dt2, ZZ, [0, 1]) - edt1 = densitytensor_to_single_expectation(dt1, ZZ, [0, 1]) - edt2 = densitytensor_to_single_expectation(dt2, ZZ, [0, 1]) + est1 = qujax.statetensor_to_single_expectation(dt1, ZZ, [0, 1]) + est2 = qujax.statetensor_to_single_expectation(dt2, ZZ, [0, 1]) + edt1 = qujax.densitytensor_to_single_expectation(dt1, ZZ, [0, 1]) + edt2 = qujax.densitytensor_to_single_expectation(dt2, ZZ, [0, 1]) assert est1.item() == edt1.item() == 1 assert est2.item() == edt2.item() == -1 @@ -97,7 +100,7 @@ def test_ZZ_Y(): state = random.uniform(random.PRNGKey(0), shape=(2 ** n_qubits,)) * 2 state /= jnp.linalg.norm(state) st_in = state.reshape((2,) * n_qubits) - dt_in = statetensor_to_densitytensor(st_in) + dt_in = qujax.statetensor_to_densitytensor(st_in) def big_hermitian_matrix(hermitian_str_seq, qubit_inds): qubit_arrs = [getattr(qujax.gates, s) for s in hermitian_str_seq] @@ -144,8 +147,8 @@ def big_hermitian_matrix(hermitian_str_seq, qubit_inds): coefs) qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) - qujax_samp_exp_dt = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) - qujax_samp_exp_dt_jit = jit(st_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) + qujax_samp_exp_dt = dt_to_samp_exp(st_in, random.PRNGKey(1), 1000000) + qujax_samp_exp_dt_jit = jit(dt_to_samp_exp, static_argnums=2)(st_in, random.PRNGKey(2), 1000000) assert jnp.array(qujax_samp_exp).shape == () assert jnp.array(qujax_samp_exp).dtype.name[:5] == 'float' assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2) From 9d063d17816612ebe4a556db44ce66e96feb4744 Mon Sep 17 00:00:00 2001 From: SamDuffield <34280297+SamDuffield@users.noreply.github.com> Date: Mon, 7 Nov 2022 08:06:17 +0000 Subject: [PATCH 24/25] Update README latex comile --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 267b275..2083154 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ takes as input any parameters of the circuit and outputs a _statetensor_. The st of the quantum state and can then be used downstream for exact expectations, gradients or sampling. qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix, -which has shape ($2^N$, $2^N$). +which has shape $(2^N, 2^N)$. This allows for mixed states and generic Kraus operators. A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support From 9ba090301e36498db7ec510bac7ee35e6788f47d Mon Sep 17 00:00:00 2001 From: SamDuffield <34280297+SamDuffield@users.noreply.github.com> Date: Mon, 7 Nov 2022 08:07:46 +0000 Subject: [PATCH 25/25] Update README.md --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index 2083154..965d8c7 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,7 @@ Represent a (parameterised) quantum circuit as a pure [JAX](https://github.com/g takes as input any parameters of the circuit and outputs a _statetensor_. The statetensor encodes all $2^N$ amplitudes of the quantum state and can then be used downstream for exact expectations, gradients or sampling. -qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix, -which has shape $(2^N, 2^N)$. -This allows for mixed states and generic Kraus operators. +qujax also supports densitytensor simulations. A densitytensor is a tensor representation of the density matrix and allows for mixed states and generic Kraus operators. A JAX implementation of a quantum circuit is useful for runtime speedups, automatic differentiation and support for GPUs/TPUs.