-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
92490b8
commit ccccc91
Showing
4 changed files
with
144 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from __future__ import annotations | ||
|
||
from collections import Counter | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import Array | ||
|
||
from horqrux.adjoint import adjoint_expectation | ||
from horqrux.apply import apply_gate | ||
from horqrux.primitive import Primitive | ||
from horqrux.utils import OperationType, inner | ||
|
||
|
||
def run( | ||
circuit: list[Primitive], | ||
state: Array, | ||
values: dict[str, float] = dict(), | ||
) -> Array: | ||
return apply_gate(state, circuit, values) | ||
|
||
|
||
def sample( | ||
state: Array, | ||
gates: list[Primitive], | ||
values: dict[str, float] = dict(), | ||
n_shots: int = 1000, | ||
) -> Counter: | ||
if n_shots < 1: | ||
raise ValueError("You can only call sample with n_shots>0.") | ||
|
||
wf = apply_gate(state, gates, values) | ||
probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() | ||
key = jax.random.PRNGKey(0) | ||
n_qubits = len(state.shape) | ||
# JAX handles pseudo random number generation by tracking an explicit state via a random key | ||
# For more details, see https://jax.readthedocs.io/en/latest/random-numbers.html | ||
samples = jax.vmap( | ||
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) | ||
)(jax.random.split(key, n_shots)) | ||
|
||
return Counter( | ||
{ | ||
format(k, "0{}b".format(n_qubits)): count.item() | ||
for k, count in enumerate(jnp.bincount(samples)) | ||
if count > 0 | ||
} | ||
) | ||
|
||
|
||
def ad_expectation( | ||
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float] | ||
) -> Array: | ||
""" | ||
Run 'state' through a sequence of 'gates' given parameters 'values' | ||
and compute the expectation given an observable. | ||
""" | ||
out_state = apply_gate(state, gates, values, OperationType.UNITARY) | ||
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY) | ||
return inner(out_state, projected_state).real | ||
|
||
|
||
def expectation( | ||
state: Array, | ||
gates: list[Primitive], | ||
observable: list[Primitive], | ||
values: dict[str, float], | ||
diff_mode: str = "ad", | ||
) -> Array: | ||
""" | ||
Run 'state' through a sequence of 'gates' given parameters 'values' | ||
and compute the expectation given an observable. | ||
""" | ||
if diff_mode == "ad": | ||
return ad_expectation(state, gates, observable, values) | ||
else: | ||
return adjoint_expectation(state, gates, observable, values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import annotations | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import Array | ||
|
||
from horqrux.apply import apply_gate | ||
from horqrux.primitive import Primitive | ||
|
||
|
||
def finite_shots( | ||
state: Array, | ||
gates: list[Primitive], | ||
observable: Primitive, | ||
values: dict[str, float], | ||
diff_mode: str = "gpsr", | ||
n_shots: int = 100, | ||
) -> Array: | ||
""" | ||
Run 'state' through a sequence of 'gates' given parameters 'values' | ||
and compute the expectation given an observable. | ||
""" | ||
state = apply_gate(state, gates, values) | ||
# NOTE this only works now for an observable comprised of a single gate | ||
# to get eigvals,eigvecs for arbitary compositions of paulis, we need to | ||
# create the full tensor. check `block_to_jax` in qadence for this | ||
eigvals, eigvecs = jnp.linalg.eig(observable.unitary()) | ||
probs = jnp.abs(jnp.float_power(jnp.inner(state, eigvecs), 2.0)).ravel() | ||
key = jax.random.PRNGKey(0) | ||
n_qubits = len(state.shape) | ||
samples = jax.vmap( | ||
lambda subkey: jax.random.choice(key=subkey, a=jnp.arange(0, 2**n_qubits), p=probs) | ||
)(jax.random.split(key, n_shots)) | ||
# samples now contains a list of indices | ||
# i forgot the formula | ||
# something here which is correct | ||
counts = jnp.bincount(samples) | ||
return jnp.mean(counts / n_shots) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from horqrux import random_state | ||
from horqrux.parametric import PHASE, RX, RY, RZ | ||
from horqrux.primitive import NOT, H, I, S, T, X, Y, Z | ||
from horqrux.shots import finite_shots | ||
|
||
MAX_QUBITS = 2 | ||
PARAMETRIC_GATES = (RX, RY, RZ, PHASE) | ||
PRIMITIVE_GATES = (NOT, H, X, Y, Z, I, S, T) | ||
|
||
|
||
def test_gradcheck() -> None: | ||
ops = [RX("theta", 0), RY("epsilon", 0), RX("phi", 0), NOT(1, 0), RX("omega", 0, 1)] | ||
observable = Z(0) | ||
values = { | ||
"theta": np.random.uniform(0, 1), | ||
"epsilon": np.random.uniform(0, 1), | ||
"phi": np.random.uniform(0, 1), | ||
"omega": np.random.uniform(0, 1), | ||
} | ||
state = random_state(MAX_QUBITS) | ||
exp_shots = finite_shots(state, ops, observable, values) |