diff --git a/examples/ffi/tests/cpu_examples_test.py b/examples/ffi/tests/cpu_examples_test.py index cb2653d2e928..0e2cfde02db6 100644 --- a/examples/ffi/tests/cpu_examples_test.py +++ b/examples/ffi/tests/cpu_examples_test.py @@ -37,10 +37,10 @@ def test_array_attr_jit_cache(self): jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,)) with jtu.count_jit_and_pmap_lowerings() as count: jit_array_attr(5) - self.assertEqual(count[0], 1) # compiles once the first time + self.assertEqual(count(), 1) # compiles once the first time with jtu.count_jit_and_pmap_lowerings() as count: jit_array_attr(5) - self.assertEqual(count[0], 0) # cache hit + self.assertEqual(count(), 0) # cache hit def test_array_attr_no_jit(self): with jax.disable_jit(): diff --git a/jax/_src/array.py b/jax/_src/array.py index 7c5385f97e40..cf2cc6248bd8 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -33,6 +33,7 @@ from jax._src import dtypes from jax._src import errors from jax._src import profiler +from jax._src import util from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.interpreters import pxla @@ -1131,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): def _array_shard_arg(xs, shardings, layouts, copy_semantics): + util.test_event("_array_shard_arg") results = [] batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] batch_cs = [] @@ -1168,6 +1170,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): results.append( shard_sharded_device_array_slow_path(x, devices, indices, sharding)) + util.test_event("batched_copy_array") copy_outs = xc.batched_copy_array_to_devices_with_sharding( batch_xs, batch_devs, batch_shardings, batch_cs) for i, copy_out in safe_zip(batch_indices, copy_outs): diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 54c3a43e8a84..72e918241ea5 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params): @util.cache() def xla_primitive_callable(prim: core.Primitive, **params): + util.test_event("xla_primitive_callable_cache_miss") def prim_fun(*args): with config.eager_constant_folding(False): return prim.bind(*args, **params) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1dad0ab9720d..0e3fdea02301 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1085,6 +1085,7 @@ def lower_jaxpr_to_module( Handles the quirks of the argument/return value passing conventions of the runtime. """ + util.test_event("lower_jaxpr_to_module") platforms = tuple(map(xb.canonicalize_platform, platforms)) in_avals = (jaxpr.in_avals if arg_shardings is None else @@ -1378,6 +1379,7 @@ def lower_jaxpr_to_fun( Returns: MLIR func op """ + util.test_event("lower_jaxpr_to_fun", name) # The first dimension variable may be the platform index num_dim_vars = len(ctx.shape_poly_state.dim_vars) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 0f16bab19181..c9c1628a71b6 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -231,16 +231,20 @@ def _shard_mutable_array(xs, shardings, layouts, copy_semantics): def batched_device_put(aval: core.ShapedArray, sharding: JSharding, xs: Sequence[Any], devices: Sequence[jax.Device], committed: bool = True): - from jax._src import array - - bufs = [x for x, d in safe_zip(xs, devices) - if (isinstance(x, array.ArrayImpl) and - dispatch.is_single_device_sharding(x.sharding) and - x.devices() == {d})] - if len(bufs) == len(xs): - return array.ArrayImpl( - aval, sharding, bufs, committed=committed, _skip_checks=True) - return xc.batched_device_put(aval, sharding, xs, list(devices), committed) + util.test_event("batched_device_put_start") + try: + from jax._src import array + + bufs = [x for x, d in safe_zip(xs, devices) + if (isinstance(x, array.ArrayImpl) and + dispatch.is_single_device_sharding(x.sharding) and + x.devices() == {d})] + if len(bufs) == len(xs): + return array.ArrayImpl( + aval, sharding, bufs, committed=committed, _skip_checks=True) + return xc.batched_device_put(aval, sharding, xs, list(devices), committed) + finally: + util.test_event("batched_device_put_end") def _shard_aval(size, axis: int, aval): try: @@ -2850,6 +2854,7 @@ def from_hlo(name: str, mesh = i.mesh break + util.test_event("pxla_cached_compilation") xla_executable = _cached_compilation( hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e5279c7a19bc..0fabb8c21c52 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -549,6 +549,7 @@ def _infer_params_impl( kwargs: dict[str, Any], in_avals: tuple[core.AbstractValue, ...] | None, ) -> tuple[PjitParams, list[Any]]: + util.test_event("pjit._infer_params_impl", fun) have_kwargs = bool(kwargs) if have_kwargs and ji.user_specified_in_shardings: raise ValueError( @@ -1297,6 +1298,7 @@ def _create_pjit_jaxpr( ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + util.test_event("create_pjit_jaxpr") del ignored_inline # just for explain_cache_miss if config.no_tracing.value: raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but " @@ -1784,6 +1786,7 @@ def _pjit_lower( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): + util.test_event("pjit_lower") if config.sharding_in_types.value: mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' else: diff --git a/jax/_src/stages.py b/jax/_src/stages.py index db26813de8bc..10963330ff92 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -533,6 +533,7 @@ def output_layouts(self): @staticmethod def call(*args, **kwargs): + util.test_event("stages_compiled_call") # This is because `__call__` passes in `self._params` as the first argument. # Instead of making the call signature `call(params, *args, **kwargs)` # extract it from args because `params` can be passed as a kwarg by users diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index ce418b686603..4e2a4f4fdeee 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -30,6 +30,7 @@ import sys import tempfile import textwrap +import threading from typing import Any, TextIO import unittest import warnings @@ -40,22 +41,17 @@ import jax from jax import lax from jax._src import api -from jax._src import array from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes as _dtypes from jax._src import lib as _jaxlib -from jax._src import linear_util as lu from jax._src import monitoring -from jax._src import pjit as pjit_lib -from jax._src import stages from jax._src import xla_bridge +from jax._src import util from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir -from jax._src.interpreters import pxla -from jax._src.lib import xla_client as xc from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, @@ -235,223 +231,109 @@ def get_output() -> str: capture_stderr = partial(_capture_output, sys.stderr) -@contextmanager -def count_device_put(): - batched_device_put = pxla.batched_device_put - count = [0] - - def make_fn_and_count(fn): - def fn_and_count(*args, **kwargs): - count[0] += 1 - # device_put handlers might call `dispatch.device_put` (e.g. on an - # underlying payload or several). We only want to count these - # recursive puts once, so we skip counting more than the outermost - # one in such a call stack. - pxla.batched_device_put = batched_device_put - try: - return fn(*args, **kwargs) - finally: - pxla.batched_device_put = batched_device_put_and_count - return fn_and_count - - batched_device_put_and_count = make_fn_and_count(batched_device_put) - - pxla.batched_device_put = batched_device_put_and_count - try: - yield count - finally: - pxla.batched_device_put = batched_device_put - - -@contextmanager -def count_primitive_compiles(): - dispatch.xla_primitive_callable.cache_clear() +class EventThreadLocalState(threading.local): + def __init__(self): + self.counts = {} # Mapping from string name to count. + self.nested_device_put_count = 0 # Number of recursive calls to device_put - count = [-1] - try: - yield count - finally: - count[0] = dispatch.xla_primitive_callable.cache_info().misses + # Per-function counts + self.infer_params_fun_counts = None + self.lower_jaxpr_to_fun_counts = None +thread_local_state = EventThreadLocalState() -@contextmanager -def count_device_put_fast_path_hit(): - original_fn = xc.batched_copy_array_to_devices_with_sharding - count = [0] - def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs): - count[0] += 1 - return original_fn(*args, **kwargs) +def event_listener(name, *args): + counts = thread_local_state.counts + counts[name] = counts.get(name, 0) + 1 - xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count - try: - yield count - finally: - xc.batched_copy_array_to_devices_with_sharding = original_fn + # device_put handlers might call `dispatch.device_put` (e.g. on an + # underlying payload or several). We only want to count these + # recursive puts once, so we skip counting more than the outermost + # one in such a call stack. + if name == "batched_device_put_start": + if thread_local_state.nested_device_put_count == 0: + counts["batched_device_put"] = counts.get("batched_device_put", 0) + 1 + thread_local_state.nested_device_put_count += 1 + elif name == "batched_device_put_end": + thread_local_state.nested_device_put_count -= 1 + elif name == "pjit._infer_params_impl": + # For infer_params, we collect per-function data, but only while a context + # manager is active. + infer_counts = thread_local_state.infer_params_fun_counts + if infer_counts is not None: + (fun,) = args + infer_counts[fun] += 1 + elif name == "lower_jaxpr_to_fun": + # For infer_params, we collect per-function data, but only while a context + # manager is active. + lower_counts = thread_local_state.lower_jaxpr_to_fun_counts + if lower_counts is not None: + (fun,) = args + lower_counts[fun] += 1 -@contextmanager -def count_pjit_cpp_cache_miss(): - original_pjit_lower = pjit_lib._pjit_lower - count = [0] - def pjit_lower_and_count(*args, **kwargs): - count[0] += 1 - return original_pjit_lower(*args, **kwargs) +util.test_event_listener = event_listener - pjit_lib._pjit_lower = pjit_lower_and_count - try: - yield count - finally: - pjit_lib._pjit_lower = original_pjit_lower -@contextmanager -def count_cached_compilation_cache_miss(): - original_cached_compilation = pxla._cached_compilation - count = [0] +def count_events(event): + "Returns a context-manager that yields a function that counts a test event." + @contextmanager + def count_event(): + before = thread_local_state.counts.get(event, 0) + yield lambda: thread_local_state.counts.get(event, 0) - before + return count_event - def cached_compilation_and_count(*args, **kwargs): - count[0] += 1 - return original_cached_compilation(*args, **kwargs) +count_device_put = count_events("batched_device_put") +count_device_put_fast_path_hit = count_events("batched_copy_array") +count_pjit_cpp_cache_miss = count_events("pjit_lower") +count_jit_tracing_cache_miss = count_events("create_pjit_jaxpr") +count_aot_jit_cpp_cache_miss = count_events("stages_compiled_call") +count_jit_and_pmap_lowerings = count_events("lower_jaxpr_to_module") +count_jit_compilation_cache_miss = count_events("pxla_cached_compilation") +count_jax_array_shard_arg_calls = count_events("_array_shard_arg") - pxla._cached_compilation = cached_compilation_and_count - try: - yield count - finally: - pxla._cached_compilation = original_cached_compilation @contextmanager -def count_jit_tracing_cache_miss(): - original_create_pjit_jaxpr = pjit_lib._create_pjit_jaxpr - count = [0] - - @lu.cache - def create_pjit_jaxpr_and_count(*args): - count[0] += 1 - return original_create_pjit_jaxpr(*args) +def count_primitive_compiles(): + dispatch.xla_primitive_callable.cache_clear() - pjit_lib._create_pjit_jaxpr = create_pjit_jaxpr_and_count + count = [-1] try: - yield count + yield lambda: count[0] finally: - pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr + count[0] = dispatch.xla_primitive_callable.cache_info().misses @contextmanager def count_jit_infer_params_cache_miss(): - original_infer_params_impl = pjit_lib._infer_params_impl - count = collections.defaultdict(int) - - def infer_params_impl_and_count(fun, *args, **kw): - count[fun] += 1 - return original_infer_params_impl(fun, *args, **kw) - - pjit_lib._infer_params_impl = infer_params_impl_and_count + assert thread_local_state.infer_params_fun_counts is None + counts = collections.Counter() + thread_local_state.infer_params_fun_counts = counts try: - yield count + yield counts finally: - pjit_lib._infer_params_impl = original_infer_params_impl - + thread_local_state.infer_params_fun_counts = None @contextmanager -def count_aot_jit_cpp_cache_miss(): - original_call = stages.Compiled.call - count = [0] - - def compiled_call_count(*args, **kwargs): - count[0] += 1 - return original_call(*args, **kwargs) - - stages.Compiled.call = compiled_call_count +def count_subjaxpr_to_hlo_conversion(fun_name): + assert thread_local_state.lower_jaxpr_to_fun_counts is None + counts = collections.Counter() + thread_local_state.lower_jaxpr_to_fun_counts = counts try: - yield count + yield lambda: counts[fun_name] finally: - stages.Compiled.call = original_call + thread_local_state.lower_jaxpr_to_fun_counts = None -@contextmanager -def count_jit_and_pmap_lowerings(): - # No need to clear any caches since we generally jit and pmap fresh callables - # in tests. - - mlir_lower = mlir.lower_jaxpr_to_module - count = [0] - - def mlir_lower_and_count(*args, **kwargs): - count[0] += 1 - return mlir_lower(*args, **kwargs) - - mlir.lower_jaxpr_to_module = mlir_lower_and_count - try: - yield count - finally: - mlir.lower_jaxpr_to_module = mlir_lower - - -@contextmanager -def count_jax_array_shard_arg_calls(): - # No need to clear any caches since we generally jit and pmap fresh callables - # in tests. - - array_shard_arg = array._array_shard_arg - count = [0] - - def array_shard_arg_and_count(*args, **kwargs): - count[0] += 1 - return array_shard_arg(*args, **kwargs) - - pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg_and_count - try: - yield count - finally: - pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg - - -@contextmanager -def count_jit_compilation_cache_miss(): - # No need to clear any caches since we generally jit and pmap fresh callables - # in tests. - - jit_compilation = pxla._cached_compilation - count = [0] - - def compile_and_count(*args, **kwargs): - count[0] += 1 - return jit_compilation(*args, **kwargs) - - pxla._cached_compilation = compile_and_count - try: - yield count - finally: - pxla._cached_compilation = jit_compilation - - -@contextmanager -def count_subjaxpr_to_hlo_conversion(fun_name: str): - # No need to clear any caches since we generally jit and pmap fresh callables - # in tests. - - mlir_lower = mlir.lower_jaxpr_to_fun - count = [0] - - def mlir_lower_and_count(ctx, name, *args, **kwargs): - if name == fun_name: - count[0] += 1 - return mlir_lower(ctx, name, *args, **kwargs) - - mlir.lower_jaxpr_to_fun = mlir_lower_and_count - try: - yield count - finally: - mlir.lower_jaxpr_to_fun = mlir_lower - @contextmanager def assert_num_jit_and_pmap_compilations(times): with count_jit_and_pmap_lowerings() as count: yield - if count[0] != times: + if count() != times: raise AssertionError(f"Expected exactly {times} XLA compilations, " - f"but executed {count[0]}") + f"but executed {count()}") def jaxlib_version() -> tuple[int, ...]: diff --git a/jax/_src/util.py b/jax/_src/util.py index 8dcc5eaa5804..f262570b9a1b 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -676,3 +676,12 @@ def register(cls, subclass): class StrictABC(metaclass=StrictABCMeta): __slots__ = () + + + +test_event_listener: Callable | None = None + +def test_event(name: str, *args) -> None: + if not test_event_listener: + return + test_event_listener(name, *args) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 5567f2f765c1..8e2883ea4dea 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -24,6 +24,7 @@ import jax from jax._src import api from jax._src import tree_util +from jax._src import util from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc from jax._src.traceback_util import api_boundary @@ -283,6 +284,7 @@ def _get_specialized_func( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: """Returns a specialized function for the given specialization.""" + util.test_event("colocated_python_func._get_specialized_func") assert specialization.in_specs_treedef is not None assert specialization.in_specs_leaves is not None assert specialization.devices is not None diff --git a/tests/api_test.py b/tests/api_test.py index 38467809e9d3..4e68a3edb5b1 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1081,7 +1081,7 @@ def f(*args): args = range(num_args) with jtu.count_device_put() as count: np.testing.assert_allclose(f_pruned(*args), 3) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def testBuffersAreFreedPromptly(self): # Regression test for a bug where garbage collection was delayed too long @@ -1246,7 +1246,7 @@ def test_jit_lower_no_pruning(self): jitted_f = jit(lambda x, y: x, keep_unused=True) with jtu.count_pjit_cpp_cache_miss() as count: _ = jitted_f(1, 2) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_jit_lower_compile_compiler_ir(self): f = jit(lambda x: x + 4).lower(1.).compile() @@ -1428,7 +1428,7 @@ def f(x): with jtu.count_jit_compilation_cache_miss() as count: jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.) jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.) - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) # We should still error on invalid options after some valid compiles with self.assertRaisesRegex( @@ -1511,7 +1511,7 @@ def test_caches_dont_depend_on_unnamed_axis_env(self): expected = f() with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = jax.vmap(f, axis_size=2, out_axes=None)() - self.assertEqual(count[0], 0) # no compiles + self.assertEqual(count(), 0) # no compiles self.assertArraysAllClose(ans, expected, check_dtypes=True) def test_cache_key_defaults(self): @@ -2737,7 +2737,7 @@ def f(x): jax.eval_shape(f, inp) jax.jit(f)(inp) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_jit_infer_params_cache(self): def f(x): @@ -3384,12 +3384,12 @@ def f(x): with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _ = jax.grad(f)(3.) - self.assertEqual(count[0], 2) # one for fwd, one for bwd + self.assertEqual(count(), 2) # one for fwd, one for bwd with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _ = jax.grad(f)(3.) _ = jax.grad(f)(4.) - self.assertEqual(count[0], 0) # cache hits on both fwd and bwd + self.assertEqual(count(), 0) # cache hits on both fwd and bwd def test_grad_does_not_unflatten_tree_with_none(self): # https://github.com/jax-ml/jax/issues/7546 @@ -3458,7 +3458,7 @@ def test_primitive_compilation_cache(self): with jtu.count_primitive_compiles() as count: lax.add(1, 2) lax.add(2, 3) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_arange_jit(self): # see https://github.com/jax-ml/jax/issues/553 @@ -4021,7 +4021,7 @@ def test_eval_shape_weak_type(self): jax.eval_shape(jax.numpy.array, 1) out = jax.eval_shape(jax.numpy.array, 1) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertTrue(out.weak_type) self.assertEqual(out.weak_type, arr.weak_type) @@ -4296,19 +4296,19 @@ def test_vmap_caching(self): for _ in range(5): jax.hessian(jf)(x).block_until_ready() - n = count[0] + n = count() # The exact number of compilations may vary depending on the number of # jit decorators in the function above, but it should not grow after an # initial warmup phase. for _ in range(5): jax.hessian(jf)(x).block_until_ready() - self.assertEqual(count[0], n) + self.assertEqual(count(), n) def test_jnp_array_doesnt_device_put(self): with jtu.count_device_put() as count: api.make_jaxpr(lambda: jnp.array(3))() - self.assertEqual(count[0], 0) + self.assertEqual(count(), 0) def test_rank_promotion_forces_retrace(self): num_traces = 0 @@ -5910,7 +5910,7 @@ def test_linearize_caching(self): with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): f_lin(1.).block_until_ready() - self.assertEqual(count[0], 1) # cached after first execution + self.assertEqual(count(), 1) # cached after first execution def test_vjp_caching(self): # https://github.com/jax-ml/jax/issues/9661 @@ -5919,7 +5919,7 @@ def test_vjp_caching(self): with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready() - self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd + self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd def test_vjp_caching_static_argnums(self): identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x), @@ -5928,7 +5928,7 @@ def test_vjp_caching_static_argnums(self): with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready() - self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd + self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd def test_fwd_caching(self): # see above test also @@ -5937,7 +5937,7 @@ def test_fwd_caching(self): for _ in range(20): y, _ = jax.vjp(identity, 1.) y.block_until_ready() - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_fwd_caching_static_argnums(self): # see above test also @@ -5946,7 +5946,7 @@ def test_fwd_caching_static_argnums(self): for _ in range(20): y = identity(1.) y.block_until_ready() - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 2ddd5870268d..6a1660b28578 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -858,7 +858,7 @@ def test_retracing(self): _ = f(3.) with jtu.count_jit_and_pmap_lowerings() as count: _ = f(3.) - self.assertEqual(count[0], 0) + self.assertEqual(count(), 0) def test_goodfellow_custom_jvp(self): def h(fext): diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index bbd5c38068f3..43a887b8ea55 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -24,7 +24,6 @@ from jax._src import test_util as jtu from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member from jax.experimental import colocated_python -from jax.experimental.colocated_python import func as colocated_python_func from jax.experimental.colocated_python import serialization from jax.extend.ifrt_programs import ifrt_programs import jax.numpy as jnp @@ -52,23 +51,8 @@ def _colocated_cpu_devices( ] -@contextlib.contextmanager -def _count_colocated_python_specialization_cache_miss() -> list[int]: - """Counts the number of cache misses for colocated_python specialization.""" - original_get_specialized_func = colocated_python_func._get_specialized_func - count = [0] - - @jax.util.cache(max_size=None) - def get_specialized_func(*args, **kwargs): - count[0] += 1 - return original_get_specialized_func(*args, **kwargs) - - colocated_python_func._get_specialized_func = get_specialized_func - try: - yield count - finally: - colocated_python_func._get_specialized_func = original_get_specialized_func - +_count_colocated_python_specialization_cache_miss = jtu.count_events( + "colocated_python_func._get_specialized_func") _exit_stack = contextlib.ExitStack() @@ -117,12 +101,12 @@ def add_one(x): out = add_one(x) out = jax.device_get(out) self.assertEqual(out, np.array(2)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) out = add_one(x) out = jax.device_get(out) self.assertEqual(out, np.array(2)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def testSimpleFunctioWithTree(self): @colocated_python.colocated_python @@ -137,12 +121,12 @@ def add_one(x): out = add_one(x) out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) out = add_one(x) out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def testEmptyInputFailsWithoutSpecialization(self): @colocated_python.colocated_python @@ -168,12 +152,12 @@ def make_zero(): out = make_zero() out = jax.device_get(out) self.assertEqual(out, np.array(0)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) out = make_zero() out = jax.device_get(out) self.assertEqual(out, np.array(0)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def testInputPolymorphismWithoutOutSpecsFn(self): @colocated_python.colocated_python @@ -188,12 +172,12 @@ def add_one(x): out = add_one(x) out = jax.device_get(out) self.assertEqual(out, np.array(2)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) out = add_one(x) out = jax.device_get(out) self.assertEqual(out, np.array(2)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) # Different input tree structure and dtype/shape. x = [np.array(1), (np.array(2), {"v": np.array(3)})] @@ -202,12 +186,12 @@ def add_one(x): out = add_one(x) out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) out = add_one(x) out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) def testInputPolymorphismAllowedWithOutSpecsFn(self): @colocated_python.colocated_python @@ -223,12 +207,12 @@ def add_one(x): out = add_one(x) out = jax.device_get(out) self.assertEqual(out, np.array(2)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) out = add_one(x) out = jax.device_get(out) self.assertEqual(out, np.array(2)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) # Different input tree structure and dtype/shape. x = [np.array(1), (np.array(2), {"v": np.array(3)})] @@ -237,12 +221,12 @@ def add_one(x): out = add_one(x) out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) out = add_one(x) out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) @parameterized.named_parameters( ("on_main_thread", True), diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 4b0420fda8f9..1f04a788914d 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2621,7 +2621,7 @@ def g(x): lax.cond(x, f, g, x) # Should observe a maximum of 4 compiles: convert_element_type, f, g, cond # In #14058, this was observed to be 31 compiles. - self.assertLess(count[0], 5) + self.assertLess(count(), 5) @parameterized.named_parameters( {"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype} diff --git a/tests/layout_test.py b/tests/layout_test.py index afddab916723..ad2746a55deb 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -87,7 +87,7 @@ def init(x, y): with jtu.count_aot_jit_cpp_cache_miss() as init_count: init_out = init_compiled(arr1, arr2) init_compiled(arr1, arr2) - self.assertEqual(init_count[0], 1) + self.assertEqual(init_count(), 1) self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0]) self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1]) @@ -95,7 +95,7 @@ def init(x, y): with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) - self.assertEqual(apply_count[0], 1) + self.assertEqual(apply_count(), 1) self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0]) self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1]) diff --git a/tests/memories_test.py b/tests/memories_test.py index fcb1d6bdc2b1..c92ea0ba3895 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1308,7 +1308,7 @@ def test_jit_cpp_cache_hit(self): with jtu.count_pjit_cpp_cache_miss() as count: out = f(inp) out2 = f(inp2) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertArraysEqual(out, np_inp @ np_inp.T) self.assertArraysEqual(out2, np_inp @ np_inp.T) @@ -1329,8 +1329,8 @@ def test_jit_compilation_cache_hit(self): jtu.count_jit_and_pmap_lowerings() as compile_count): f(inp) f(inp2) - self.assertEqual(cpp_count[0], 2) - self.assertEqual(compile_count[0], 1) + self.assertEqual(cpp_count(), 2) + self.assertEqual(compile_count(), 1) def test_jit_cpp_cache_output_hit(self): _, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device") @@ -1342,7 +1342,7 @@ def mul_two(x): with jtu.count_pjit_cpp_cache_miss() as count: out = mul_two(inp) mul_two(out) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_jit_cache_hit_with_default_and_specified_mem_kind(self): _, s, np_inp, _ = _create_inputs((8, 2), P("x", "y")) @@ -1357,7 +1357,7 @@ def mul(x): with jtu.count_jit_and_pmap_lowerings() as count: out = f(np_inp) out2 = g(np_inp2) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertArraysEqual(out, np_inp @ np_inp.T) self.assertArraysEqual(out2, np_inp2 @ np_inp2.T) @@ -1633,7 +1633,7 @@ def f(x): f(inp) # 2 for `f` and `2` for `mul` (compute type changes for `mul`) - self.assertEqual(count[0], 4) + self.assertEqual(count(), 4) def test_offload_take_host(self): @compute_on('device_host') diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index a3b6b1efaa76..057731cb5d55 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -141,7 +141,7 @@ def test_primitive_compilation_cache(self): y = lax.add(x, x) z = lax.add(y, y) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assert_committed_to_device(y, devices[1]) self.assert_committed_to_device(z, devices[1]) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index fb144cacbc98..dbf67b1421ab 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -149,14 +149,14 @@ def f(x): with config.pgle_profiling_runs(2), config.enable_pgle(True): # Run 1: Module should be compiled without FDO. Two modules are expected # One is the funtion f, the other one is multi slice module - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 2) + self.assertEqual(cache_miss_count(), 2) # Run 2: Second PGLE run. Profile should be empty. - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 2) + self.assertEqual(cache_miss_count(), 2) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) # One for before and one for after optimization. self.assertLen(fdo_profiles_before_pgle, 2) @@ -165,9 +165,9 @@ def f(x): os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) # Run 3: The module should be recompiled with FDO profiles - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 2) + self.assertEqual(cache_miss_count(), 2) fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) # One for before and one for after optimization. self.assertLen(fdo_profiles_after_pgle, 4) @@ -179,9 +179,9 @@ def f(x): ) # Run 4: Fast-path should be used after PGLE is done - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertLess(cache_miss_count[0], 2) + self.assertLess(cache_miss_count(), 2) def testAutoPgleWithAot(self): @jax.jit @@ -197,14 +197,14 @@ def f(x): with config.pgle_profiling_runs(1), config.enable_pgle(True): # Run 1 - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(compiled(x), expected) - self.assertEqual(cache_miss_count[0], 0) + self.assertEqual(cache_miss_count(), 0) # Run 2 - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(compiled(x), expected) - self.assertEqual(cache_miss_count[0], 0) + self.assertEqual(cache_miss_count(), 0) def testAutoPgleWithPersistentCache(self): its = 50 @@ -243,24 +243,24 @@ def f(x): cc.reset_cache() cc.set_cache_dir(cache_dir) # Run 1: Module should be compiled without FDO - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) - self.assertGreater(cache_miss_count[0], 0) + self.assertGreater(cache_miss_count(), 0) # Non-pgle profiled version of module should be saved non_pgle_profiled_files = os.listdir(cache_dir) self.assertNotEmpty(non_pgle_profiled_files) # Run 2: Compilation should not be called - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) - self.assertGreater(cache_miss_count[0], 0) + self.assertGreater(cache_miss_count(), 0) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) # Run 3: Module should be compiled with FDO and stored to persistent cache - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with jtu.count_jit_compilation_cache_miss() as cache_miss_count: f(x) - self.assertGreater(cache_miss_count[0], 0) + self.assertGreater(cache_miss_count(), 0) # Check if FDO profile file of the biggest module is not empty fdo_profiles_after_pgle = [ diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2bb8af8d7363..ad57434f1beb 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -660,7 +660,7 @@ def testAutodiffCache(self): jax.grad(f)(x) # Warm up the cache. with jtu.count_pjit_cpp_cache_miss() as count: jax.grad(f)(x) - self.assertEqual(count[0], 0) # no cache miss i.e. cache hit + self.assertEqual(count(), 0) # no cache miss i.e. cache hit @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -2200,7 +2200,7 @@ def test_pjit_single_device_sharding_cache(self): with jtu.count_pjit_cpp_cache_miss() as count: out = f(a) _ = f(out) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_pjit_different_device_recompilation(self): if jax.device_count() < 2: @@ -2217,7 +2217,7 @@ def test_pjit_different_device_recompilation(self): with jtu.count_jit_compilation_cache_miss() as count: out1 = f(a) out2 = f(b) - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) self.assertArraysEqual(out1, val1) self.assertArraysEqual(out2, val2) @@ -2645,7 +2645,7 @@ def test_single_device_pjit_cpp_dispatch(self): inp_data, jax.sharding.NamedSharding(mesh, P('x'))) with mesh: f(arr1) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_single_device_add_single_compile(self): f1 = pjit(lambda x, y: x + y) @@ -2657,7 +2657,7 @@ def test_single_device_add_single_compile(self): with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(2): f1(a, b) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_global_array_to_host_local_array_already_host_local(self): inp_shape = (8, 2) @@ -2791,7 +2791,7 @@ def f(y, **kwargs): f_names = pjit(f, static_argnames='x') f_names(y, x='foo') f_names(y, x='foo') - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_new_static_argnum_on_keyword_arguments(self): f = pjit(lambda x: x, static_argnums=0) @@ -2834,7 +2834,7 @@ def test_pjit_different_default_device(self): # The count here is 0 because before `count_pjit_cpp_cache_miss`, `f` was # called with `system_default_device` and `test_device` so it was added # to the cache. Subsequent calls hit the C++ cache. - self.assertEqual(count[0], 0) + self.assertEqual(count(), 0) def test_pjit_with_mismatched_static_argnames(self): x_is_tracer, y_is_tracer = False, False @@ -3287,7 +3287,7 @@ def f(x): with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): pjit(f)(inp) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_pjit_no_global_cache_hit_axis_resources(self): mesh = jtu.create_mesh((1,), ('x',)) @@ -3297,20 +3297,20 @@ def test_pjit_no_global_cache_hit_axis_resources(self): with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)(inp) - self.assertEqual(count[0], 10) + self.assertEqual(count(), 10) with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): pjit(lambda x: x * 2, device=jax.devices()[0])(inp) - self.assertEqual(count[0], 10) + self.assertEqual(count(), 10) pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s) with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): pf(inp) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): @@ -3318,7 +3318,7 @@ def test_pjit_no_global_cache_hit_axis_resources(self): with jtu.count_pjit_cpp_cache_miss() as count: for _ in range(10): pf1(inp) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_with_sharding_constraint_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2, 2), ('replica', 'data', 'mdl')) @@ -3484,7 +3484,7 @@ def mul(x): out = f(arr) self.assertIsInstance(out.sharding, NamedSharding) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) with jtu.count_pjit_cpp_cache_miss() as count: out2 = f2(arr) @@ -3493,7 +3493,7 @@ def mul(x): out2 = f2(arr) self.assertIsInstance(out2.sharding, PositionalSharding) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertEqual(cache_info2.hits, cache_info1.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) @@ -3530,7 +3530,7 @@ def mul(x): self.assertIsInstance(out2.sharding, NamedSharding) # Drops out of C++ cache i.e. cache miss - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) # Still gets a hit on pjit_lower cache. self.assertEqual(cache_info2.hits, cache_info1.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) @@ -3600,7 +3600,7 @@ def test_sharding_on_output_with_vmap(self): out3 = vf(out2) self.assertIsInstance(out3.sharding, NamedSharding) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_jit_mul_sum_sharding_preserved(self): if config.use_shardy_partitioner.value: @@ -3625,7 +3625,7 @@ def test_jit_mul_sum_sharding_preserved(self): # This will hit the cpp cache. out3 = f(out2) self.assertIsInstance(out3.sharding, PositionalSharding) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertEqual(cache_info2.hits, cache_info1.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) @@ -3755,7 +3755,7 @@ def top(x): top(jnp.arange(8)) # The count should be 1 because `nest`'s lowering to MHLO should be cached. - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_wsc_eager(self): mesh = jtu.create_mesh((2,), ('x',)) @@ -3792,7 +3792,7 @@ def f(x): with jtu.count_pjit_cpp_cache_miss() as count: out1 = core.jaxpr_as_fun(jaxpr)(inp) out2 = core.jaxpr_as_fun(jaxpr)(inp) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertArraysEqual(out1[0], inp * 2) self.assertArraysEqual(out2[0], inp * 2) @@ -3970,7 +3970,7 @@ def g(a): with jtu.count_jit_and_pmap_lowerings() as count: g(np.arange(8)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_lowering_cache_miss_different_devices_and_sharding(self): if jax.device_count() < 4: @@ -3993,7 +3993,7 @@ def g(a): with jtu.count_jit_and_pmap_lowerings() as count: g(np.arange(8)) - self.assertEqual(count[0], 2) + self.assertEqual(count(), 2) def test_single_device_named_sharding_preserved(self): mesh = jax.sharding.Mesh([jax.devices()[0]], 'x') @@ -4021,7 +4021,7 @@ def test_mpmd_device_put_fast_path(self): # that is expected. with jtu.count_device_put_fast_path_hit() as count: out = jax.device_put(arr1, NamedSharding(mesh2, P('x'))) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertTupleEqual(out.sharding._device_assignment, mesh2._flat_devices_tuple) self.assertArraysEqual(out, inp) @@ -4536,9 +4536,9 @@ def f(x): # same num_devices but different devices. b = jax.device_put(out_a, NamedSharding(mesh2, P())) f(b) # tracing and lowering cache *hit* - self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin` - self.assertEqual(lowering_count[0], 1) - self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. + self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin` + self.assertEqual(lowering_count(), 1) + self.assertEqual(compilation_count(), 2) # 2 misses since devices differ. def test_wsc_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -4621,7 +4621,7 @@ def f(x): with jtu.count_pjit_cpp_cache_miss() as count: jax.jit(f, out_shardings=s)(np.arange(8)) jax.jit(f, out_shardings=s)(np.arange(8)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_input_shardings_single_device(self): @jax.jit diff --git a/tests/pmap_test.py b/tests/pmap_test.py index f611ee981335..a9de8c896414 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1287,7 +1287,7 @@ def testPmapConstant(self): x = jnp.arange(device_count) with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) - # self.assertEqual(count[0], 0) # TODO(mattjj): fix this + # self.assertEqual(count(), 0) # TODO(mattjj): fix this expected = np.repeat(3, device_count) self.assertAllClose(ans, expected, check_dtypes=False) @@ -1308,7 +1308,7 @@ def testPmapConstantDevices(self): x = jnp.arange(len(devices)) with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) - # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants + # self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants expected = np.repeat(3, len(devices)) self.assertAllClose(ans, expected, check_dtypes=False) @@ -1344,7 +1344,7 @@ def testNestedPmapConstant(self): x = jnp.arange(math.prod(shape)).reshape(shape) with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) - # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants + # self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants expected = 3 * np.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) @@ -1370,7 +1370,7 @@ def testNestedPmapConstantDevices(self): x = jnp.arange(math.prod(shape)).reshape(shape) with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) - # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants + # self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants expected = 3 * np.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) @@ -2043,7 +2043,7 @@ def f(x): _, f_bwd2 = jax.vjp(f, x) _ = f_bwd(x) _ = f_bwd2(x) - self.assertEqual(count[0], 0) # cache hits on fwd and bwd + self.assertEqual(count(), 0) # cache hits on fwd and bwd def testSizeOverflow(self): if config.disable_jit.value: diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 78be353dd345..199b90fe524e 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -588,7 +588,7 @@ def fun(x): with jtu.count_primitive_compiles() as count: for _ in range(3): self.assertAllClose(2 * x, fun(x)) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) class PureCallbackTest(jtu.JaxTestCase): diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 63510b7295d6..af7a3fb7a3b7 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1041,7 +1041,7 @@ def test_random_split_doesnt_device_put_during_tracing(self): key = self.make_key(1).block_until_ready() with jtu.count_device_put() as count: jax.jit(random.split)(key) - self.assertLessEqual(count[0], 1) # 1 for the argument device_put + self.assertLessEqual(count(), 1) # 1 for the argument device_put @jtu.sample_product(dtype=int_dtypes + uint_dtypes) def test_randint_bounds(self, dtype): diff --git a/tests/random_test.py b/tests/random_test.py index d8c5a70b995a..a51e387dca76 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -709,7 +709,7 @@ def f(key): f(key).block_until_ready() f(key).block_until_ready() - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) # TODO(jakevdp) remove this decorator when reuse checks move to C++ @jax.debug_key_reuse(False) @@ -726,7 +726,7 @@ def f(key): f(key).block_until_ready() f(key).block_until_ready() - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_cpp_dispatch_aot_normal(self): # Ensure we stay on the C++ dispatch path when calling an @@ -739,7 +739,7 @@ def test_cpp_dispatch_aot_normal(self): f(key).block_until_ready() f(key).block_until_ready() - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) def test_cpp_dispatch_aot_split(self): # Ensure we stay on the C++ dispatch path when calling an @@ -753,7 +753,7 @@ def test_cpp_dispatch_aot_split(self): f(key).block_until_ready() f(key).block_until_ready() - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) # -- prng primitives diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index c5f80a6d97f3..383746899570 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -201,7 +201,7 @@ def f(x): with jtu.count_pjit_cpp_cache_miss() as count: f(np_inp) out1, out2 = f(np_inp) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertTrue(s.is_equivalent_to(out1.sharding, np_inp.ndim)) self.assertTrue(s.is_equivalent_to(out2.sharding, np_inp.ndim)) @@ -213,7 +213,7 @@ def g(x): with jtu.count_pjit_cpp_cache_miss() as count: g(arr) out3, out4 = g(arr) - self.assertEqual(count[0], 1) + self.assertEqual(count(), 1) self.assertEqual(out3.sharding, s) self.assertEqual(out4.sharding, s) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 56cf9987911d..6d799b52d1ef 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -825,9 +825,9 @@ def f(x): b = jax.device_put(out_a, NamedSharding(mesh2, P())) f(b) # tracing and lowering cache *hit* - self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin` - self.assertEqual(lowering_count[0], 1) - self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. + self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin` + self.assertEqual(lowering_count(), 1) + self.assertEqual(compilation_count(), 2) # 2 misses since devices differ. def test_shmap_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',))