-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[better_errors] Add debug_info to DynamicJaxprTrace and JaxprStackFrame #25827
base: main
Are you sure you want to change the base?
Conversation
a457184
to
4baa6f2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took me a minute to work out all the logic, but this looks good to me! My only high level comment is that it seems like it could be useful to add some tests that exercise this new behavior, even just as an example of what this enables.
@@ -83,9 +83,10 @@ def jvpfun(f, instantiate, transform_stack, primals, tangents): | |||
return out_primals, out_tangents | |||
|
|||
@lu.transformation_with_aux2 | |||
def linearize_subtrace(_f, _store, _tag, nzs_in, *primals, **params): | |||
def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the type of _f
be lu.WrappedFun
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, no, it should be a Callable. (I added it, because I had the same confusion earlier.) See how below we invoke it, and WrappedFun
is not a Callable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's quite right! linearize_subtrace
is called in LinearizeTrace.process_call
, which means that the _f
argument is the fun
from call_p
, which (for better or worse) is a WrappedFun
.
Also, as far as I know, the first input to a linear_util
transformation must always be a WrappedFun
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: Ignore what I said above! I understand what you mean now. In the new version of linear_util
without context the inputs are magically Callable
s (and I now parsed your explanation properly). Sorry for my confusion!
c47b287
to
d0f6b66
Compare
This is part of a sequence of changes to ensure that the debugging information is propagated properly. Additional cleanup: * Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the difference from the similar field in `JaxprDebugInfo` * Added more type declarations
This is part of a sequence of changes to ensure that the debugging information is propagated properly.
Additional cleanup:
result_paths
toresult_paths_thunk
inTracingDebugInfo
to clarify the difference from the similar field inJaxprDebugInfo