From 482209da90b9319f3dc4e05688f1953332bed21e Mon Sep 17 00:00:00 2001 From: Enrique Piqueras <19157096+epiqueras@users.noreply.github.com> Date: Fri, 27 Dec 2024 18:10:11 -0800 Subject: [PATCH] New Experimental JAX Check Numerics API. --- jax/BUILD | 1 + jax/experimental/numerics_check/__init__.py | 46 ++ jax/experimental/numerics_check/checks.py | 66 +++ .../numerics_check/numerics_check.py | 423 ++++++++++++++++++ tests/BUILD | 6 + tests/numerics_check_test.py | 184 ++++++++ 6 files changed, 726 insertions(+) create mode 100644 jax/experimental/numerics_check/__init__.py create mode 100644 jax/experimental/numerics_check/checks.py create mode 100644 jax/experimental/numerics_check/numerics_check.py create mode 100644 tests/numerics_check_test.py diff --git a/jax/BUILD b/jax/BUILD index aed97b2a3243..7c4c8d362b27 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -235,6 +235,7 @@ py_library_providing_imports_info( "_src/state/**/*.py", "_src/third_party/**/*.py", "experimental/key_reuse/**/*.py", + "experimental/numerics_check/**/*.py", "experimental/roofline/**/*.py", "image/**/*.py", "interpreters/**/*.py", diff --git a/jax/experimental/numerics_check/__init__.py b/jax/experimental/numerics_check/__init__.py new file mode 100644 index 000000000000..dcf88ffb5bc0 --- /dev/null +++ b/jax/experimental/numerics_check/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax.experimental.numerics_check.checks as checks +from jax.experimental.numerics_check.numerics_check import ( + MetricKeys as MetricKeys, +) +from jax.experimental.numerics_check.numerics_check import ( + Metrics as Metrics, +) +from jax.experimental.numerics_check.numerics_check import ( + MetricsKey as MetricsKey, +) +from jax.experimental.numerics_check.numerics_check import ( + MetricsValue as MetricsValue, +) +from jax.experimental.numerics_check.numerics_check import ( + metric_keys_to_metrics as metric_keys_to_metrics, +) +from jax.experimental.numerics_check.numerics_check import ( + numerics_check as numerics_check, +) +from jax.experimental.numerics_check.numerics_check import ( + print_metrics as print_metrics, +) +from jax.experimental.numerics_check.numerics_check import ( + register_numerics_check as register_numerics_check, +) +from jax.experimental.numerics_check.numerics_check import ( + sort_metrics_by_in_metrics as sort_metrics_by_in_metrics, +) +from jax.experimental.numerics_check.numerics_check import ( + sort_metrics_by_out_metric as sort_metrics_by_out_metric, +) + +del checks diff --git a/jax/experimental/numerics_check/checks.py b/jax/experimental/numerics_check/checks.py new file mode 100644 index 000000000000..251e42f13c48 --- /dev/null +++ b/jax/experimental/numerics_check/checks.py @@ -0,0 +1,66 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from functools import partial +from typing import Any + +from jax._src import core, pjit, typing +from jax.experimental.numerics_check.numerics_check import ( + Val, + register_numerics_check, + NumericsCheckTrace, + numerics_check_subtrace, +) +from jax._src.interpreters import partial_eval as pe +from jax._src import linear_util as lu + + +def _numerics_check_jaxpr_trace( + trace: NumericsCheckTrace, + jaxpr: core.ClosedJaxpr, +) -> tuple[core.ClosedJaxpr, list[Any]]: + f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) + f, subtrace_thunk = numerics_check_subtrace( + f, + core.TraceTag(), + trace.high_precision_dtype, + trace.low_precision_dtype, + trace.next_metric_index, + trace.metrics, + ) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_) + subtrace = subtrace_thunk()["trace"] + trace.metric_keys.extend(subtrace.metric_keys) + trace.next_metric_index = subtrace.next_metric_index + return core.ClosedJaxpr(jaxpr_, []), consts + + +@register_numerics_check(pjit.pjit_p) +def _pjit_numerics_check( + trace: NumericsCheckTrace, + in_metrics: tuple[typing.Array, ...], + out_metric: typing.Array, + *args: Val, + jaxpr: core.ClosedJaxpr, + **kwargs: Val, +) -> Val: + del in_metrics, out_metric + jaxpr, consts = _numerics_check_jaxpr_trace(trace, jaxpr) + kwargs["in_shardings"] = (pjit.UNSPECIFIED,) * len(consts) + kwargs["in_shardings"] + kwargs["in_layouts"] = (None,) * len(consts) + kwargs["in_layouts"] + kwargs["donated_invars"] = (False,) * len(consts) + kwargs["donated_invars"] + out_vals = pjit.pjit_p.bind(*consts, *args, jaxpr=jaxpr, **kwargs) + return out_vals diff --git a/jax/experimental/numerics_check/numerics_check.py b/jax/experimental/numerics_check/numerics_check.py new file mode 100644 index 000000000000..f198e42541db --- /dev/null +++ b/jax/experimental/numerics_check/numerics_check.py @@ -0,0 +1,423 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from functools import lru_cache, partial +from typing import Any, Callable, Concatenate, NamedTuple, ParamSpec, Protocol, TypeVar + +import jax.numpy as jnp +from jax._src import ( + api, + api_util, + core, + custom_derivatives, + source_info_util, + traceback_util, + tree_util, + typing, + util, +) +from jax._src import linear_util as lu +from jax._src.interpreters import partial_eval as pe +from jax._src.lax import lax + +zip = util.safe_zip + + +Val = Any + +class PrecisionDtype(NamedTuple): + dtype: jnp.dtype + exponent_bits: int + mantissa_bits: int + + +PRECISION_DTYPE_F32 = PrecisionDtype(jnp.float32, 8, 23) +PRECISION_DTYPE_BF16 = PrecisionDtype(jnp.bfloat16, 8, 7) + + +# Rules + + +class _NumericsCheckRule(Protocol): + def __call__( + self, + trace: "NumericsCheckTrace", + in_metrics: tuple[typing.Array, ...], + out_metric: typing.Array, + *args: Val, + **params: Val, + ) -> tuple[Val, ...]: ... + + +_numerics_checks: dict[core.Primitive, _NumericsCheckRule] = {} + + +def register_numerics_check(prim: core.Primitive): + def register(rule: _NumericsCheckRule): + _numerics_checks[prim] = rule + return rule + + return register + + +# Default Rule + + +def lower_precision(hp: PrecisionDtype, lp: PrecisionDtype, val: Val) -> Val: + if isinstance(val, typing.Array) and val.dtype == hp.dtype: + val = lax.reduce_precision( + val, + exponent_bits=lp.exponent_bits, + mantissa_bits=lp.mantissa_bits, + ) + val = val.astype(lp.dtype) + return val + + +@lru_cache +def _make_default_numerics_check(primitive: core.Primitive) -> _NumericsCheckRule: + @lru_cache + def make_default_numerics_check_with_kwargs( + high_precision_dtype: PrecisionDtype, + low_precision_dtype: PrecisionDtype, + **params: Val, + ) -> Val: + lp = partial(lower_precision, high_precision_dtype, low_precision_dtype) + + @custom_derivatives.custom_vjp + def default_numerics_check( + in_metrics: tuple[typing.Array, ...], out_metric: typing.Array, *args: Val + ) -> Val: + del in_metrics, out_metric + return primitive.bind(*args, **params) + + def default_numerics_fwd( + in_metrics: tuple[typing.Array, ...], out_metric: typing.Array, *args: Val + ): + def bind_primitive(*args): + return primitive.bind(*args, **params) + + out, f_vjp = api.vjp( + bind_primitive, + *args, + ) + low_precision_out, low_precision_f_vjp = api.vjp( + lambda *args: lp(bind_primitive(*args)), + *tuple(map(lp, args)), + ) + delta = out - low_precision_out.astype(out.dtype) + return out, (f_vjp, low_precision_f_vjp, delta, in_metrics) + + def default_numerics_bwd( + res: tuple[Callable, Callable, Val, tuple[typing.Array, ...]], g: Val + ): + f_vjp, low_precision_f_vjp, delta, in_metrics = res + out_metric = jnp.sum((g * delta.astype(g.dtype)).astype(jnp.float32)) + grads = f_vjp(g) + low_precision_grads = low_precision_f_vjp(lp(g)) + in_metrics = tuple( + jnp.mean( + jnp.square((grad - low_precision_grad.astype(grad.dtype)).astype(jnp.float32)) + ) + for grad, low_precision_grad in zip(grads, low_precision_grads) + ) + return (in_metrics, out_metric, *grads) + + default_numerics_check.defvjp(default_numerics_fwd, default_numerics_bwd) + return default_numerics_check + + def default_numerics_check( + trace: "NumericsCheckTrace", + in_metrics: tuple[typing.Array, ...], + out_metric: typing.Array, + *args: Val, + **params: Val, + ) -> Val: + return make_default_numerics_check_with_kwargs( + trace.high_precision_dtype, trace.low_precision_dtype, **params + )(in_metrics, out_metric, *args) + + return default_numerics_check + + +# Trace + + +class NumericsCheckTracer(core.Tracer): + _trace: "NumericsCheckTrace" + val: Val + + def __init__(self, trace, val): + self._trace = trace + self.val = val + + @property + def aval(self) -> core.AbstractValue: + return core.get_aval(self.val) + + def to_concrete_value(self) -> Val: + return core.to_concrete_value(self.val) + + +class MetricsKey: + primitive: core.Primitive + source_info: source_info_util.SourceInfo + in_metrics: int + in_avals: tuple[api.ShapeDtypeStruct, ...] + + def __init__( + self, + primitive: core.Primitive, + source_info: source_info_util.SourceInfo, + in_metrics: int, + in_avals: tuple[api.ShapeDtypeStruct, ...], + ): + self.primitive = primitive + self.source_info = source_info + self.in_metrics = in_metrics + self.in_avals = in_avals + + # It's slightly faster to use a class with __slots__ than a NamedTuple. + __slots__ = ["primitive", "source_info", "in_metrics", "in_avals"] + + +MetricsValue = tuple[tuple[typing.Array, ...], typing.Array] + +MetricKeys = list[MetricsKey] +Metrics = list[MetricsValue] + + +class NumericsCheckTrace(core.Trace[NumericsCheckTracer]): + parent_trace: core.Trace + tag: core.TraceTag + high_precision_dtype: PrecisionDtype + low_precision_dtype: PrecisionDtype + metric_keys: MetricKeys + + metrics: None | Metrics + next_metric_index: int + + def __init__( + self, + parent_trace, + tag, + high_precision_dtype: PrecisionDtype, + low_precision_dtype: PrecisionDtype, + next_metric_index: int, + metrics: None | Metrics, + ): + self.parent_trace = parent_trace + self.tag = tag + self.high_precision_dtype = high_precision_dtype + self.low_precision_dtype = low_precision_dtype + self.metric_keys = [] + self.next_metric_index = next_metric_index + self.metrics = metrics + + def to_val(self, val: Val | NumericsCheckTracer) -> Val: + if isinstance(val, NumericsCheckTracer) and val._trace.tag is self.tag: + return val.val + else: + return val + + @staticmethod + def make_metric() -> typing.Array: + return jnp.zeros((), dtype=jnp.float32) + + def process_primitive( + self, primitive: core.Primitive, tracers: tuple[Val, ...], params: dict[str, Val] + ) -> Val: + rule = _numerics_checks.get(primitive, None) + if rule is None: + rule = _make_default_numerics_check(primitive) + in_vals = tuple(map(self.to_val, tracers)) + metrics: MetricsValue | tuple[tuple[None, ...], None] + if self.metrics is None: + metrics = ( + (None,) * len(in_vals), + None, + ) + else: + metrics = self.metrics[self.next_metric_index] + self.next_metric_index += 1 + in_metrics = tuple( + NumericsCheckTrace.make_metric() if metric is None else metric + for metric in metrics[0] + ) + out_metric = NumericsCheckTrace.make_metric() if metrics[1] is None else metrics[1] + self.metric_keys.append( + MetricsKey( + primitive, + source_info_util.current(), + len(in_metrics), + tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in in_vals), + ) + ) + with core.set_current_trace(self.parent_trace): + out_vals = rule(self, in_metrics, out_metric, *in_vals, **params) + if primitive.multiple_results: + out_tracers = tuple(map(partial(NumericsCheckTracer, self), out_vals)) + return out_tracers + else: + out_tracer = NumericsCheckTracer(self, out_vals) + return out_tracer + + +# Transformation + + +P = ParamSpec("P") +R = TypeVar("R") + + +@lu.transformation_with_aux2 +def numerics_check_subtrace( + f: Callable, + store: lu.Store, + tag: core.TraceTag, + high_precision_dtype: PrecisionDtype, + low_precision_dtype: PrecisionDtype, + next_metric_index: int, + metrics: Metrics, + *args_flat: Val, +) -> tuple[Val, ...]: + with core.take_current_trace() as parent_trace: + trace = NumericsCheckTrace( + parent_trace, + tag, + high_precision_dtype, + low_precision_dtype, + next_metric_index, + metrics, + ) + in_tracers = tuple(map(partial(NumericsCheckTracer, trace), args_flat)) + with core.set_current_trace(trace): + out_tracers = f(*in_tracers) + out = tuple(map(trace.to_val, out_tracers)) + store.store(dict(trace=trace)) + return out + + +@lu.transformation2 +def numerics_check_trace( + f: Callable[ + Concatenate[core.TraceTag, PrecisionDtype, PrecisionDtype, int, Metrics, ...], R + ], + subtrace_thunk: Callable, + high_precision_dtype: PrecisionDtype, + low_precision_dtype: PrecisionDtype, + metrics: Metrics, + *args_flat: Val, +) -> R: + tag = core.TraceTag() + with source_info_util.transform_name_stack("numerics_check"): + out = f(tag, high_precision_dtype, low_precision_dtype, 0, metrics, *args_flat) + trace = subtrace_thunk()["trace"] + with core.ensure_no_leaks(trace): + del trace + return out + + +def numerics_check( + fun: Callable[P, R], + high_precision_dtype: PrecisionDtype = PRECISION_DTYPE_F32, + low_precision_dtype: PrecisionDtype = PRECISION_DTYPE_BF16, +) -> tuple[ + Callable[Concatenate[Metrics, P], Val], + Callable[P, MetricKeys], +]: + api_util.check_callable(fun) + docstr = "Takes similar arguments as {fun} but adds additional arrays in which numerical sensitivities are deposited." + if fun.__doc__: + docstr += "\n\nOriginal documentation:\n\n" + docstr += fun.__doc__ + + @util.wraps(fun, docstr=docstr) + @traceback_util.api_boundary + def numerics_check_f( + metrics: Metrics, + *args: P.args, + **kwargs: P.kwargs, + ) -> Val: + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + f = lu.wrap_init(fun) + f, out_tree_thunk = api_util.flatten_fun(f, in_tree) + f, subtrace_thunk = numerics_check_subtrace(f) + f = numerics_check_trace( + f, subtrace_thunk, high_precision_dtype, low_precision_dtype, metrics + ) + out_flat = f.call_wrapped(*args_flat) + return tree_util.tree_unflatten(out_tree_thunk(), out_flat) + + def numerics_check_metrics_f( + *args: P.args, + **kwargs: P.kwargs, + ) -> MetricKeys: + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + f = lu.wrap_init(fun) + f, _ = api_util.flatten_fun(f, in_tree) + f, subtrace_thunk = numerics_check_subtrace(f) + f = numerics_check_trace( + f, subtrace_thunk, high_precision_dtype, low_precision_dtype, None + ) + pe.trace_to_jaxpr_dynamic(f, tuple(core.get_aval(x) for x in args_flat)) + return subtrace_thunk()["trace"].metric_keys + + return numerics_check_f, numerics_check_metrics_f + + +def metric_keys_to_metrics(metric_keys: MetricKeys) -> Metrics: + return [ + ( + tuple(NumericsCheckTrace.make_metric() for _ in range(key.in_metrics)), + NumericsCheckTrace.make_metric(), + ) + for key in metric_keys + ] + +def sort_metrics_by_in_metrics( + metric_keys: MetricKeys, metrics: Metrics +) -> tuple[MetricKeys, Metrics]: + def sort_key(key_and_metric: tuple[MetricsKey, MetricsValue]) -> int: + _, (in_metrics, _) = key_and_metric + return max(abs(x.item()) for x in in_metrics) + + sorted_pairs = sorted(zip(metric_keys, metrics), key=sort_key, reverse=True) + return [x[0] for x in sorted_pairs], [x[1] for x in sorted_pairs] + + +def sort_metrics_by_out_metric( + metric_keys: MetricKeys, metrics: Metrics +) -> tuple[MetricKeys, Metrics]: + def sort_key(key_and_metric: tuple[MetricsKey, MetricsValue]) -> int: + _, (_, out_metric) = key_and_metric + return abs(out_metric.item()) + + sorted_pairs = sorted(zip(metric_keys, metrics), key=sort_key, reverse=True) + return [x[0] for x in sorted_pairs], [x[1] for x in sorted_pairs] + +def print_metrics( + metric_keys: MetricKeys, metrics: Metrics, *, normalize_out_metric: bool = False +): + out_metrics = [m[1].item() for m in metrics] + if normalize_out_metric: + normalizer = max(abs(m) for m in out_metrics) + out_metrics = [m / normalizer for m in out_metrics] + + for key, (in_metrics, _), out_metric in zip(metric_keys, metrics, out_metrics): + print(f"\n{key.primitive}:{source_info_util.summarize(key.source_info)}:") + print(f" In avals: {key.in_avals}") + print(f" In metrics: {in_metrics}") + print(f" Out metric: {out_metric}") diff --git a/tests/BUILD b/tests/BUILD index 4868bcf75e2e..6ae7b3657a25 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1206,6 +1206,12 @@ jax_multiplatform_test( srcs = ["key_reuse_test.py"], ) +jax_multiplatform_test( + name = "numerics_check_test", + srcs = ["numerics_check_test.py"], + enable_backends = ["cpu"], +) + jax_multiplatform_test( name = "roofline_test", srcs = ["roofline_test.py"], diff --git a/tests/numerics_check_test.py b/tests/numerics_check_test.py new file mode 100644 index 000000000000..74759cec757d --- /dev/null +++ b/tests/numerics_check_test.py @@ -0,0 +1,184 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from absl.testing import absltest + +import jax +import jax.lax as lax +import jax.numpy as jnp +from jax._src import test_util as jtu +from jax._src.lax.lax import _reduce_sum +from jax.experimental import numerics_check + +jax.config.parse_flags_with_absl() + +def _f32_to_bf16(arr: jax.Array) -> jax.Array: + arr = lax.reduce_precision(arr, exponent_bits=8, mantissa_bits=7) + arr = arr.astype(jnp.bfloat16) + return arr + + +def _mse(arr1: jax.Array, arr2: jax.Array) -> jax.Array: + return jnp.mean(jnp.square(arr1 - arr2.astype(arr1.dtype))) + + +def _g_dot_delta(g: jax.Array, arr1: jax.Array, arr2: jax.Array) -> jax.Array: + delta = arr1 - arr2.astype(arr1.dtype) + return jnp.sum(g * delta) + +class NumericsCheckTest(jtu.JaxTestCase): + def test_numerics_check(self): + def matmul(x, y): + return jnp.sum(x @ y) + + key = jax.random.key(42) + x = jax.random.uniform(key, (128, 128), dtype=jnp.float32) * 1e7 + y = ( + jax.random.uniform(jax.random.split(key)[1], (128, 128), dtype=jnp.float32) * 1e-7 + ) + args = x, y + check, source_metrics = numerics_check.numerics_check(matmul) + metric_keys = source_metrics(*args) + metrics = numerics_check.metric_keys_to_metrics(metric_keys) + result = check(metrics, *args) + expected = matmul(*args) + self.assertArraysEqual(result, expected) + + metrics: numerics_check.Metrics = jax.grad(check)(metrics, *args) + print("\n\nTrace Order:") + numerics_check.print_metrics(metric_keys, metrics) + + pjit_1_metrics, dot_metrics, pjit_2_metrics, sum_metrics = metrics + zero = jnp.zeros((), dtype=jnp.float32) + ones = jnp.ones((128, 128), dtype=jnp.float32) + one = jnp.ones((), dtype=jnp.float32) + + self.assertEqual(pjit_1_metrics, ((zero, zero), zero)) + self.assertEqual(pjit_2_metrics, ((zero,), zero)) + + dot_in_x = _mse(ones @ y, _f32_to_bf16(ones) @ _f32_to_bf16(y)) + dot_in_y = _mse(ones @ x, _f32_to_bf16(ones) @ _f32_to_bf16(x)) + self.assertAllClose(dot_metrics[0], (dot_in_x, dot_in_y)) + dot_out = _g_dot_delta(ones, x @ y, _f32_to_bf16(x) @ _f32_to_bf16(y)) + self.assertAllClose(dot_metrics[1], dot_out) + + self.assertEqual(sum_metrics[0], (zero,)) + sum_out = _g_dot_delta( + one, + _reduce_sum(x @ y, axes=(0, 1)), + _reduce_sum(_f32_to_bf16(x @ y), axes=(0, 1)), + ) + self.assertAllClose(sum_metrics[1], sum_out) + + print("\n\nSorted by In Metrics:") + metric_keys, metrics = numerics_check.sort_metrics_by_in_metrics( + metric_keys, metrics + ) + numerics_check.print_metrics(metric_keys, metrics) + print("\n\nSorted by Out Metric:") + metric_keys, metrics = numerics_check.sort_metrics_by_out_metric( + metric_keys, metrics + ) + numerics_check.print_metrics(metric_keys, metrics) + + def test_jit_numerics_check(self): + @jax.jit + def matmul(x, y): + return jnp.sum(x @ y) + + key = jax.random.key(42) + x = jax.random.uniform(key, (128, 128), dtype=jnp.float32) * 1e7 + y = ( + jax.random.uniform(jax.random.split(key)[1], (128, 128), dtype=jnp.float32) * 1e-7 + ) + args = x, y + check, source_metrics = numerics_check.numerics_check(matmul) + metric_keys = source_metrics(*args) + metrics = numerics_check.metric_keys_to_metrics(metric_keys) + result = jax.jit(check)(metrics, *args) + expected = matmul(*args) + self.assertArraysEqual(result, expected) + + metrics: numerics_check.Metrics = jax.jit(jax.grad(check))(metrics, *args) + print("\n\nTrace Order:") + numerics_check.print_metrics(metric_keys, metrics) + + pjit_1_metrics, dot_metrics, sum_metrics = metrics + zero = jnp.zeros((), dtype=jnp.float32) + ones = jnp.ones((128, 128), dtype=jnp.float32) + one = jnp.ones((), dtype=jnp.float32) + + self.assertEqual(pjit_1_metrics, ((zero, zero), zero)) + + dot_in_x = _mse(ones @ y, _f32_to_bf16(ones) @ _f32_to_bf16(y)) + dot_in_y = _mse(ones @ x, _f32_to_bf16(ones) @ _f32_to_bf16(x)) + self.assertAllClose(dot_metrics[0], (dot_in_x, dot_in_y)) + dot_out = _g_dot_delta(ones, x @ y, _f32_to_bf16(x) @ _f32_to_bf16(y)) + self.assertAllClose(dot_metrics[1], dot_out) + + self.assertEqual(sum_metrics[0], (zero,)) + sum_out = _g_dot_delta( + one, + _reduce_sum(x @ y, axes=(0, 1)), + _reduce_sum(_f32_to_bf16(x @ y), axes=(0, 1)), + ) + self.assertAllClose(sum_metrics[1], sum_out) + + print("\n\nSorted by In Metrics:") + metric_keys, metrics = numerics_check.sort_metrics_by_in_metrics( + metric_keys, metrics + ) + numerics_check.print_metrics(metric_keys, metrics) + print("\n\nSorted by Out Metric:") + metric_keys, metrics = numerics_check.sort_metrics_by_out_metric( + metric_keys, metrics + ) + numerics_check.print_metrics(metric_keys, metrics) + + def test_demo(self): + @jax.jit + def matmul_with_residual(x, ys): + def layer1(x, y): + return x + jax.nn.gelu(x @ y) + + def layer2(x, y): + return x + jax.nn.swish(x @ y) + + def layer3(x, y): + return x + jax.nn.sigmoid(x @ y) + + def layer4(x, y): + return x + jax.nn.tanh(x @ y) + + for l, y in zip([layer1, layer2, layer3, layer4], ys): + x = l(x, y) + return jnp.sum(x) + + key = jax.random.key(42) + x = jax.random.uniform(key, (128, 1024), dtype=jnp.float32) + y = jax.random.uniform(jax.random.split(key)[1], (1024, 1024), dtype=jnp.float32) + args = x, (y, y, y, y) + check, source_metrics = numerics_check.numerics_check(matmul_with_residual) + metric_keys = source_metrics(*args) + metrics = numerics_check.metric_keys_to_metrics(metric_keys) + metrics: numerics_check.Metrics = jax.jit(jax.grad(check))(metrics, *args) + metric_keys, metrics = numerics_check.sort_metrics_by_out_metric( + metric_keys, metrics + ) + numerics_check.print_metrics(metric_keys, metrics, normalize_out_metric=False) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())