Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Tensor Puzzlers from torchtyping to jaxtyping. #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

davideger
Copy link

jaxtyping is a new, improved, and maintained version of torchtyping by the same author (patrick kidger).
When used with beartype, jaxtyping can inform the user of shape mismatch errors at run time.

Other minor formatting issues were fixed in the Tensor Puzzles notebook.

jaxtyping is a new, improved, and maintained version of torchtyping by the
same author (patrick kidger).  When used with beartype, jaxtyping can inform
the user of shape mismatch errors at run time.

Other minor formatting issues were fixed in the Tensor Puzzles notebook.
@srush
Copy link
Owner

srush commented Jan 12, 2024

Amazing! I was just planning on doing this.

@srush
Copy link
Owner

srush commented Jan 12, 2024

Just curious, Why do you need "{j j}"?

@davideger
Copy link
Author

My understanding (and @patrick-kidger correct me if I am wrong) is that if you want jaxtyping to ensure that a return type dimension matches a scalar function parameter, you need to use f-string escaping to reference the function parameter. At least that's what I get from interpreting https://github.com/google/jaxtyping/blob/main/docs/api/array.md.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants