-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Users can capture conditionals using qml.cond
#5999
Conversation
… are making progresses... [ci skip]
…e second one are ignored...
…o capture_qml_cond
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One suggestion to try.
What would you think of combining jaxpr_true
, jaxpr_false
, and all the jaxpr_elifs
? Same with combining the condition
with the elifs_conditions
.
We could essentially have:
jaxpr_branches = (jaxpr_true, *jaxpr_elifs, jaxpr_false)
conditions = jax.numpy.array([condition, *elifs_conditions, True])
And then the implementation could just be something like:
for condition, branch in zip(conditions, branches):
if condition:
return jax.core.eval_jaxpr(branch, consts, *args)
@PietropaoloFrisoni if you think this is in a reasonably good state, I would like to use this as a base for my PR so that I can capture |
@mudit2812 Sure, no problem for me. I'll tell you if I need to make relevant changes, but hopefully, this is not the case |
…o capture_qml_cond
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #5999 +/- ##
==========================================
- Coverage 99.66% 99.65% -0.01%
==========================================
Files 430 430
Lines 41544 41341 -203
==========================================
- Hits 41404 41200 -204
- Misses 140 141 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
I created a branch locally that uses the changes from this PR as well as from #6015. Seems like after creating a primitive for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work 🎉
Context: The purpose of this PR is to allow
qml.cond
to be captured into plxpr, although with some restrictions on its usage (see below for details).Description of the Change: Now,
qml.cond
can be captured into plxpr. The function's behavior resembles the one provided in Catalyst, with the executed branch determined at runtime.Benefits: If
qml.capture.enabled()
isTrue
,qml.cond
can be used.Possible Drawbacks: Here there's a list of necessary restrictions at this stage.
The arguments provided to the branches must be the same for all branches.
The arguments provided to branches must be JAX-compatible. For example, this example would not work:
because wires are not a valid JAX type, and the arguments provided to the branch are traced.
Tests on nested conditionals have not been included as they belong to a different (future) epic, but in principle they should work
If a branch returns one or more variables, every other branch must return the same abstract values.
For example, the following example would not work:
because
true_fn
andfalse_fn
do not return the same abstract values. This is similar to Catalyst, where values produced inside the scope of a conditional can be returned to the outside context, but the return type signature of each branch must be identical.Related GitHub Issues: None.
Related Shortcut Stories
[sc-66774]
[sc-69642]