Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax compatible api #1207

Merged
merged 8 commits into from
Nov 27, 2024
Merged

Add jax compatible api #1207

merged 8 commits into from
Nov 27, 2024

Conversation

sky-2002
Copy link
Contributor

This PR adds a JAX compatible API, refer issue #1027

Looking forward for review and open to feedback(especially in writing tests).

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work so far!

Could you please add jax to benchmarks/bench_processors.py as well?

Also could you fill me in on how you intend to use a jax based model, just so I have context on how Outlines is being used here?

Thanks!

outlines/processors/base_logits_processor.py Show resolved Hide resolved
tests/processors/test_base_processor.py Outdated Show resolved Hide resolved
@sky-2002
Copy link
Contributor Author

Hey @lapp0 , thanks for the feedback.
Have added to benchmarks(though I haven't run those yet, will possibly do)

Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with outlines.

PS: I currently have no use case with JAX, but I was trying out outlines and found it very interesting, so wanted to contribute.

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS: I currently have no use case with JAX, but I was trying out outlines and found it very interesting, so wanted to contribute.

Thanks for clarifying and your interest in the project.

Have added to benchmarks(though I haven't run those yet, will possibly do)

I'll add the run-benchmarks label, which will trigger their run once the workflow is approved.

Regarding intent to use with JAX, @borisdayma(author of issue mentioned above) can help us understand jax based model usage with outlines.

@borisdayma could you review and smoke test to ensure this fits your desired use case?

tests/processors/test_base_processor.py Show resolved Hide resolved
MLX_AVAILABLE = True
except ImportError:
MLX_AVAILABLE = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this if you use the below comment

outlines/processors/base_logits_processor.py Outdated Show resolved Hide resolved
outlines/processors/base_logits_processor.py Outdated Show resolved Hide resolved
@sky-2002 sky-2002 marked this pull request as ready for review October 15, 2024 05:19
@borisdayma
Copy link

borisdayma commented Oct 16, 2024

This looks great but I’ll still need some time to try it out.

The goal is to use it with repo’s such as:

Since the basic tests seem to work I expect it should be fine and could create a new issue if for some reason it prevents JAX compilation.

@rlouf rlouf force-pushed the jax-compatible-api branch from 722a25a to 5072edc Compare November 27, 2024 15:18
@rlouf
Copy link
Member

rlouf commented Nov 27, 2024

Thank you for contributing and addressing @lapp0's comments. It looks like we'll be good to merge if the CI turns green.

@rlouf rlouf merged commit 5608dd8 into dottxt-ai:main Nov 27, 2024
4 of 6 checks passed
@sky-2002 sky-2002 deleted the jax-compatible-api branch November 28, 2024 06:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants