Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Some decompositions/transforms do not preserve derivatives #5715

Closed
1 task done
dwierichs opened this issue May 21, 2024 · 2 comments
Closed
1 task done

[BUG] Some decompositions/transforms do not preserve derivatives #5715

dwierichs opened this issue May 21, 2024 · 2 comments
Assignees
Labels
bug 🐛 Something isn't working

Comments

@dwierichs
Copy link
Contributor

dwierichs commented May 21, 2024

Expected behavior

Using decompositions and transforms does not change the derivative of the overall workflow.

Actual behavior

Some decompositions/transforms only reproduce the function, but not its derivative. I found this in the following parts of the codebase:

  • merge_rotations: Some rotation gates are skipped for zero angles
  • single_qubit_fusion: Some rotation gates are skipped for zero angles
  • MottonenStatePreparation: Depending on the input state, gates are skipped, which leads to errors with JITting (no gradient entries to stack) or produces nan values.
  • fuse_rot_angles: Used in merge_rotations and single_qubit_fusion, creates second bugs within both functions

Additional information

Note that JITting usually prevents the source of error (except for MottonenStatePrep), and in all examples above, the code base has special logic for JITting.
As a consequence, JITted derivatives tend to be unaffected by the type of bug observed in the transforms.

Under the hood, this seems like similar to #5541, which is concerned with AmplitudeEmbedding and is being solved in #5620 by modifying the diff method of GlobalPhase. However, the bug described here is of different origin and was encountered while finalizing the tests for #5620 for MottonenStatePreparation.

Source code

#### BUG caused by merge_rotations itself

@qml.transforms.merge_rotations
def _node(x):
    qml.RX(x, 1)
    qml.RX(x, 1)
    return qml.expval(qml.Y(1))

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

print("Derivatives at 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:    
    print(jax.jacobian(node_)(0.))

>>> Derivatives at 0:
... 0.0
... -2.0
... 0.0
... -1.9999999999999996

print("Derivatives close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-8))

>>> Derivatives close to 0:
... -1.9999999999999993
... -1.9999999999999993
... -1.9999999999999993
... -1.9999999999999993

#### BUG caused by fuse_rot_angles via merge_rotations

@qml.transforms.merge_rotations
def _node(x):
    qml.Rot(x, x, x, 1)
    qml.Rot(x, x, x, 1)
    return qml.expval(qml.X(1))

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

print("Derivatives at 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:    
    print(jax.jacobian(node_)(0.))

>>> Derivatives at 0:
... 0.0
... 2.0
... 0.0
... 1.9999999999999996

print("Derivatives close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-6))

>>> Derivatives close to 0:
... 2.0000221220668224
... 1.9999999999840001
... 2.0000221220668215
... 1.9999999999839995


#### BUGS in single_qubit_fusion, one in the function itself, one from fuse_rot_angles
@partial(qml.transforms.single_qubit_fusion, atol=1e-6)
def _node(x):
    qml.RX(x, 1)
    qml.RX(x, 1)
    return qml.expval(qml.Y(1))

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

print("Derivatives at 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:    
    print(jax.jacobian(node_)(0.))

>>> Derivatives at 0:
... 0.0
... 0.0
... 0.0
... 0.0

print("Derivatives close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-7))

>>> Derivatives close to 0:
... 0.0
... -2.000799757290469
... 0.0
... -2.000799757290469

print("Derivatives less close to 0:")
for node_ in [node_ps, node_ps_jit, node_ad, node_ad_jit]:
    print(jax.jacobian(node_)(1e-5))

>>> Derivatives less close to 0:
... -1.9999999168263007
... -1.9999999168263007
... -1.9999999168263003
... -1.9999999168263003

#### BUGS with MottonenStatePreparation
def _node(x):
    qml.MottonenStatePreparation(x, wires=[0, 1])
    return qml.probs()

dev = qml.device("default.qubit")
node_ps = qml.QNode(_node, dev, diff_method="parameter-shift")
node_ps_jit = jax.jit(qml.QNode(_node, dev, diff_method="parameter-shift"))
node_ad = qml.QNode(_node, dev)
node_ad_jit = jax.jit(qml.QNode(_node, dev))

x1 = jnp.array([1, 1, 0, 1]) / np.sqrt(3)

for node_ in [node_ps, node_ad, node_ps_jit, node_ad_jit]: # Fails with JITted nodes   
    print(jax.jacobian(node_)(x1))
>>> [[ 7.69800359e-01 -3.84900179e-01             nan             nan]
...  [-3.84900179e-01  7.69800359e-01             nan             nan]
...  [-4.80740672e-17  4.80740672e-17             nan             nan]
...  [-3.84900179e-01 -3.84900179e-01             nan             nan]]
>>> [[ 7.69800359e-01 -3.84900179e-01             nan             nan]
...  [-3.84900179e-01  7.69800359e-01             nan             nan]
...  [-4.80740672e-17  4.80740672e-17             nan             nan]
...  [-3.84900179e-01 -3.84900179e-01             nan             nan]]

x2 = jnp.array([1, 0, 0, 1]) / np.sqrt(2)
for node_ in [node_ps, node_ad, node_ps_jit, node_ad_jit]: # Fails with JITted nodes   
    print(jax.jacobian(node_)(x2))

>>> [[nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]]
>>> [[nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]
...  [nan nan nan nan]]

Tracebacks

No response

System information

pl dev

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@dwierichs dwierichs added the bug 🐛 Something isn't working label May 21, 2024
@dwierichs dwierichs self-assigned this May 30, 2024
@dwierichs
Copy link
Contributor Author

While trying to fix this, I noticed that fuse_rot_angles uses a function that - as it stands - is not differentiable everywhere. At those singular points, we're returning wrong derivatives in yet another way :/

dwierichs added a commit that referenced this issue Jun 13, 2024
**Context:**
The decomposition of `MottonenStatePreparation` skips some gates for
special parameter values/input states.
See the linked issue for details.

**Description of the Change:**
This PR introduces a check for differentiability so that the gates only
are skipped when no derivatives are being computed.
Note that this does *not* fix the non-differentiability at other special
parameter points that also is referenced in #5715 and that is being
warned against in the docs already.
Also, the linked issue is about multiple operations and we here only
address `MottonenStatePreparation`.

**Benefits:**
Fixes parts of #5715. Unblocks #5620 .

**Possible Drawbacks:**

**Related GitHub Issues:**
#5715
dwierichs added a commit that referenced this issue Aug 14, 2024
…6031)

**Context:**
The current implementation of `fuse_rot_angles`, which is used by
`merge_rotations` and `single_qubit_fusion`, has the following issues:
1. It does not necessarily preserve global phases. As we move towards
global-phase aware standards, this becomes an issue where it wasn't one
before.
2. Its derivative is wrong, forming part (but not all) of the bug #5715.
In particular, the custom handling of special input values prevents the
calculation of correct derivatives, and `fuse_rot_angles` at singular
points leads to wrong derivatives, rather than NaN values, which are
mathematically well-motivated.
3. A minor technical issue is that the implementation requires nested
conditionals, leading to a good bit of code and separate handling of
traced JAX code, making it more complex.

**Description of the Change:**
The implementation of `fuse_rot_angles` is remade entirely. 

**Benefits:**
The remade code uses a comparably simple mathematical expression to
compute the fused rotation angles that
1. preserves global phases
2. has the correct derivative everywhere except for well-understandable,
predictable singular points. These singular points make sense because
rotation fusion is not a smooth map everywhere. The predictability
allowed us to write a dedicated test that confirms our understanding of
the singularities, at least within a large set of special test points.
3. does not require conditionals beyond those that are implemented in
`qml.math.arctan2` anyways, and thus available in all ML interfaces,
including JAX with JIT-compatibility.
4. As a bonus, the new implementation supports broadcasting/batching
with an arbitrary number of leading dimensions if all angles in each set
are broadcasted in the same way (because this is nice, easy to support,
and allows us to speed up tests a lot).

In summary, the global phases are fixed, Jacobians are only ever NaNs,
rather than wrong, and Jacobians are only NaNs when they were NaN or
wrong in the current implementation.

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
#5715 (not fixed entirely, just partially)
[sc-63642]

---------

Co-authored-by: Vincent Michaud-Rioux <[email protected]>
Co-authored-by: Korbinian Kottmann <[email protected]>
Co-authored-by: Thomas R. Bromley <[email protected]>
dwierichs added a commit that referenced this issue Aug 19, 2024
**Context:**
As reported in #5715, `merge_rotations` and `single_qubit_fusion` have
problems with differentiability at specific points.
#6031 takes care of upgrading the Rot-angle fusion to only ever return
NaNs at mathematically non-differentiable points, rather than wrong
results.
However, the transforms add additional points, based on internal
optimizations, where the derivative is flawed.

**Description of the Change:**
This PR fixes the flawed derivatives caused by the code of the
transforms themselves.

**Benefits:**
Fixes derivatives of `merge_rotations` and `single_qubit_fusion` (where
mathematically defined)

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
Fixes another part of #5715, still not all of it.

---------

Co-authored-by: Vincent Michaud-Rioux <[email protected]>
Co-authored-by: Korbinian Kottmann <[email protected]>
Co-authored-by: Thomas R. Bromley <[email protected]>
@dwierichs
Copy link
Contributor Author

After merging #5620, #5774, #6031, and #6033, the issues with the mentioned decompositions and transforms are resolved. The last code example with MottonenStatePreparation still fails because of a more general bug related to non-scalar outputs, see #3480 .
As it is not related to the transforms/decompositions section of the codebase that this issue is about, I suggest to close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant