Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Add debug_info to DynamicJaxprTrace and JaxprStackFrame #25827

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jan 10, 2025

This is part of a sequence of changes to ensure that the debugging information is propagated properly.

Additional cleanup:

  • Rename result_paths to result_paths_thunk in TracingDebugInfo to clarify the difference from the similar field in JaxprDebugInfo
  • Added more type declarations

@gnecula gnecula self-assigned this Jan 10, 2025
@gnecula gnecula added the better_errors Improve the error reporting label Jan 10, 2025
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jan 10, 2025
@gnecula gnecula force-pushed the debug_info_2 branch 6 times, most recently from a457184 to 4baa6f2 Compare January 13, 2025 12:45
@gnecula gnecula requested review from mattjj and dfm January 13, 2025 12:46
Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a minute to work out all the logic, but this looks good to me! My only high level comment is that it seems like it could be useful to add some tests that exercise this new behavior, even just as an example of what this enables.

@@ -83,9 +83,10 @@ def jvpfun(f, instantiate, transform_stack, primals, tangents):
return out_primals, out_tangents

@lu.transformation_with_aux2
def linearize_subtrace(_f, _store, _tag, nzs_in, *primals, **params):
def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the type of _f be lu.WrappedFun?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no, it should be a Callable. (I added it, because I had the same confusion earlier.) See how below we invoke it, and WrappedFun is not a Callable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's quite right! linearize_subtrace is called in LinearizeTrace.process_call, which means that the _f argument is the fun from call_p, which (for better or worse) is a WrappedFun.

Also, as far as I know, the first input to a linear_util transformation must always be a WrappedFun!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Ignore what I said above! I understand what you mean now. In the new version of linear_util without context the inputs are magically Callables (and I now parsed your explanation properly). Sorry for my confusion!

jax/_src/linear_util.py Show resolved Hide resolved
jax/_src/api.py Show resolved Hide resolved
@gnecula gnecula closed this Jan 14, 2025
@gnecula gnecula reopened this Jan 14, 2025
@gnecula gnecula force-pushed the debug_info_2 branch 2 times, most recently from c47b287 to d0f6b66 Compare January 14, 2025 10:10
This is part of a sequence of changes to ensure that the debugging information
is propagated properly.

Additional cleanup:
* Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the
  difference from the similar field in `JaxprDebugInfo`
* Added more type declarations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants