Skip to content

Commit

Permalink
[Feature] Finite shots
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz committed Jul 11, 2024
1 parent 92490b8 commit ccccc91
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 14 deletions.
18 changes: 4 additions & 14 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,21 @@
from horqrux.utils import OperationType, inner


def expectation(
@custom_vjp
def adjoint_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return inner(out_state, projected_state).real


@custom_vjp
def adjoint_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
return expectation(state, gates, observable, values)


def adjoint_expectation_fwd(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
) -> Array:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return inner(out_state, projected_state).real, (out_state, projected_state, gates, values)
return inner(out_state, projected_state).real


def adjoint_expectation_bwd(
Expand Down
77 changes: 77 additions & 0 deletions horqrux/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from collections import Counter

import jax
import jax.numpy as jnp
from jax import Array

from horqrux.adjoint import adjoint_expectation
from horqrux.apply import apply_gate
from horqrux.primitive import Primitive
from horqrux.utils import OperationType, inner


def run(
circuit: list[Primitive],
state: Array,
values: dict[str, float] = dict(),
) -> Array:
return apply_gate(state, circuit, values)


def sample(
state: Array,
gates: list[Primitive],
values: dict[str, float] = dict(),
n_shots: int = 1000,
) -> Counter:
if n_shots < 1:
raise ValueError("You can only call sample with n_shots>0.")

wf = apply_gate(state, gates, values)
probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel()
key = jax.random.PRNGKey(0)
n_qubits = len(state.shape)
# JAX handles pseudo random number generation by tracking an explicit state via a random key
# For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html
samples = jax.vmap(
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs)
)(jax.random.split(key, n_shots))

return Counter(
{
format(k, "0{}b".format(n_qubits)): count.item()
for k, count in enumerate(jnp.bincount(samples))
if count > 0
}
)


def ad_expectation(
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return inner(out_state, projected_state).real


def expectation(
state: Array,
gates: list[Primitive],
observable: list[Primitive],
values: dict[str, float],
diff_mode: str = "ad",
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
if diff_mode == "ad":
return ad_expectation(state, gates, observable, values)
else:
return adjoint_expectation(state, gates, observable, values)
38 changes: 38 additions & 0 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

import jax
import jax.numpy as jnp
from jax import Array

from horqrux.apply import apply_gate
from horqrux.primitive import Primitive


def finite_shots(
state: Array,
gates: list[Primitive],
observable: Primitive,
values: dict[str, float],
diff_mode: str = "gpsr",
n_shots: int = 100,
) -> Array:
"""
Run 'state' through a sequence of 'gates' given parameters 'values'
and compute the expectation given an observable.
"""
state = apply_gate(state, gates, values)
# NOTE this only works now for an observable comprised of a single gate
# to get eigvals,eigvecs for arbitary compositions of paulis, we need to
# create the full tensor. check `block_to_jax` in qadence for this
eigvals, eigvecs = jnp.linalg.eig(observable.unitary())
probs = jnp.abs(jnp.float_power(jnp.inner(state, eigvecs), 2.0)).ravel()
key = jax.random.PRNGKey(0)
n_qubits = len(state.shape)
samples = jax.vmap(
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs)
)(jax.random.split(key, n_shots))
# samples now contains a list of indices
# i forgot the formula
# something here which is correct
counts = jnp.bincount(samples)
return jnp.mean(counts / n_shots)
25 changes: 25 additions & 0 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import numpy as np

from horqrux import random_state
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, H, I, S, T, X, Y, Z
from horqrux.shots import finite_shots

MAX_QUBITS = 2
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T)


def test_gradcheck() -> None:
ops = [RX("theta", 0), RY("epsilon", 0), RX("phi", 0), NOT(1, 0), RX("omega", 0, 1)]
observable = Z(0)
values = {
"theta": np.random.uniform(0, 1),
"epsilon": np.random.uniform(0, 1),
"phi": np.random.uniform(0, 1),
"omega": np.random.uniform(0, 1),
}
state = random_state(MAX_QUBITS)
exp_shots = finite_shots(state, ops, observable, values)

0 comments on commit ccccc91

Please sign in to comment.