Skip to content

Commit

Permalink
Merge pull request #25426 from hawkinsp:tls
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705510277
  • Loading branch information
Google-ML-Automation committed Dec 12, 2024
2 parents 3630756 + 62e66b6 commit 99d675a
Show file tree
Hide file tree
Showing 25 changed files with 213 additions and 321 deletions.
4 changes: 2 additions & 2 deletions examples/ffi/tests/cpu_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def test_array_attr_jit_cache(self):
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 1) # compiles once the first time
self.assertEqual(count(), 1) # compiles once the first time
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 0) # cache hit
self.assertEqual(count(), 0) # cache hit

def test_array_attr_no_jit(self):
with jax.disable_jit():
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import dtypes
from jax._src import errors
from jax._src import profiler
from jax._src import util
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
Expand Down Expand Up @@ -1131,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):


def _array_shard_arg(xs, shardings, layouts, copy_semantics):
util.test_event("_array_shard_arg")
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
batch_cs = []
Expand Down Expand Up @@ -1168,6 +1170,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))

util.test_event("batched_copy_array")
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
for i, copy_out in safe_zip(batch_indices, copy_outs):
Expand Down
1 change: 1 addition & 0 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params):

@util.cache()
def xla_primitive_callable(prim: core.Primitive, **params):
util.test_event("xla_primitive_callable_cache_miss")
def prim_fun(*args):
with config.eager_constant_folding(False):
return prim.bind(*args, **params)
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,7 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
util.test_event("lower_jaxpr_to_module")
platforms = tuple(map(xb.canonicalize_platform, platforms))

in_avals = (jaxpr.in_avals if arg_shardings is None else
Expand Down Expand Up @@ -1378,6 +1379,7 @@ def lower_jaxpr_to_fun(
Returns:
MLIR func op
"""
util.test_event("lower_jaxpr_to_fun", name)

# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
Expand Down
25 changes: 15 additions & 10 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,20 @@ def _shard_mutable_array(xs, shardings, layouts, copy_semantics):
def batched_device_put(aval: core.ShapedArray,
sharding: JSharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array

bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
util.test_event("batched_device_put_start")
try:
from jax._src import array

bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
finally:
util.test_event("batched_device_put_end")

def _shard_aval(size, axis: int, aval):
try:
Expand Down Expand Up @@ -2850,6 +2854,7 @@ def from_hlo(name: str,
mesh = i.mesh
break

util.test_event("pxla_cached_compilation")
xla_executable = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def _infer_params_impl(
kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None,
) -> tuple[PjitParams, list[Any]]:
util.test_event("pjit._infer_params_impl", fun)
have_kwargs = bool(kwargs)
if have_kwargs and ji.user_specified_in_shardings:
raise ValueError(
Expand Down Expand Up @@ -1297,6 +1298,7 @@ def _create_pjit_jaxpr(
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
util.test_event("create_pjit_jaxpr")
del ignored_inline # just for explain_cache_miss
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
Expand Down Expand Up @@ -1784,6 +1786,7 @@ def _pjit_lower(
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
util.test_event("pjit_lower")
if config.sharding_in_types.value:
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
else:
Expand Down
1 change: 1 addition & 0 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def output_layouts(self):

@staticmethod
def call(*args, **kwargs):
util.test_event("stages_compiled_call")
# This is because `__call__` passes in `self._params` as the first argument.
# Instead of making the call signature `call(params, *args, **kwargs)`
# extract it from args because `params` can be passed as a kwarg by users
Expand Down
Loading

0 comments on commit 99d675a

Please sign in to comment.