New Experimental JAX Check Numerics API. #25785
Open
+885
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Higher precision is generally better. But, more precision also costs more. It would be useful to have an automated way of finding where in code it's a good idea to spend resources (compute/memory) on higher precision. In the context of first order Autodiff trained models. We can do this by implementing a fairly simple JAX transform that aims to extract one metric for every input to each primitive and one metric for each output of each primitive. Let's look at an example program:
For the inputs to a primitive we want to extract the mean squared errors of the difference in their gradients when computed in high vs. low precision. This is analogous to the forward error of the backward pass function:
For the outputs of a primitive we want to compute the dot product of the difference in the output when computed in high vs. low precision and the gradient. This is analogous to a sort of inverse of the backward error of the forward pass function and tells us how sensitive the loss is to a given error:
The transform works by returning two functions. One that you run to trace the target function for the purpose of extracting the list of expected metrics so that you can pass them as inputs to the computation. The other takes these inputs and uses the Autodiff system in a way that makes the gradients of these metric inputs the values the function aims to extract. Utilities are also provided to sort the results and print them in a way that makes it easy to quickly identify the biggest numerical bottlenecks of a program. For example:
This shows us that the second layer's dot and the final sum are disproportionately sensitive to numerics for these inputs. This tells us we should investigate why the values going into layer 2 or the gradients coming back make the operation so sensitive. Perhaps the layer before or its own activation function need massaging. Or perhaps this is a more macro signal propagation issue and that layer should be done in higher precision than the others.