-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return DShapeArray objects in measure primitives when possible #1170
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1170 +/- ##
==========================================
+ Coverage 97.89% 97.96% +0.07%
==========================================
Files 76 76
Lines 10879 10885 +6
Branches 1289 1292 +3
==========================================
+ Hits 10650 10664 +14
+ Misses 178 174 -4
+ Partials 51 47 -4 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Raul!
shots, shape_value = jaxpr.eqns[1].params.values() | ||
|
||
assert jaxpr.eqns[1].primitive == probs_p | ||
assert isinstance(shape_value[0], DynamicJaxprTracer) | ||
assert shots == 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what changed that means there are no escaped tracer issues compared to your previous test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note the jaxpr for the example here is different than what JAX would generate, and I'm wondering if that isn't going to lead to issues (like escaped tracers).
Our jaxpr, contains tracer in the jaxpr equation:
{ lambda ; a:i64[]. let
b:AbstractObs(num_qubits=0,primitive=compbasis) = compbasis
c:f64[a] d:i64[a] = counts[
shape=(Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,)
shots=5
] b
in (c, d) }
JAX's jaxpr for similar example (jnp.ones((dim_0, 5))
), does not insert the tracer as constant data into the jaxpr equation:
{ lambda ; a:i64[]. let
b:f64[a,5] = broadcast_in_dim[broadcast_dimensions=() shape=(None, 5)] 1.0 a
in (b,) }
Instead we find None
in the place where the tracer would be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to remove shots
from the jaxpr by making it a traceable argument. However, shape
is still a keyword argument, that's why we still see the shots
tracer as constant data inside shape. Once we support shape
as traceable, I guess we will see the None
keyword instead.
def f(dim_0): | ||
obs = compbasis_p.bind() | ||
return state_p.bind(obs, shots=5, shape=(dim_0,)) | ||
|
||
jaxpr = jax.make_jaxpr(f)(1) | ||
shots, shape_value = jaxpr.eqns[1].params.values() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we still need to cover more of the pipeline in our tests (ignoring the MLIR core for now, which can be addressed in a different PR).
From what I can see we test the bind function and abstract evaluation of jaxpr primitives, but that leaves out most of the code that constitutes the frontend. The first step might be to identify which functions could be affected by this change, from the start of tracing until jax spits out the mlir, and then ensure those functions are tested for the dynamic case. Or alternatively develop tests that cover roughly that entire span.
I decided to tackle the MLIR changes in the same PR |
Context: Measure primitives only return
ShapedArray
objects, which forces variables like the number of shots or number of wires to be known at compilation time. With the support forDShapedArray
objects, this constraint can be removed.Description of the Change: Return
DShapedArray
objects when possible.Benefits: Eliminating the need to know the number of shots/wires statically is a step towards allowing us to compile programs that don't have a fixed number of qubits.
[sc-74736]