-
Notifications
You must be signed in to change notification settings - Fork 532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jax compatible api #1207
Add jax compatible api #1207
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work so far!
Could you please add jax to benchmarks/bench_processors.py
as well?
Also could you fill me in on how you intend to use a jax based model, just so I have context on how Outlines is being used here?
Thanks!
Hey @lapp0 , thanks for the feedback. Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with PS: I currently have no use case with JAX, but I was trying out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PS: I currently have no use case with JAX, but I was trying out outlines and found it very interesting, so wanted to contribute.
Thanks for clarifying and your interest in the project.
Have added to benchmarks(though I haven't run those yet, will possibly do)
I'll add the run-benchmarks
label, which will trigger their run once the workflow is approved.
Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with outlines.
@borisdayma could you review and smoke test to ensure this fits your desired use case?
MLX_AVAILABLE = True | ||
except ImportError: | ||
MLX_AVAILABLE = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for this if you use the below comment
This looks great but I’ll still need some time to try it out. The goal is to use it with repo’s such as:
Since the basic tests seem to work I expect it should be fine and could create a new issue if for some reason it prevents JAX compilation. |
722a25a
to
5072edc
Compare
Thank you for contributing and addressing @lapp0's comments. It looks like we'll be good to merge if the CI turns green. |
This PR adds a JAX compatible API, refer issue #1027
Looking forward for review and open to feedback(especially in writing tests).