Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Experimental JAX Check Numerics API. #25785

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

epiqueras
Copy link
Collaborator

@epiqueras epiqueras commented Jan 8, 2025

Higher precision is generally better. But, more precision also costs more. It would be useful to have an automated way of finding where in code it's a good idea to spend resources (compute/memory) on higher precision. In the context of first order Autodiff trained models. We can do this by implementing a fairly simple JAX transform that aims to extract one metric for every input to each primitive and one metric for each output of each primitive. Let's look at an example program:

@jax.jit
def matmul_with_residual(x, ys):
  def layer1(x, y):
    return x + jax.nn.gelu(x @ y)

  def layer2(x, y):
    return x + jax.nn.swish(x @ y)

  def layer3(x, y):
    return x + jax.nn.sigmoid(x @ y)

  def layer4(x, y):
    return x + jax.nn.tanh(x @ y)

  for l, y in zip([layer1, layer2, layer3, layer4], ys):
    x = l(x, y)
  return jnp.sum(x)

For the inputs to a primitive we want to extract the mean squared errors of the difference in their gradients when computed in high vs. low precision. This is analogous to the forward error of the backward pass function:

jnp.mean(
  jnp.square((grad - low_precision_grad.astype(grad.dtype)).astype(jnp.float32))
)

For the outputs of a primitive we want to compute the dot product of the difference in the output when computed in high vs. low precision and the gradient. This is analogous to a sort of inverse of the backward error of the forward pass function and tells us how sensitive the loss is to a given error:

delta = out - low_precision_out.astype(out.dtype)
jnp.sum((g * delta.astype(g.dtype)).astype(jnp.float32))

The transform works by returning two functions. One that you run to trace the target function for the purpose of extracting the list of expected metrics so that you can pass them as inputs to the computation. The other takes these inputs and uses the Autodiff system in a way that makes the gradients of these metric inputs the values the function aims to extract. Utilities are also provided to sort the results and print them in a way that makes it easy to quickly identify the biggest numerical bottlenecks of a program. For example:

key = jax.random.key(42)
x = jax.random.uniform(key, (128, 1024), dtype=jnp.float32)
y = jax.random.uniform(jax.random.split(key)[1], (1024, 1024), dtype=jnp.float32)
args = x, (y, y, y, y)
check, source_metrics = numerics_check.numerics_check(matmul_with_residual)
metric_keys = source_metrics(*args)
metrics = numerics_check.metric_keys_to_metrics(metric_keys)
metrics: numerics_check.Metrics = jax.jit(jax.grad(check))(metrics, *args)
metric_keys, metrics = numerics_check.sort_metrics_by_out_metric(
  metric_keys, metrics
)
numerics_check.print_metrics(metric_keys, metrics, normalize_out_metric=True)
reduce_sum:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:167:13 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual):
  In metrics: (Array(0., dtype=float32),)
  Out metric: 1.0

dot_general:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:157:32 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer2):
  In metrics: (Array(0.7834606, dtype=float32), Array(3489.6716, dtype=float32))
  Out metric: 0.00011606767063664462

dot_general:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:154:31 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer1):
  In metrics: (Array(215670.23, dtype=float32), Array(6921.871, dtype=float32))
  Out metric: 6.192288453520269e-05

add:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:154:15 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer1):
  In metrics: (Array(0.8573911, dtype=float32), Array(0.8573911, dtype=float32))
  Out metric: 5.5814833872004735e-05

mul:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:154:19 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer1):
  In metrics: (Array(0.8573911, dtype=float32), Array(108569.2, dtype=float32))
  Out metric: 5.5804327095507726e-05

add:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:163:15 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer4):
  In metrics: (Array(0., dtype=float32), Array(0., dtype=float32))
  Out metric: 5.5169262114450295e-05

add:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:160:15 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer3):
  In metrics: (Array(0., dtype=float32), Array(0., dtype=float32))
  Out metric: 5.439651337370313e-05

add:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:157:15 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer2):
  In metrics: (Array(0., dtype=float32), Array(0., dtype=float32))
  Out metric: 4.7548649231166806e-05

mul:/Users/enriqueps/Developer/jax/tests/numerics_check_test.py:157:19 (NumericsCheckTest.test_demo.<locals>.matmul_with_residual.<locals>.layer2):
  In metrics: (Array(0., dtype=float32), Array(54109.715, dtype=float32))
  Out metric: 4.7412547292800515e-05

This shows us that the second layer's dot and the final sum are disproportionately sensitive to numerics for these inputs. This tells us we should investigate why the values going into layer 2 or the gradients coming back make the operation so sensitive. Perhaps the layer before or its own activation function need massaging. Or perhaps this is a more macro signal propagation issue and that layer should be done in higher precision than the others.

@epiqueras epiqueras requested a review from mattjj January 8, 2025 20:26
@epiqueras epiqueras force-pushed the feature/numerics branch 2 times, most recently from f4c6d8c to 5508cb9 Compare January 8, 2025 20:37
@epiqueras epiqueras requested a review from froystig January 8, 2025 20:58
@mattjj
Copy link
Collaborator

mattjj commented Jan 9, 2025

Idea from our convo: try adding a dup primitive, so we can pick up on the backward pass cotangent addition stemming from fan-out (ie variable reuse), which seems important.

@epiqueras
Copy link
Collaborator Author

Idea from our convo: try adding a dup primitive, so we can pick up on the backward pass cotangent addition stemming from fan-out (ie variable reuse), which seems important.

Done, wasn't that hard on current code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants