Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Incorrect output pytree when using qml.counts() in specific output patterns #1016

Open
mehrdad2m opened this issue Aug 13, 2024 · 2 comments · May be fixed by #1219
Open

[BUG] Incorrect output pytree when using qml.counts() in specific output patterns #1016

mehrdad2m opened this issue Aug 13, 2024 · 2 comments · May be fixed by #1219
Labels
bug Something isn't working

Comments

@mehrdad2m
Copy link
Contributor

mehrdad2m commented Aug 13, 2024

Context

When using qml.counts() in the output of a quantum circuit with qjit, the output pytree is modified to replace the output pytree element related to qml.counts with tree_structure(("keys", "counts")). However this transformation is buggy and while it works for simple cases, it incorrectly transforms more complex patterns.

An example that works fine:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1":  qml.counts()}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

The result is as expected:

PyTreeDef({'1': (*, *)})

In the following example, there are two patterns that result in the wrong output pytree:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

results in:

PyTreeDef(((*, *), {'2': *}))

instead of the expected pytree of:

PyTreeDef(({'1': (*, *)}, {'2': *}))
dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

results in:

PyTreeDef(([{'1': *}, {'2': *}], (*, *)))

while the expected pytree is:

PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))

A possible solution would update trace_quantum_measurements(), which is where the output pytree is modified. You could write a function replace_child_tree(tree, i, subtree) which receives a pytree and would replace the ith node of the tree that is visited in a DFS of subtree.

@mehrdad2m mehrdad2m added the bug Something isn't working label Aug 14, 2024
@josh146
Copy link
Member

josh146 commented Aug 18, 2024

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

@mehrdad2m
Copy link
Contributor Author

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

Hi @josh146, It is pretty straightforward. The fix should be done in trace_quantum_measurements which is where the output pytree is modified. Basically the simple version of the problem is to write a function replace_child_tree(tree, i, subtree) which recieved a pytree and would replace the ith node of tree that is visited in a DFS with sub_tree. The only tricky part is working with pytrees :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants