Skip to content

Commit

Permalink
New Experimental JAX Check Numerics API.
Browse files Browse the repository at this point in the history
  • Loading branch information
epiqueras committed Jan 9, 2025
1 parent 6e1f060 commit 482209d
Show file tree
Hide file tree
Showing 6 changed files with 726 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ py_library_providing_imports_info(
"_src/state/**/*.py",
"_src/third_party/**/*.py",
"experimental/key_reuse/**/*.py",
"experimental/numerics_check/**/*.py",
"experimental/roofline/**/*.py",
"image/**/*.py",
"interpreters/**/*.py",
Expand Down
46 changes: 46 additions & 0 deletions jax/experimental/numerics_check/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.experimental.numerics_check.checks as checks
from jax.experimental.numerics_check.numerics_check import (
MetricKeys as MetricKeys,
)
from jax.experimental.numerics_check.numerics_check import (
Metrics as Metrics,
)
from jax.experimental.numerics_check.numerics_check import (
MetricsKey as MetricsKey,
)
from jax.experimental.numerics_check.numerics_check import (
MetricsValue as MetricsValue,
)
from jax.experimental.numerics_check.numerics_check import (
metric_keys_to_metrics as metric_keys_to_metrics,
)
from jax.experimental.numerics_check.numerics_check import (
numerics_check as numerics_check,
)
from jax.experimental.numerics_check.numerics_check import (
print_metrics as print_metrics,
)
from jax.experimental.numerics_check.numerics_check import (
register_numerics_check as register_numerics_check,
)
from jax.experimental.numerics_check.numerics_check import (
sort_metrics_by_in_metrics as sort_metrics_by_in_metrics,
)
from jax.experimental.numerics_check.numerics_check import (
sort_metrics_by_out_metric as sort_metrics_by_out_metric,
)

del checks
66 changes: 66 additions & 0 deletions jax/experimental/numerics_check/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from functools import partial
from typing import Any

from jax._src import core, pjit, typing
from jax.experimental.numerics_check.numerics_check import (
Val,
register_numerics_check,
NumericsCheckTrace,
numerics_check_subtrace,
)
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu


def _numerics_check_jaxpr_trace(
trace: NumericsCheckTrace,
jaxpr: core.ClosedJaxpr,
) -> tuple[core.ClosedJaxpr, list[Any]]:
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f, subtrace_thunk = numerics_check_subtrace(
f,
core.TraceTag(),
trace.high_precision_dtype,
trace.low_precision_dtype,
trace.next_metric_index,
trace.metrics,
)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_)
subtrace = subtrace_thunk()["trace"]
trace.metric_keys.extend(subtrace.metric_keys)
trace.next_metric_index = subtrace.next_metric_index
return core.ClosedJaxpr(jaxpr_, []), consts


@register_numerics_check(pjit.pjit_p)
def _pjit_numerics_check(
trace: NumericsCheckTrace,
in_metrics: tuple[typing.Array, ...],
out_metric: typing.Array,
*args: Val,
jaxpr: core.ClosedJaxpr,
**kwargs: Val,
) -> Val:
del in_metrics, out_metric
jaxpr, consts = _numerics_check_jaxpr_trace(trace, jaxpr)
kwargs["in_shardings"] = (pjit.UNSPECIFIED,) * len(consts) + kwargs["in_shardings"]
kwargs["in_layouts"] = (None,) * len(consts) + kwargs["in_layouts"]
kwargs["donated_invars"] = (False,) * len(consts) + kwargs["donated_invars"]
out_vals = pjit.pjit_p.bind(*consts, *args, jaxpr=jaxpr, **kwargs)
return out_vals
Loading

0 comments on commit 482209d

Please sign in to comment.