Skip to content

Commit

Permalink
test: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Gentoo-p3nguin committed Jun 7, 2024
1 parent 14e6f14 commit 8eca1c9
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions test/integ_tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,86 @@
class TestSample:
"""Tests for the sample return type"""

def test_sample_default(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when no observable is specified
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 4, wires=0)
qml.CNOT(wires=[0, 1])
return qml.sample()

shot_vector = circuit()

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (shots, 2)

def test_sample_wires(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when only wires are specified
"""
dev = device(3)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 4, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample(wires=[0, 2])

shot_vector = circuit()

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (shots, 2)

def test_sample_batch_default(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when no observable is specified and
the batch dimension is returned
"""
dev = device(3)

@qml.qnode(dev)
def circuit(a):
qml.RX(a, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample()

shot_vector = circuit([np.pi / 4, np.pi / 3])

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (2, shots, 3)

def test_sample_batch_wires(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when only wires are specified
"""
dev = device(4)

@qml.qnode(dev)
def circuit(a):
qml.RX(a, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample(wires=[0, 2, 3])

shot_vector = circuit([np.pi / 4, np.pi / 3])

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (2, shots, 3)

def test_sample_values(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values
Expand Down

0 comments on commit 8eca1c9

Please sign in to comment.