-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resolve issue of qjit(static_argnums=...)
fails when the marked static argument has a default value
#1295
base: main
Are you sure you want to change the base?
Resolve issue of qjit(static_argnums=...)
fails when the marked static argument has a default value
#1295
Conversation
…arguments to compare validity of `static_argnums`
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1295 +/- ##
=======================================
Coverage 97.95% 97.95%
=======================================
Files 77 77
Lines 11318 11320 +2
Branches 981 981
=======================================
+ Hits 11087 11089 +2
Misses 181 181
Partials 50 50 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @AniketDalvi, thanks for your submission!
I have added a couple of comments to your test. Don't worry too much about them as I don't expect a perfect submission and my comments are mostly for explaining and showing how things are done here :)
I do have one other comment that is outside of these other ones, and that is about python and keyword arguments.
When I make the following modifications to your test, in order to test if passing keyword arguments to f
work
diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py
index 34b553f47..987b3b8c6 100644
--- a/frontend/test/pytest/test_static_arguments.py
+++ b/frontend/test/pytest/test_static_arguments.py
@@ -135,6 +135,7 @@ class TestStaticArguments:
assert f(20) == 29
assert f(20, 3) == 23
assert f(20, 300000) == 42000
+ assert f(20, x=3) == 23
def test_mutable_static_arguments(self):
"""Test QJIT with mutable static arguments."""
it fails.
This is technically allowed under python and I also tested it with against jax.jit
(which also works with jax.jit
). Can you provide some more details and if time permits a solution that would make the test above pass as well?
Thanks!
Hi @erick-xanadu, thank you for these comments. Both the points about using On your other comment about calling the function using keyword arguments, I did some debugging and research. It looks like this is an issue independent of having default arguments. I made some simple test cases which do not use default arguments, and these fail on using keyword arguments too. It appears as though the |
Interesting, can you provide a small example in the comments? Does it only fail with qjit, or does jax.jit fail as well? |
Yes, and can you provide a bit more explanation as to why it fails? |
Hi @paul0403 and @erick-xanadu, So here's an example of a case that fails
This test case, however, passes when we use
Upon doing some step through debugging, I realized that this happens because
|
@AniketDalvi , thanks for your answer! Will you be implementing a solution to the test case regarding keyword arguments or are you done with your submission? Cheers! |
Hi! I am curious to understand how this would affect my assessment. It would be a slightly larger undertaking as it appears that the keyword argument test case doesn't just pertain to #1163. Although I would be glad to take a stab at it, I am currently also tied up in some PhD deadlines. Furthermore, I recently also notified the recruiter that I received a verbal offer from a company I was interviewing with and requested to expedite the recruitment process as this role with Xanadu is a top priority for me. |
This work is part of a coding assessment for a full-time role at Xanadu. It attempts to resolve #1163.
Currently using the decorator
qjit(static_argnums=...)
fails if the marked static argument has a default argument that isn't passed when the function is called. The exception thrown could be traced back to a condition here -catalyst/frontend/catalyst/tracing/type_signatures.py
Line 120 in 75dc517
It appears that it checks the index used to specify the static argument with the number of
args
passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater thanlen(args)
as the default parameter may not be passed when the function is called. Considering this, my proposal for a solution would be to compare the static argument index to the larger value between the number of arguments passed in the function call and number of arguments in the function signature. This will accommodate for the current bug and the case when the function is declared with variable number of arguments (likefoo(*args)
).I have tested this with the examples outlined in the issue. It also passes all of the current frontend test cases.