Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qjit(static_argnums=...) fails when the marked static argument has a default value #1163

Open
paul0403 opened this issue Sep 30, 2024 · 23 comments · May be fixed by #1295
Open

qjit(static_argnums=...) fails when the marked static argument has a default value #1163

paul0403 opened this issue Sep 30, 2024 · 23 comments · May be fixed by #1295
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@paul0403
Copy link
Contributor

paul0403 commented Sep 30, 2024

Context

When jit-compiling a python function, the arguments of the compiled function lose their concrete values and are replaced by tracers, which at a high level means abstract variables that have the same type and shape as the concrete variable. A compiled program, called the jaxpr, uses these abstract tracers to represent how the arguments of the function are used.

The below example shows how to use Catalyst to jit-compile a function, and how to inspect the compiled jaxpr.

import catalyst
from catalyst import qjit

@qjit
def f(x, y):
  return x+y

result = f(10, 20)
print(result)
print(f.jaxpr)

result = f(1.1, 2.2)
print(result)
print(f.jaxpr)
30
{ lambda ; a:i64[] b:i64[]. let
    c:i64[] = add a b
  in (c,) }
3.3
{ lambda ; a:f64[] b:f64[]. let
    c:f64[] = add a b
  in (c,) }

Notice that in the jaxpr, the type of the arguments to the function, i64 and f64, are the same as the type of the concrete arguments of their corresponding calls. The process of converting python to jaxpr is called tracing.

One issue with arguments being abstract is when their concrete value is needed, for example when being compared to other concrete values, tracing will fail, since abstract tracers cannot be interpreted as concrete values. See here for more details.

@qjit
def f(x, y):
  if x < 100:
    return x+y
  return 42

result = f(10, 20)
print(result)
print(f.jaxpr)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..

To avoid this problem, some of the function arguments can be marked static, which essentially means when tracing, keep their concrete values, and don't replace them with tracers. This marking can be done by the static_argnums keyword argument of qjit, which takes in a list of argument indices to be marked static.

@qjit(static_argnums=[0])
def f(x, y):
  if x < 100:
    return x+y
  return 42

result = f(10, 20)
print(result)
print(f.jaxpr)

result = f(1000, 20)
print(result)
print(f.jaxpr)
30
{ lambda ; a:i64[]. let  ; b:i64[] = add 10 a in (b,) }
42
{ lambda ; a:i64[]. let  ;  in (42,) }

However, currently in Catalyst, arguments with default values cannot be marked as static_argnum:

@qjit(static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000


res = f(20)
print(res)
catalyst.utils.exceptions.CompileError: argnum 1 is beyond the valid range of [0, 1).

Goal

We would like to support static_argnums in qjit to mark arguments with default values, as this is supported by native jax.jit:

from functools import partial
import jax
@partial(jax.jit, static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000

print(f(20), f(20, 3), f(20, 30000))
29 23 42000

Requirements:

  • We would like to support marking arguments with default values as static, and match all behavior observed from jax.jit. Explicitly:
    • If a value is not provided for the default-valued argument in the call, the default concrete value should be used in the jaxpr.
    • If a value is provided for the default-valued argument in the call, then the provided concrete value should be used in the jaxpr.

Technical details

  • Due to reasons that do not concern us here, all jaxprs produced by qjit will carry a transform_named_sequence. You can safely ignore it.

  • The qjit function takes in a python function and returns a QJIT object, which is a callable. In the QJIT object, there is a capture method that determines how a python function is traced into a jaxpr. See frontend/catalyst/jit.py.

  • It should be possible to implement this functionality completely in the capture layer, without delving into the actual underlying machinery of the trace_to_jaxpr methods. For example, one option is to create two versions of the function, both without any default-valued arguments, adjusted to behave correctly, and trace these two functions depending on whether a default value was supplied by the user's call. Other options might be possible too.

  • There are many cases potentially possible for how a user might call a function. For the purpose of this assessment, only the simplest case is required, aka the example above that worked with pure jax (though bonus points if you can make more complicated patterns work).

Installation help

To save time, instead of installing Catalyst from source it is also be possible to download the PyPI wheel and extract it into the frontend directory of a cloned catalyst repository (taking care to match git tags before hand), followed by make frontend. This then allows modifying the Python files in-place.

Alternatively, complete instructions to install Catalyst from source can be found here, but due to the size of the llvm-project it can take a while (~3 hrs on a personal laptop) to compile.

@paul0403 paul0403 added the bug Something isn't working label Sep 30, 2024
@paul0403
Copy link
Contributor Author

paul0403 commented Sep 30, 2024

I think this is because the static_argnum tracks the argument indices at the call, not at the definition, though this is just a hunch and might be completely wrong

@paul0403 paul0403 changed the title qjit(static_argnums=...) fails when the marks static argument has a default value qjit(static_argnums=...) fails when the marked static argument has a default value Oct 18, 2024
@paul0403 paul0403 added the good first issue Good for newcomers label Oct 19, 2024
@AniketDalvi
Copy link

Hi! I am Aniket, a PhD candidate at Duke interviewing for a role on the compiler team. I was given this issue as a technical challenge. Apart from the details mentioned in this issue, are there any other pointers that might help me tackle this issue? Thank you!

@paul0403
Copy link
Contributor Author

paul0403 commented Nov 4, 2024

Hi @AniketDalvi , do you have a more specific question in mind? I would love to give out more pointers and resolve any confusion you may have!

@AniketDalvi
Copy link

AniketDalvi commented Nov 4, 2024

Hi! So for installation - it says to download the PyPI wheel. Which version should I be downloading the wheel for? pip says there are 3 available versions - 0.1.0, 0.1.1 and 0.1.2.

More specifically, I get the following error when downloading the wheel on my linux computer:

ERROR: Cannot install pennylane-catalyst==0.1.0, pennylane-catalyst==0.1.1 and pennylane-catalyst==0.1.2 because these package versions have conflicting dependencies.

The conflict is caused by:
    pennylane-catalyst 0.1.2 depends on jaxlib==0.4.1
    pennylane-catalyst 0.1.1 depends on jaxlib==0.4.1
    pennylane-catalyst 0.1.0 depends on jaxlib==0.4.1

@paul0403
Copy link
Contributor Author

paul0403 commented Nov 4, 2024

Hi! So for installation - it says to download the PyPI wheel. Which version should I be downloading the wheel for? pip says there are 3 available versions - 0.1.0, 0.1.1 and 0.1.2.

I don't think there's gonna be any difference w.r.t. this particular issue, any one of the three should be good.

cc @rauletorresc who worked on the frontend dev plug-in

@erick-xanadu
Copy link
Contributor

Hey @AniketDalvi, I am not too sure why your pip says that there are only 3 versions available. We have version 0.8.1 in pypi and will be releasing version 0.9.0 soon.

@AniketDalvi
Copy link

Okay I am just going to download the .whl file from the above link. I will then extract it in the frontend directory as was directed in the installation instructions

@AniketDalvi
Copy link

Hi! Okay, so I have followed the instruction and seemed to have successfully installed the repository. Is there a quick sanity check experiment/file I can run to verify the installation?

@erick-xanadu
Copy link
Contributor

Hi! Okay, so I have followed the instruction and seemed to have successfully installed the repository. Is there a quick sanity check experiment/file I can run to verify the installation?

You can run the tests. pytest frontend/test/pytest -n auto. I think it is possible that some debugging tests fail in your machine as they expect a specific path. But if the vast majority are passing, then it should be good.

@AniketDalvi
Copy link

Running pytest gave me an error stating that there was an interpreter mismatch. The interpreter is Python 3.10 while the package is compatible only with 3.12

@erick-xanadu
Copy link
Contributor

erick-xanadu commented Nov 5, 2024

Running pytest gave me an error stating that there was an interpreter mismatch. The interpreter is Python 3.10 while the package is compatible only with 3.12

When you download via pip, pip enforces this compatibility check, but manually downloading it you need to make sure to download the appropriate one for you. See here for a list of several wheels with different python version compatibility.

Maybe it would be easier to install from source? It just takes a long time to build LLVM initially.

EDIT: It also looks like the new Catalyst version 0.9.0 is now available for download :)

@AniketDalvi
Copy link

Understood, that makes sense. I might re-try it with a different wheel with a compatible python version. If not, I will resort to installing from source.

@AniketDalvi
Copy link

Okay when trying to run the make frontend command with the new wheel, I get this error - error: command '/usr/bin/g++' failed with exit code 1. I have tried all solutions that stack overflow has to offer, but to no avail. Any thoughts on what this could be?

@paul0403
Copy link
Contributor Author

paul0403 commented Nov 6, 2024

Okay when trying to run the make frontend command with the new wheel, I get this error - error: command '/usr/bin/g++' failed with exit code 1. I have tried all solutions that stack overflow has to offer, but to no avail. Any thoughts on what this could be?

There's some required packages before building Catalyst. Maybe some of them are missing? See the build from source guide

@paul0403
Copy link
Contributor Author

paul0403 commented Nov 6, 2024

(If wheels are too complicated I recommend just building from source.)

@AniketDalvi
Copy link

Yup installed all the required packages, but the error persists. I am now just going to build from source instead.

@paul0403
Copy link
Contributor Author

paul0403 commented Nov 6, 2024

Yup installed all the required packages, but the error persists. I am now just going to build from source instead.

To avoid all package version issues, I also recommend using a fresh virtual environment when developing, e.g.

python3 -m venv pyenv
source ./pyenv/bin/activate

after which you can pip install all the requirements and make all from source.

@AniketDalvi
Copy link

Okay I seemed to have gotten it to work from source. Most tests pass, some are skipped, and 4 debugging tests fails (as @erick-xanadu said is expected). I am running all of this from with within a conda environment on a linux machine.

@erick-xanadu
Copy link
Contributor

Okay I seemed to have gotten it to work from source. Most tests pass, some are skipped, and 4 debugging tests fails (as @erick-xanadu said is expected).

Yeah, a bunch are skipped, 4 failing ones.

I am running all of this from with within a conda environment on a linux machine.

Awesome! We don't normally use conda so I am happy to hear this worked for you :)

@AniketDalvi
Copy link

AniketDalvi commented Nov 7, 2024

Hi! So from my initial analysis, I traced the issue down to this check that throws an exception -

if argnum < 0 or argnum >= len(args):
.

It appears that it checks the index used to specify the static argument with the number of args passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater than len(args) as the default parameter may not be passed when the function is called. Considering this, my initial proposal for a solution would be to compare the static argument index to the number of arguments in the function signature, as opposed to the number in the function itself.

Would like to get your thoughts on this!

@paul0403
Copy link
Contributor Author

paul0403 commented Nov 7, 2024

Hi! So from my initial analysis, I traced the issue down to this check that throws an exception -

if argnum < 0 or argnum >= len(args):

.
It appears that it checks the index used to specify the static argument with the number of args passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater than len(args) as the default parameter may not be passed when the function is called. Considering this, my initial proposal for a solution would be to compare the static argument index to the number of arguments in the function signature, as opposed to the number in the function itself.

Would like to get your thoughts on this!

Hi! Usually what we do for these challenges is you can fork the catalyst repo, push your changes, and open a PR. It will be easier to review.

It's good if it turns out to be a simple fix, but one thing I'm afraid of is whether loosening the verification would allow in some errors. Can you test the frontend test suite to make sure this does not happen? You can run make test-frontend in the root catalyst directory.

@AniketDalvi
Copy link

Hi! Okay that sounds good. I am working off of a my branch. Does that work, or does it have to be a fork?

@erick-xanadu
Copy link
Contributor

It doesn't really matter how you develop, as long as you are able to push a pull request :). I think to do that, you do need a fork. But you can always just add a new remote to your local git workspace.

git clone $pennylane/catalyst
# work on the issue
# fork $pennylane/catalyst to $yourrepo/catalyst
git remote add myrepo $yourrepo/catalyst
git push myrepo $yourbranch
# open a PR from $yourrepo/catalyst to $pennylane/catalyst

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
3 participants