-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Some decompositions/transforms do not preserve derivatives #5715
Comments
While trying to fix this, I noticed that |
**Context:** The decomposition of `MottonenStatePreparation` skips some gates for special parameter values/input states. See the linked issue for details. **Description of the Change:** This PR introduces a check for differentiability so that the gates only are skipped when no derivatives are being computed. Note that this does *not* fix the non-differentiability at other special parameter points that also is referenced in #5715 and that is being warned against in the docs already. Also, the linked issue is about multiple operations and we here only address `MottonenStatePreparation`. **Benefits:** Fixes parts of #5715. Unblocks #5620 . **Possible Drawbacks:** **Related GitHub Issues:** #5715
…6031) **Context:** The current implementation of `fuse_rot_angles`, which is used by `merge_rotations` and `single_qubit_fusion`, has the following issues: 1. It does not necessarily preserve global phases. As we move towards global-phase aware standards, this becomes an issue where it wasn't one before. 2. Its derivative is wrong, forming part (but not all) of the bug #5715. In particular, the custom handling of special input values prevents the calculation of correct derivatives, and `fuse_rot_angles` at singular points leads to wrong derivatives, rather than NaN values, which are mathematically well-motivated. 3. A minor technical issue is that the implementation requires nested conditionals, leading to a good bit of code and separate handling of traced JAX code, making it more complex. **Description of the Change:** The implementation of `fuse_rot_angles` is remade entirely. **Benefits:** The remade code uses a comparably simple mathematical expression to compute the fused rotation angles that 1. preserves global phases 2. has the correct derivative everywhere except for well-understandable, predictable singular points. These singular points make sense because rotation fusion is not a smooth map everywhere. The predictability allowed us to write a dedicated test that confirms our understanding of the singularities, at least within a large set of special test points. 3. does not require conditionals beyond those that are implemented in `qml.math.arctan2` anyways, and thus available in all ML interfaces, including JAX with JIT-compatibility. 4. As a bonus, the new implementation supports broadcasting/batching with an arbitrary number of leading dimensions if all angles in each set are broadcasted in the same way (because this is nice, easy to support, and allows us to speed up tests a lot). In summary, the global phases are fixed, Jacobians are only ever NaNs, rather than wrong, and Jacobians are only NaNs when they were NaN or wrong in the current implementation. **Possible Drawbacks:** N/A **Related GitHub Issues:** #5715 (not fixed entirely, just partially) [sc-63642] --------- Co-authored-by: Vincent Michaud-Rioux <[email protected]> Co-authored-by: Korbinian Kottmann <[email protected]> Co-authored-by: Thomas R. Bromley <[email protected]>
**Context:** As reported in #5715, `merge_rotations` and `single_qubit_fusion` have problems with differentiability at specific points. #6031 takes care of upgrading the Rot-angle fusion to only ever return NaNs at mathematically non-differentiable points, rather than wrong results. However, the transforms add additional points, based on internal optimizations, where the derivative is flawed. **Description of the Change:** This PR fixes the flawed derivatives caused by the code of the transforms themselves. **Benefits:** Fixes derivatives of `merge_rotations` and `single_qubit_fusion` (where mathematically defined) **Possible Drawbacks:** N/A **Related GitHub Issues:** Fixes another part of #5715, still not all of it. --------- Co-authored-by: Vincent Michaud-Rioux <[email protected]> Co-authored-by: Korbinian Kottmann <[email protected]> Co-authored-by: Thomas R. Bromley <[email protected]>
After merging #5620, #5774, #6031, and #6033, the issues with the mentioned decompositions and transforms are resolved. The last code example with |
Expected behavior
Using decompositions and transforms does not change the derivative of the overall workflow.
Actual behavior
Some decompositions/transforms only reproduce the function, but not its derivative. I found this in the following parts of the codebase:
merge_rotations
: Some rotation gates are skipped for zero anglessingle_qubit_fusion
: Some rotation gates are skipped for zero anglesMottonenStatePreparation
: Depending on the input state, gates are skipped, which leads to errors with JITting (no gradient entries to stack) or producesnan
values.fuse_rot_angles
: Used inmerge_rotations
andsingle_qubit_fusion
, creates second bugs within both functionsAdditional information
Note that JITting usually prevents the source of error (except for
MottonenStatePrep
), and in all examples above, the code base has special logic for JITting.As a consequence, JITted derivatives tend to be unaffected by the type of bug observed in the transforms.
Under the hood, this seems like similar to #5541, which is concerned with
AmplitudeEmbedding
and is being solved in #5620 by modifying the diff method ofGlobalPhase
. However, the bug described here is of different origin and was encountered while finalizing the tests for #5620 forMottonenStatePreparation
.Source code
Tracebacks
No response
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: