Skip to content

Commit

Permalink
Fixes for new autoray version (#4396)
Browse files Browse the repository at this point in the history
* Fixes for new autoray version

* unused import

* Use if instead of try-except

* Fix case of two torch tensors
  • Loading branch information
eddddddy authored and mudit2812 committed Jul 27, 2023
1 parent 9680486 commit 5508b4a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
22 changes: 16 additions & 6 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,12 +769,22 @@ def add(*args, like=None, **kwargs):
"""Add arguments element-wise."""
if like == "scipy":
return onp.add(*args, **kwargs) # Dispatch scipy add to numpy backed specifically.
try:
return np.add(*args, **kwargs)
except TypeError:
# catch arg1 = torch, arg2=numpy error
# works fine with opposite order
return np.add(args[1], args[0], *args[2:], **kwargs)

arg_interfaces = {get_interface(args[0]), get_interface(args[1])}

# case of one torch tensor and one vanilla numpy array
if like == "torch" and len(arg_interfaces) == 2:
# In autoray 0.6.5, np.add dispatches to torch instead of
# numpy if one parameter is a torch tensor and the other is
# a numpy array. torch.add raises an Exception if one of the
# arguments is a numpy array, so here we cast both arguments
# to be tensors.
dev = getattr(args[0], "device", None) or getattr(args[1], "device")
arg0 = np.asarray(args[0], device=dev, like=like)
arg1 = np.asarray(args[1], device=dev, like=like)
return np.add(arg0, arg1, *args[2:], **kwargs)

return np.add(*args, **kwargs)


@multi_dispatch()
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_phaseshift(self, method, wire, ml_framework):
initial_state = qml.math.asarray(initial_state, like=ml_framework)

phase = qml.math.asarray(-2.3, like=ml_framework)
shift = np.exp(qml.math.multiply(1j, phase))
shift = qml.math.exp(1j * qml.math.cast(phase, np.complex128))

new_state = method(qml.PhaseShift(phase, wire), initial_state)

Expand Down

0 comments on commit 5508b4a

Please sign in to comment.