Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issue of qjit(static_argnums=...) fails when the marked static argument has a default value #1295

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

AniketDalvi
Copy link

This work is part of a coding assessment for a full-time role at Xanadu. It attempts to resolve #1163.

Currently using the decorator qjit(static_argnums=...) fails if the marked static argument has a default argument that isn't passed when the function is called. The exception thrown could be traced back to a condition here -

if argnum < 0 or argnum >= len(args):
.

It appears that it checks the index used to specify the static argument with the number of args passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater than len(args) as the default parameter may not be passed when the function is called. Considering this, my proposal for a solution would be to compare the static argument index to the larger value between the number of arguments passed in the function call and number of arguments in the function signature. This will accommodate for the current bug and the case when the function is declared with variable number of arguments (like foo(*args)).

I have tested this with the examples outlined in the issue. It also passes all of the current frontend test cases.

Copy link

codecov bot commented Nov 8, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.95%. Comparing base (b12e047) to head (d9774f5).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1295   +/-   ##
=======================================
  Coverage   97.95%   97.95%           
=======================================
  Files          77       77           
  Lines       11318    11320    +2     
  Branches      981      981           
=======================================
+ Hits        11087    11089    +2     
  Misses        181      181           
  Partials       50       50           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @AniketDalvi, thanks for your submission!

I have added a couple of comments to your test. Don't worry too much about them as I don't expect a perfect submission and my comments are mostly for explaining and showing how things are done here :)

I do have one other comment that is outside of these other ones, and that is about python and keyword arguments.

When I make the following modifications to your test, in order to test if passing keyword arguments to f work

diff --git a/frontend/test/pytest/test_static_arguments.py b/frontend/test/pytest/test_static_arguments.py
index 34b553f47..987b3b8c6 100644
--- a/frontend/test/pytest/test_static_arguments.py
+++ b/frontend/test/pytest/test_static_arguments.py
@@ -135,6 +135,7 @@ class TestStaticArguments:
         assert f(20) == 29
         assert f(20, 3) == 23
         assert f(20, 300000) == 42000
+        assert f(20, x=3) == 23
 
     def test_mutable_static_arguments(self):
         """Test QJIT with mutable static arguments."""

it fails.

This is technically allowed under python and I also tested it with against jax.jit (which also works with jax.jit). Can you provide some more details and if time permits a solution that would make the test above pass as well?

Thanks!

frontend/catalyst/jit.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_static_arguments.py Show resolved Hide resolved
@AniketDalvi
Copy link
Author

Hi @erick-xanadu, thank you for these comments. Both the points about using pytest.mark.parametrize and LLVM's lit functionality are very insightful, and I will keep these in mind for future development.

On your other comment about calling the function using keyword arguments, I did some debugging and research. It looks like this is an issue independent of having default arguments. I made some simple test cases which do not use default arguments, and these fail on using keyword arguments too. It appears as though the static_argnums=... option in the decorator does not apply if the index refers to a keyword argument resulting in a error like jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. (if the argument in question is used in a boolean expression).

@paul0403
Copy link
Contributor

paul0403 commented Nov 11, 2024

It appears as though the static_argnums=... option in the decorator does not apply if the index refers to a keyword argument

Interesting, can you provide a small example in the comments? Does it only fail with qjit, or does jax.jit fail as well?

@erick-xanadu
Copy link
Contributor

Interesting, can you provide a small example in the comments? Does it only fail with qjit, or does jax.jit fail as well?

Yes, and can you provide a bit more explanation as to why it fails?

@AniketDalvi
Copy link
Author

AniketDalvi commented Nov 12, 2024

Hi @paul0403 and @erick-xanadu,

So here's an example of a case that fails

        @qjit(static_argnums=[1])
        def f(y, x):
            if x < 10:
                return x + y
            return 42000

        assert f(20, x=3) == 23

This test case, however, passes when we use jax.jit, like in the example below

        @partial(jax.jit, static_argnums=[1])
        def f(y, x):
            if x < 10:
                return x + y
            return 42000

        assert f(20, x=3) == 23

Upon doing some step through debugging, I realized that this happens because capture() only filters static arguments from args and not from kwargs, and in the above case the the static argument is in kwargs. I have added comments in the function snippet below to pinpoint the issue.

    def capture(self, args, **kwargs):
        """Capture the JAX program representation (JAXPR) of the wrapped function.

        Args:
            args (Iterable): arguments to use for program capture

        Returns:
            ClosedJaxpr: captured JAXPR
            PyTreeDef: PyTree metadata of the function output
            Tuple[Any]: the dynamic argument signature
        """

        # In the above case `args=(20)` and `kwargs={'x': 3}`.
        # The below lines of code only process static arguments from `args`.

        verify_static_argnums(args, self.original_function, self.compile_options.static_argnums)
        static_argnums = self.compile_options.static_argnums
        abstracted_axes = self.compile_options.abstracted_axes

        dynamic_args = filter_static_args(args, static_argnums)
        dynamic_sig = get_abstract_signature(dynamic_args)
        full_sig = merge_static_args(dynamic_sig, args, static_argnums)
        ...

@erick-xanadu
Copy link
Contributor

@AniketDalvi , thanks for your answer! Will you be implementing a solution to the test case regarding keyword arguments or are you done with your submission? Cheers!

@AniketDalvi
Copy link
Author

AniketDalvi commented Nov 13, 2024

Hi! I am curious to understand how this would affect my assessment. It would be a slightly larger undertaking as it appears that the keyword argument test case doesn't just pertain to #1163. Although I would be glad to take a stab at it, I am currently also tied up in some PhD deadlines.

Furthermore, I recently also notified the recruiter that I received a verbal offer from a company I was interviewing with and requested to expedite the recruitment process as this role with Xanadu is a top priority for me.

@erick-xanadu
Copy link
Contributor

Hi @AniketDalvi, don't worry about it. Submission looks good. We will contact you back by e-mail. Cheers. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

qjit(static_argnums=...) fails when the marked static argument has a default value
3 participants