Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Users can capture conditionals using qml.cond #5999

Merged
merged 45 commits into from
Jul 31, 2024

Conversation

PietropaoloFrisoni
Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni commented Jul 15, 2024

Context: The purpose of this PR is to allow qml.cond to be captured into plxpr, although with some restrictions on its usage (see below for details).

Description of the Change: Now, qml.cond can be captured into plxpr. The function's behavior resembles the one provided in Catalyst, with the executed branch determined at runtime.

Benefits: If qml.capture.enabled() is True, qml.cond can be used.

Possible Drawbacks: Here there's a list of necessary restrictions at this stage.

  • The arguments provided to the branches must be the same for all branches.

  • The arguments provided to branches must be JAX-compatible. For example, this example would not work:

qml.capture.enable()

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(pred, wires):
    def true_fn(wires):
        qml.RY(0.1, wires=wires)
    qml.cond(pred > 0, true_fn)(wires)
    return qml.expval(qml.PauliZ(wires=0))
circuit(1, qml.wires.Wires([0]))

because wires are not a valid JAX type, and the arguments provided to the branch are traced.

  • Tests on nested conditionals have not been included as they belong to a different (future) epic, but in principle they should work

  • If a branch returns one or more variables, every other branch must return the same abstract values.

For example, the following example would not work:

qml.capture.enable()

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(pred, arg):
    def true_fn(arg):
        return qml.RY(arg, wires=0)
    def false_fn(arg):
        return qml.RZ(arg, wires=0), 5
    qml.cond(pred > 0, true_fn, false_fn)(arg)
    return qml.expval(qml.PauliZ(wires=0))

because true_fn and false_fn do not return the same abstract values. This is similar to Catalyst, where values produced inside the scope of a conditional can be returned to the outside context, but the return type signature of each branch must be identical.

  • The function cannot branch on mid-circuit measurements, because these cannot be captured into plxpr yet.

Related GitHub Issues: None.

Related Shortcut Stories
[sc-66774]
[sc-69642]

@PennyLaneAI PennyLaneAI deleted a comment from codecov bot Jul 23, 2024
Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion to try.

What would you think of combining jaxpr_true, jaxpr_false, and all the jaxpr_elifs? Same with combining the condition with the elifs_conditions.

We could essentially have:

jaxpr_branches = (jaxpr_true, *jaxpr_elifs, jaxpr_false)
conditions = jax.numpy.array([condition, *elifs_conditions, True])

And then the implementation could just be something like:

for condition, branch in zip(conditions, branches):
    if condition:
        return jax.core.eval_jaxpr(branch, consts, *args)

@mudit2812
Copy link
Contributor

@PietropaoloFrisoni if you think this is in a reasonably good state, I would like to use this as a base for my PR so that I can capture qml.cond with predicates that use mid-circuit measurements.

@PietropaoloFrisoni
Copy link
Contributor Author

@mudit2812 Sure, no problem for me. I'll tell you if I need to make relevant changes, but hopefully, this is not the case

@PennyLaneAI PennyLaneAI deleted a comment from codecov bot Jul 29, 2024
Copy link

codecov bot commented Jul 30, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.65%. Comparing base (f9adf90) to head (db95c9b).
Report is 296 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5999      +/-   ##
==========================================
- Coverage   99.66%   99.65%   -0.01%     
==========================================
  Files         430      430              
  Lines       41544    41341     -203     
==========================================
- Hits        41404    41200     -204     
- Misses        140      141       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

pennylane/ops/op_math/condition.py Show resolved Hide resolved
@mudit2812
Copy link
Contributor

I created a branch locally that uses the changes from this PR as well as from #6015. Seems like after creating a primitive for qml.measure which has a scalar boolean abstract eval, using mid-circuit measurements with qml.cond just works out of the box for capture. It cannot be executed for the same reason why any other classical processing on MCMs can't be executed, but I guess I don't need to do any further work other than adding tests 😄

Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work 🎉

@PietropaoloFrisoni PietropaoloFrisoni enabled auto-merge (squash) July 31, 2024 20:15
@PietropaoloFrisoni PietropaoloFrisoni merged commit 6715095 into master Jul 31, 2024
40 checks passed
@PietropaoloFrisoni PietropaoloFrisoni deleted the capture_qml_cond branch July 31, 2024 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants