From 53e781a1eede391d633ebf35079e5824090546fa Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Thu, 10 Oct 2024 03:16:26 -0700 Subject: [PATCH 001/698] Options to control heartbeat monitor timeouts. --- jax/_src/distributed.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 3ea9304b67aa..dd156a649c10 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -26,6 +26,33 @@ logger = logging.getLogger(__name__) +_DISTRIBUTED_SERVICE_HEARTBEAT_INTERVAL_SECONDS = config.int_flag( + name="jax_distributed_service_heartbeat_interval_seconds", + default=10, + help="Number of heartbeats that a client may miss in a row before the " + "coordinator concludes that a client has vanished.", +) + +_DISTRIBUTED_SERVICE_MAX_MISSING_HEARTBEATS = config.int_flag( + name="jax_distributed_service_max_missing_heartbeats", + default=10, + help=".", +) + +_DISTRIBUTED_CLIENT_HEARTBEAT_INTERVAL_SECONDS = config.int_flag( + name="jax_distributed_client_heartbeat_interval_seconds", + default=10, + help="Interval at which the client should send heartbeat RPCs to the " + "coordinator.", +) + +_DISTRIBUTED_CLIENT_MAX_MISSING_HEARTBEATS = config.int_flag( + name="jax_distributed_client_max_missing_heartbeats", + default=10, + help="How many failed heartbeat RPCs may fail due to a possibly-ephemeral " + "reason before we decide the coordinator has vanished and that we " + "should shut down.", +) class State: process_id: int = 0 @@ -107,7 +134,9 @@ def initialize(self, 'Starting JAX distributed service on %s', coordinator_bind_address ) self.service = xla_extension.get_distributed_runtime_service( - coordinator_bind_address, num_processes) + coordinator_bind_address, num_processes, + heartbeat_interval=_DISTRIBUTED_SERVICE_HEARTBEAT_INTERVAL_SECONDS.value, + max_missing_heartbeats=_DISTRIBUTED_SERVICE_MAX_MISSING_HEARTBEATS.value) self.num_processes = num_processes @@ -115,7 +144,9 @@ def initialize(self, raise RuntimeError('distributed.initialize should only be called once.') self.client = xla_extension.get_distributed_runtime_client( - coordinator_address, process_id, init_timeout=initialization_timeout) + coordinator_address, process_id, init_timeout=initialization_timeout, + heartbeat_interval=_DISTRIBUTED_CLIENT_HEARTBEAT_INTERVAL_SECONDS.value, + max_missing_heartbeats=_DISTRIBUTED_CLIENT_MAX_MISSING_HEARTBEATS.value) logger.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() From fafebd254aafc1dd57557466e5ae24dc5954cad9 Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Thu, 10 Oct 2024 09:58:58 -0700 Subject: [PATCH 002/698] Pass the heartbeat timeouts in parameter --- jax/_src/distributed.py | 41 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index dd156a649c10..e0155e012736 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -26,33 +26,6 @@ logger = logging.getLogger(__name__) -_DISTRIBUTED_SERVICE_HEARTBEAT_INTERVAL_SECONDS = config.int_flag( - name="jax_distributed_service_heartbeat_interval_seconds", - default=10, - help="Number of heartbeats that a client may miss in a row before the " - "coordinator concludes that a client has vanished.", -) - -_DISTRIBUTED_SERVICE_MAX_MISSING_HEARTBEATS = config.int_flag( - name="jax_distributed_service_max_missing_heartbeats", - default=10, - help=".", -) - -_DISTRIBUTED_CLIENT_HEARTBEAT_INTERVAL_SECONDS = config.int_flag( - name="jax_distributed_client_heartbeat_interval_seconds", - default=10, - help="Interval at which the client should send heartbeat RPCs to the " - "coordinator.", -) - -_DISTRIBUTED_CLIENT_MAX_MISSING_HEARTBEATS = config.int_flag( - name="jax_distributed_client_max_missing_heartbeats", - default=10, - help="How many failed heartbeat RPCs may fail due to a possibly-ephemeral " - "reason before we decide the coordinator has vanished and that we " - "should shut down.", -) class State: process_id: int = 0 @@ -69,7 +42,11 @@ def initialize(self, local_device_ids: int | Sequence[int] | None = None, cluster_detection_method: str | None = None, initialization_timeout: int = 300, - coordinator_bind_address: str | None = None): + coordinator_bind_address: str | None = None, + service_heartbeat_interval_seconds: int = 10, + service_max_missing_heartbeats: int = 10, + client_heartbeat_interval_seconds: int = 10, + client_max_missing_heartbeats: int = 10): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): @@ -135,8 +112,8 @@ def initialize(self, ) self.service = xla_extension.get_distributed_runtime_service( coordinator_bind_address, num_processes, - heartbeat_interval=_DISTRIBUTED_SERVICE_HEARTBEAT_INTERVAL_SECONDS.value, - max_missing_heartbeats=_DISTRIBUTED_SERVICE_MAX_MISSING_HEARTBEATS.value) + heartbeat_interval=service_heartbeat_interval_seconds, + max_missing_heartbeats=service_max_missing_heartbeats) self.num_processes = num_processes @@ -145,8 +122,8 @@ def initialize(self, self.client = xla_extension.get_distributed_runtime_client( coordinator_address, process_id, init_timeout=initialization_timeout, - heartbeat_interval=_DISTRIBUTED_CLIENT_HEARTBEAT_INTERVAL_SECONDS.value, - max_missing_heartbeats=_DISTRIBUTED_CLIENT_MAX_MISSING_HEARTBEATS.value) + heartbeat_interval=client_heartbeat_interval_seconds, + max_missing_heartbeats=client_max_missing_heartbeats) logger.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() From 5b8e4db85558f5eb27f40de2ac6488eb78e02c03 Mon Sep 17 00:00:00 2001 From: keshavb96 Date: Fri, 18 Oct 2024 11:25:06 -0700 Subject: [PATCH 003/698] document jax config to disable remat HLO pass --- docs/gpu_memory_allocation.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 1fde02a14655..8e25807dcb95 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -60,3 +60,9 @@ Common causes of OOM failures **Running JAX on the display GPU.** Use :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` or :code:`XLA_PYTHON_CLIENT_PREALLOCATE`. + +**Disabling rematerialization HLO pass** + Sometimes disabling the rematerialization HLO pass is favorable to avoid + poor remat choices by the compiler. The pass can be disabled by + :code:`jax.config.update('enable_remat_opt_pass', False)`. But this can + sometimes lead to OOM failures. From 2789b0d4db0806fa7ab3afff136b354533ee88c6 Mon Sep 17 00:00:00 2001 From: Keshav Balasubramanian Date: Fri, 18 Oct 2024 11:31:51 -0700 Subject: [PATCH 004/698] minor change --- docs/gpu_memory_allocation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 8e25807dcb95..5fc7b503a5c3 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -63,6 +63,6 @@ Common causes of OOM failures **Disabling rematerialization HLO pass** Sometimes disabling the rematerialization HLO pass is favorable to avoid - poor remat choices by the compiler. The pass can be disabled by + poor remat choices by the compiler. The pass can be disabled by adding :code:`jax.config.update('enable_remat_opt_pass', False)`. But this can sometimes lead to OOM failures. From 7409bae64c85790927a679cc86a01248ffb4adb6 Mon Sep 17 00:00:00 2001 From: kaixih Date: Mon, 21 Oct 2024 17:00:04 +0000 Subject: [PATCH 005/698] Adjusted atol/rtol for jax sdpa tests --- tests/nn_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/nn_test.py b/tests/nn_test.py index df719256a921..0856b259c190 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -99,8 +99,8 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) - self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) - self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) + self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01) + self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01) @parameterized.product( mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), @@ -164,10 +164,10 @@ def testDotProductAttentionMask(self, mask_mode): self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) - self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) + self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02) self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) - self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) - self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03) + self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01) + self.assertAllClose(dbias_ref, dbias_ans, rtol=.02, atol=.02) @parameterized.product( batch_size=[1, 16], @@ -224,7 +224,7 @@ def bwd_ans(x, bias, mask): else: _, dbias_ref, _ = bwd_ref(x, bias, mask) _, dbias_ans, _ = bwd_ans(x, bias, mask) - self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03) + self.assertAllClose(dbias_ans, dbias_ref, rtol=.02, atol=.02) @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): From 57507668981658e9290758f24001f7ac2a483011 Mon Sep 17 00:00:00 2001 From: Keshav Balasubramanian Date: Mon, 21 Oct 2024 12:08:58 -0700 Subject: [PATCH 006/698] more detail --- docs/gpu_memory_allocation.rst | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 5fc7b503a5c3..dac52c194603 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -62,7 +62,11 @@ Common causes of OOM failures :code:`XLA_PYTHON_CLIENT_PREALLOCATE`. **Disabling rematerialization HLO pass** - Sometimes disabling the rematerialization HLO pass is favorable to avoid - poor remat choices by the compiler. The pass can be disabled by adding - :code:`jax.config.update('enable_remat_opt_pass', False)`. But this can - sometimes lead to OOM failures. + Sometimes disabling the automatic rematerialization HLO pass is favorable to avoid + poor remat choices by the compiler. The pass can be enable/disable by setting + :code:`jax.config.update('enable_remat_opt_pass', True)` or + :code:`jax.config.update('enable_remat_opt_pass', False)` respectively. Enabling or + disabling the automatic remat pass produces different trade-offs between compute and + memory. Note however, that the algorithm is basic and you can often get better + trade-off between compute and memory by disabling the automatic remat pass and doing + it manually with `the jax.remat API `_ From 7ed65c89a417b4277bdb44385ade8ff42f84d699 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 22 Oct 2024 09:11:56 +0200 Subject: [PATCH 007/698] [docs] Added two new APIs to the export API docs --- docs/jax.export.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/jax.export.rst b/docs/jax.export.rst index d458b6c64e8e..2095758dd3b3 100644 --- a/docs/jax.export.rst +++ b/docs/jax.export.rst @@ -28,6 +28,8 @@ Functions minimum_supported_calling_convention_version maximum_supported_calling_convention_version default_export_platform + register_pytree_node_serialization + register_namedtuple_serialization Functions related to shape polymorphism --------------------------------------- From 8c85f744ff04a09c6eee7d4f5d3edacd1e1b5fc1 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Tue, 22 Oct 2024 13:32:09 -0400 Subject: [PATCH 008/698] Add newline at the end of `.bazelrc` --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index 60e7326adf09..98bca5901d47 100644 --- a/.bazelrc +++ b/.bazelrc @@ -379,4 +379,4 @@ build:debug --config debug_symbols -c fastbuild try-import %workspace%/.jax_configure.bazelrc # Load rc file with user-specific options. -try-import %workspace%/.bazelrc.user \ No newline at end of file +try-import %workspace%/.bazelrc.user From 32be1992eef01be03ecdb58c501e77dd72bae667 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 22 Oct 2024 12:27:56 -0700 Subject: [PATCH 009/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/bf8dafb2a7dfe5ea32988515d491ca6c0fd2c83f. PiperOrigin-RevId: 688648145 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3dd8cbd33712..6a80394a2e7e 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "76da730179313b3bebad6dea6861768421b7358c" -XLA_SHA256 = "d67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757" +XLA_COMMIT = "bf8dafb2a7dfe5ea32988515d491ca6c0fd2c83f" +XLA_SHA256 = "a74647bd55cc0c9765d02bdaa29c5a78580afa34a0c9180a895f3e7bd06ac1b1" def repo(): tf_http_archive( From f8a1f02d6b39957ee5afd704d5816514416d9626 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 22 Oct 2024 13:10:05 -0700 Subject: [PATCH 010/698] [sharding_in_types][Take 2] Add `out_type` argument to `einsum` and `dot_general` to allow specifying for the output type. Right now, it only accept a `NamedSharding` but in the future we can allow a polymorphic type of: `jax.ShapeDtypeStruct | Sharding | Layout`. Reverts 0b3f0e11fb0c37342b3c05ad5d53f3435b6ca44c PiperOrigin-RevId: 688663504 --- jax/_src/lax/lax.py | 78 +++++++++++++++++++++++------- jax/_src/numpy/lax_numpy.py | 42 ++++++++++++---- jax/_src/pallas/triton/lowering.py | 3 +- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/sparse/bcoo.py | 12 +++-- jax/experimental/sparse/bcsr.py | 5 +- jax/experimental/sparse/util.py | 2 +- tests/pjit_test.py | 54 +++++++++++++++++++++ 8 files changed, 163 insertions(+), 35 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 113c87b60ee0..a41c4c4cec5a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1040,7 +1040,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + out_type=None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -1086,6 +1087,13 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ + if out_type is not None and not config.sharding_in_types.value: + raise NotImplementedError("out_type only works when sharding_in_types " + "config is True.") + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError( + '`out_type` argument of `dot_general` only supports NamedSharding ' + 'instances. Please file a bug if this is not enough for your use case.') (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), api_util._ensure_index_tuple(rhs_contract)) @@ -1097,7 +1105,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_type=out_type) def ragged_dot( @@ -1123,7 +1132,8 @@ def ragged_dot( """ return ragged_dot_p.bind(lhs, rhs, group_sizes, precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type, group_offset=group_offset) + preferred_element_type=preferred_element_type, + group_offset=group_offset) def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array: @@ -3002,7 +3012,11 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, not dtypes.issubdtype(new_dtype, np.complexfloating)): operand = hlo.real(operand) aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) - return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)] + out = mlir.convert_hlo(ctx, operand, aval_in, aval_out) + if config.sharding_in_types.value: + proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [out] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -3164,7 +3178,10 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): @@ -3241,24 +3258,29 @@ def _check_specs_match(lhs_spec, rhs_spec, msg): raise TypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): if lhs.sharding.mesh != rhs.sharding.mesh: raise ValueError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') + if out_type is not None: + assert isinstance(out_type, NamedSharding) + return out_type + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch) msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " - f"to have the consistent sharding, got {lhs_batch_spec} and " - f"{rhs_batch_spec}.") + f"to have the consistent sharding, got {lhs_batch_spec} and " + f"{rhs_batch_spec}.") _check_specs_match(lhs_batch_spec, rhs_batch_spec, msg) lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) msg = ("dot_general requires contracting dimensions to have consistent " - f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") + f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") _check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg) return _dot_general_sharding_computation( @@ -3280,7 +3302,10 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError del dimension_numbers # unused # We're mostly matching XLA's logic here, namely in shape_inference.cc and # primitive_util.h's HigherPrecisionType, e.g. @@ -3327,7 +3352,9 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, - swap_ans=False): + out_type, swap_ans=False): + if out_type is not None: + raise NotImplementedError (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim x_kept = remaining(range(x_ndim), x_contract, x_batch) @@ -3347,12 +3374,16 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + out_type): + if out_type is not None: + raise NotImplementedError (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) y_bar = _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type, swap_ans=True) + preferred_element_type=preferred_element_type, out_type=out_type, + swap_ans=True) if y_bar.dtype != y.aval.dtype: y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) return y_bar @@ -3366,6 +3397,7 @@ def _dot_batch_rule( batch_dims, *, dimension_numbers, + out_type, precision, preferred_element_type: DTypeLike | None, **_, @@ -3395,12 +3427,16 @@ def _dot_batch_rule( rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) else: rhs_shape = np.shape(rhs) + if out_type is not None: + raise NotImplementedError("vmap with out_type is not supported. " + "Please open an issue.") batched_out = invoke_prim( lhs, rhs, new_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, + out_type=out_type, ) result_batch_dim = batching.shape_as_bdim( result_stack_dim, @@ -3570,7 +3606,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, - platform: str = "default"): + out_type, platform: str = "default"): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -3658,6 +3694,8 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): **algorithm_kwarg, ) if config.sharding_in_types.value: + if out_type is not None: + assert aval_out.sharding == out_type out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp) if accumulation_aval.dtype != aval_out.dtype: @@ -3711,12 +3749,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S return (m, n) def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: + precision, preferred_element_type: DTypeLike | None, + **_) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. - return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, - precision=precision, preferred_element_type=preferred_element_type) + return _dot_general_dtype_rule( + lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, preferred_element_type=preferred_element_type, + out_type=None) def _ragged_dot_jvp_rule( @@ -3839,7 +3880,9 @@ def _ragged_dot_invoke_prim( new_dimension_numbers, precision, preferred_element_type, + out_type, ): + del out_type return ragged_dot( lhs, rhs, @@ -3868,6 +3911,7 @@ def _ragged_dot_batch_rule( dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type, + out_type=None, ) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6b1bd9acf3ca..714f21577b28 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -67,10 +67,10 @@ DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar, ) from jax._src.util import ( - NumpyComplexWarning, - canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) -from jax.sharding import Sharding, SingleDeviceSharding + NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, + ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) +from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, + PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum @@ -9081,6 +9081,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: ... @overload @@ -9093,6 +9094,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: ... def einsum( @@ -9103,6 +9105,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[..., Array] = lax.dot_general, + out_type=None, ) -> Array: """Einstein summation @@ -9334,11 +9337,11 @@ def einsum( contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) - einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True) if spec is not None: einsum = jax.named_call(einsum, name=spec) return einsum(operands, contractions, precision, - preferred_element_type, _dot_general) + preferred_element_type, _dot_general, out_type) # Enable other modules to override einsum_contact_path. @@ -9437,7 +9440,15 @@ def _einsum( precision, preferred_element_type, _dot_general=lax.dot_general, + out_type=None, ): + if out_type is not None and not config.sharding_in_types.value: + raise NotImplementedError("out_type only works when sharding_in_types " + "config is True.") + if out_type is not None and not isinstance(out_type, NamedSharding): + raise NotImplementedError( + "`out_type` argument of `einsum` only supports NamedSharding instances." + " Please file a bug if this is not enough for your use case.") dtypes.check_user_dtype_supported(preferred_element_type, "einsum") operands = list(map(asarray, operands)) if preferred_element_type is None: @@ -9559,13 +9570,25 @@ def filter_singleton_dims(operand, names, other_shape, other_names): names = batch_names_str + remaining_rhs_names + remaining_lhs_names if names == result_names: dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch)) + k_out_type = {} if out_type is None else {'out_type': out_type} operand = _dot_general(rhs, lhs, dimension_numbers, precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + **k_out_type) else: names = batch_names_str + remaining_lhs_names + remaining_rhs_names + if (config.sharding_in_types.value and out_type is not None and + names != result_names): + spec = out_type.spec + inverse_spec = tuple(spec[result_names.index(name)] for name in names) + dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec)) + else: + dot_general_out_type = out_type # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) + dot_general_out_type = ({} if dot_general_out_type is None else # type: ignore + {'out_type': dot_general_out_type}) operand = _dot_general(lhs, rhs, dimension_numbers, precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + **dot_general_out_type) else: raise NotImplementedError # if this is actually reachable, open an issue! @@ -9578,7 +9601,8 @@ def filter_singleton_dims(operand, names, other_shape, other_names): operand = lax.transpose(operand, perm) operands.append(operand) # used in next iteration - return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type) + return lax_internal._convert_element_type(operands[0], preferred_element_type, + output_weak_type) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b0a2b4dbcae0..79919e638d1d 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2089,10 +2089,11 @@ def _dot_general_lowering( b, *, dimension_numbers, + out_type, precision, preferred_element_type, ): - del preferred_element_type # Unused. + del preferred_element_type, out_type # Unused. ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index a5cfa5f9b928..dcf9cafb5117 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2180,7 +2180,7 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None): tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated -def _dot_general(lhs, rhs, *, dimension_numbers, +def _dot_general(lhs, rhs, *, dimension_numbers, out_type, precision: lax_internal.CanonicalPrecision, preferred_element_type: DType | None, _in_avals: Sequence[core.ShapedArray], diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index f65f7b0a194b..477f634744ed 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -606,8 +606,11 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequenc bcoo_dot_general_p = core.Primitive('bcoo_dot_general') -def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers, - precision: None = None, preferred_element_type: None = None) -> BCOO | Array: +def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, + dimension_numbers: DotDimensionNumbers, + precision: None = None, + preferred_element_type: None = None, + out_type=None) -> BCOO | Array: """A general contraction operation. Args: @@ -625,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: the result will be dense, of type ndarray. """ # TODO(jakevdp) make use of these? - del precision # unused + del precision, out_type # unused if isinstance(lhs, BCOO) and isinstance(rhs, BCOO): shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers) @@ -1051,7 +1054,8 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers) indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True) kwds = {'dimension_numbers': dimension_numbers, 'precision': None, - 'preferred_element_type': None} + 'preferred_element_type': None, + 'out_type': None} A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds) return A, B, indices diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 8aa7d80c7a29..372bce0344ba 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -462,7 +462,8 @@ def _bcsr_extract_batching_rule(batched_args, batch_dims): def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, - preferred_element_type: None = None) -> Array: + preferred_element_type: None = None, + out_type=None) -> Array: """A general contraction operation. Args: @@ -479,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray. """ - del precision # unused + del precision, out_type # unused if isinstance(rhs, (np.ndarray, jax.Array)): if isinstance(lhs, (np.ndarray, jax.Array)): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers, diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 7ef1ed781c15..2cb765676411 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -111,4 +111,4 @@ def _dot_general_validated_shape( rhs = core.ShapedArray(rhs_shape, np.float32) return _dot_general_shape_rule( lhs, rhs, dimension_numbers=dimension_numbers, - precision=None, preferred_element_type=None) + precision=None, preferred_element_type=None, out_type=None) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d3b96676afdc..fd65c79538cf 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4945,6 +4945,60 @@ def g(x): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + def test_einsum_with_out_type(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertArraysEqual(out, np_inp @ np_inp.T) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + lowered_text = f.lower(arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + + @jax.jit + def g(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr4 = jax.device_put(np_inp.T, NamedSharding(mesh, P('x', 'y'))) + out2 = g(arr3, arr4) + self.assertArraysEqual(out2, np_inp @ np_inp.T) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + + def test_einsum_inverse(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(64) + + @jax.jit + def h(x, y): + s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None)) + out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s) + self.assertEqual(out.sharding.spec, s.spec) + return out + + arr1 = jax.device_put(np_inp.reshape(8, 4, 2), + NamedSharding(mesh, P('x', 'y', None))) + arr2 = jax.device_put(np_inp.reshape(2, 4, 8), + NamedSharding(mesh, P(None, 'x', 'y'))) + out = h(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None))) + + lowered_text = h.lower(arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 9038bb2664a6d5249025c4590feb7d430e509253 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 22 Oct 2024 08:41:58 -0700 Subject: [PATCH 011/698] Better documentation for jnp.indices --- jax/_src/numpy/lax_numpy.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 714f21577b28..b92e5e250d1c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6913,6 +6913,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, A length-N list of grid arrays. See also: + - :func:`jax.numpy.indices`: generate a grid of indices. - :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax. - :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax. @@ -7085,9 +7086,38 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: ... -@util.implements(np.indices) def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: + """Generate arrays of grid indices. + + JAX implementation of :func:`numpy.indices`. + + Args: + dimensions: the shape of the grid. + dtype: the dtype of the indices (defaults to integer). + sparse: if True, then return sparse indices. Default is False, which + returns dense indices. + + Returns: + An array of shape ``(len(dimensions), *dimensions)`` If ``sparse`` is False, + or a sequence of arrays of the same length as ``dimensions`` if ``sparse`` is True. + + See also: + - :func:`jax.numpy.meshgrid`: generate a grid from arbitrary input arrays. + - :obj:`jax.numpy.mgrid`: generate dense indices using a slicing syntax. + - :obj:`jax.numpy.ogrid`: generate sparse indices using a slicing syntax. + + Examples: + >>> jnp.indices((2, 3)) + Array([[[0, 0, 0], + [1, 1, 1]], + + [[0, 1, 2], + [0, 1, 2]]], dtype=int32) + >>> jnp.indices((2, 3), sparse=True) + (Array([[0], + [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32)) + """ dtypes.check_user_dtype_supported(dtype, "indices") dtype = dtype or dtypes.canonicalize_dtype(int_) dimensions = tuple( From 4688da31183d2270c5059af46b5065c6e0a1d077 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 22 Oct 2024 16:53:03 -0700 Subject: [PATCH 012/698] Fix jax2tf failure coming from dot_general PiperOrigin-RevId: 688738110 --- jax/experimental/jax2tf/impl_no_xla.py | 1 + jax/experimental/jax2tf/jax2tf.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 310cbaab6d59..0d8c95d42676 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -364,6 +364,7 @@ def _conv_general_dilated( def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, + out_type=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index dcf9cafb5117..29a1034e51ed 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2180,9 +2180,10 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None): tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated -def _dot_general(lhs, rhs, *, dimension_numbers, out_type, +def _dot_general(lhs, rhs, *, dimension_numbers, precision: lax_internal.CanonicalPrecision, preferred_element_type: DType | None, + out_type=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" From d6f4ce1612d58421ad61d3e1231b9a93986d9d0a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 23 Oct 2024 07:58:31 -0700 Subject: [PATCH 013/698] Better docs for jnp.unwrap --- jax/_src/numpy/lax_numpy.py | 55 ++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b92e5e250d1c..69f6b6ebbdf3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3900,10 +3900,63 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None, return nonzero(ravel(a), size=size, fill_value=fill_value)[0] -@util.implements(np.unwrap) @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: + """Unwrap a periodic signal. + + JAX implementation of :func:`numpy.unwrap`. + + Args: + p: input array + discont: the maximum allowable discontinuity in the sequence. The + default is ``period / 2`` + axis: the axis along which to unwrap; defaults to -1 + period: the period of the signal, which defaults to :math:`2\\pi` + + Returns: + An unwrapped copy of ``p``. + + Examples: + Consider a situation in which you are making measurements of the position of + a rotating disk via the ``x`` and ``y`` locations of some point on that disk. + The underlying variable is an always-increating angle which we'll generate + this way, using degrees for ease of representation: + + >>> rng = np.random.default_rng(0) + >>> theta = rng.integers(0, 90, size=(20,)).cumsum() + >>> theta + array([ 76, 133, 179, 203, 230, 233, 239, 240, 255, 328, 386, 468, 513, + 567, 654, 719, 775, 823, 873, 957]) + + Our observations of this angle are the ``x`` and ``y`` coordinates, given by + the sine and cosine of this underlying angle: + + >>> x, y = jnp.sin(jnp.deg2rad(theta)), jnp.cos(jnp.deg2rad(theta)) + + Now, say that given these ``x`` and ``y`` coordinates, we wish to recover + the original angle ``theta``. We might do this via the :func:`atan2` function: + + >>> theta_out = jnp.rad2deg(jnp.atan2(x, y)).round() + >>> theta_out + Array([ 76., 133., 179., -157., -130., -127., -121., -120., -105., + -32., 26., 108., 153., -153., -66., -1., 55., 103., + 153., -123.], dtype=float32) + + The first few values match the input angle ``theta`` above, but after this the + values are wrapped because the ``sin`` and ``cos`` observations obscure the phase + information. The purpose of the :func:`unwrap` function is to recover the original + signal from this wrapped view of it: + + >>> jnp.unwrap(theta_out, period=360) + Array([ 76., 133., 179., 203., 230., 233., 239., 240., 255., 328., 386., + 468., 513., 567., 654., 719., 775., 823., 873., 957.], dtype=float32) + + It does this by assuming that the true underlying sequence does not differ by more than + ``discont`` (which defaults to ``period / 2``) within a single step, and when it encounters + a larger discontinuity it adds factors of the period to the data. For periodic signals + that satisfy this assumption, :func:`unwrap` can recover the original phased signal. + """ util.check_arraylike("unwrap", p) p = asarray(p) if issubdtype(p.dtype, np.complexfloating): From ce8ecbd16d324f1f1e214eee5df047cb7eac9306 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 23 Oct 2024 08:00:02 -0700 Subject: [PATCH 014/698] Add an extension mechanism to run_state that allows: * Uninitialized values * Custom ref aval construction This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values. Co-authored-by: Sharad Vikram PiperOrigin-RevId: 688965089 --- jax/_src/state/discharge.py | 197 +++++++++++++++++++++++++++++------- jax/_src/state/types.py | 25 ++++- tests/state_test.py | 21 ++++ 3 files changed, 204 insertions(+), 39 deletions(-) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index f1c4994b473b..ecfedad971f4 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -17,6 +17,7 @@ from collections.abc import Callable, Sequence import dataclasses from functools import partial +import math import operator from typing import Any, Protocol, TypeVar @@ -34,7 +35,14 @@ from jax._src.lax import slicing as lax_slicing from jax._src.state import indexing from jax._src.state.primitives import addupdate_p, get_p, swap_p -from jax._src.state.types import AbstractRef, RefBitcaster, RefEffect, RefReshaper +from jax._src.state.types import ( + AbstractRef, + RefBitcaster, + RefEffect, + RefReshaper, + get_ref_aval_from_value, + uninitialized, +) from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array from jax._src.util import ( @@ -44,6 +52,7 @@ safe_zip, split_dict, split_list, + unzip2, weakref_lru_cache, ) import numpy as np @@ -470,39 +479,87 @@ def _closed_call_discharge_rule( run_state_p.multiple_results = True def _run_state_bind(*args: Any, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): if config.enable_checks.value: core.check_jaxpr(jaxpr) - assert len(jaxpr.invars) == len(args) - assert len(which_linear) == len(args) + num_uninitialized = sum(not i for i in is_initialized) + assert len(jaxpr.invars) == len(args) + num_uninitialized + assert len(which_linear) == len(args) + num_uninitialized return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr, - which_linear=which_linear) + which_linear=which_linear, + is_initialized=is_initialized) run_state_p.def_custom_bind(_run_state_bind) + +def _default_initialization(x): + assert hasattr(x, 'shape') + assert hasattr(x, 'dtype') + dtype = np.dtype(x) + if np.issubdtype(dtype, np.integer): + value = np.iinfo(dtype).min + else: + value = math.nan + return lax.full(x.shape, value, dtype) + + def _run_state_impl(*args: Any, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): del which_linear discharged_jaxpr, consts = discharge_state(jaxpr, ()) + # Initialize the args that are not initialized. + args_it = iter(args) + args = tuple( + next(args_it) if is_init else _default_initialization(var.aval) + for is_init, var in zip(is_initialized, discharged_jaxpr.invars) + ) return core.eval_jaxpr(discharged_jaxpr, consts, *args) run_state_p.def_impl(_run_state_impl) mlir.register_lowering(run_state_p, mlir.lower_fun(_run_state_impl)) def _run_state_abstract_eval(*avals: core.AbstractValue, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): del which_linear + assert sum(is_initialized) == len(avals) # When we abstractly evaluate `run_state`, we want to keep track of which # input avals are `Ref`s and which are not. If an aval is a `Ref`, we want to # "propagate" out its inner effects. Otherwise, the effects are local to this # `run_state`. + inner_to_outer_aval_mapping = {} + outer_ref_index = 0 + for i, is_init in enumerate(is_initialized): + if not is_init: + pass + inner_to_outer_aval_mapping[i] = outer_ref_index + outer_ref_index += 1 + nonlocal_effects = set() is_ref = {i for i, aval in enumerate(avals) if isinstance(aval, AbstractRef)} - nonlocal_effects = {e for e in jaxpr.effects - if (isinstance(e, RefEffect) and e.input_index in is_ref) - or not isinstance(e, RefEffect)} + for eff in jaxpr.effects: + if not isinstance(eff, RefEffect): + nonlocal_effects.add(eff) + continue + if eff.input_index not in inner_to_outer_aval_mapping: + # This means that this effect corresponds to an uninitialized Ref and + # should not propagate out of the primitive. + continue + # If we do propagate the effect, we need to update the input index to + # correspond to the outer index. + outer_index = inner_to_outer_aval_mapping[eff.input_index] + if outer_index in is_ref: + # This means that the effect corresponds to a Ref from an outside scope. + nonlocal_effects.add( + eff.replace(input_index=inner_to_outer_aval_mapping[eff.input_index]) + ) return avals, nonlocal_effects run_state_p.def_effectful_abstract_eval(_run_state_abstract_eval) def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, - jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]): + jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError("Uninitialized Refs are not supported in jvp.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) for _ in range(len(nonzero_tangents)): @@ -524,7 +581,9 @@ def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, jvp_jaxpr = hoist_consts_to_refs(jvp_jaxpr_) jvp_which_linear = (*(False,) * len(jvp_consts), *which_linear, *(True,) * len(tangents)) out = run_state_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr, - which_linear=jvp_which_linear) + which_linear=jvp_which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jvp_jaxpr.invars)) out_consts, out_primals, out_tangents = split_list(out, [len(jvp_consts), len(primals)]) del out_consts @@ -576,7 +635,12 @@ def eval_jaxpr(*refs): return jaxpr def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, - jaxpr: core.Jaxpr, which_linear: tuple[bool, ...]): + jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in partial_eval." + ) num_inputs = len(tracers) assert num_inputs == len(jaxpr.invars) in_unknowns = [not t.pval.is_known() for t in tracers] @@ -636,7 +700,9 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) out_flat = run_state_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, - which_linear=jaxpr_known_which_linear) + which_linear=jaxpr_known_which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_known.invars)) known_outputs, residuals = split_list(out_flat, [len(known_tracers)]) residuals = map(trace.new_instantiated_const, residuals) ref_res, nonref_res = split_list(residuals, [num_res_ref]) @@ -664,7 +730,9 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, assert len(unknown_inputs) == len(res_ref_unknown_outputs) assert len(unknown_inputs) == len(jaxpr_unknown.invars) - uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear) + uk_params = dict(jaxpr=jaxpr_unknown, which_linear=unknown_which_linear, + # TODO(sharadmv); compute this in the general case + is_initialized=(True,) * len(jaxpr_unknown.invars)) _, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs], **uk_params) eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, @@ -682,7 +750,13 @@ def _run_state_partial_eval_custom( eqn: core.JaxprEqn): if not any(in_unknowns): return eqn, None, in_unknowns, [False] * len(in_unknowns), [] - jaxpr, which_linear = split_dict(eqn.params, ["jaxpr", "which_linear"]) + jaxpr, which_linear, is_initialized = split_dict( + eqn.params, ["jaxpr", "which_linear", "is_initialized"] + ) + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in partial_eval_custom." + ) num_inputs = len(eqn.invars) # We first need to run a fixpoint to determine which of the `Ref`s are unknown # after running the for loop. However, the jaxpr has no outputs. Instead, we @@ -709,7 +783,8 @@ def _run_state_partial_eval_custom( break in_unknowns = map(operator.or_, in_unknowns, out_unknowns) else: - if num_inputs > 0: raise Exception("Invalid fixpoint") + if num_inputs > 0: + raise Exception("Invalid fixpoint") del out_unknowns # Redundant since it's the same as `in_unknowns` new_inst = [x for x, already, inst in zip(eqn.invars, in_inst, out_inst) if type(x) is core.Var and inst and not already] @@ -748,7 +823,9 @@ def _run_state_partial_eval_custom( jaxpr_known_which_linear = (*known_which_linear, *(False,) * num_res) known_and_res_invars = [*known_invars, *ref_resvars, *nonref_resvars] - known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear) + known_params = dict(jaxpr=jaxpr_known, which_linear=jaxpr_known_which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_known.invars)) _, known_effects = run_state_p.abstract_eval( *[v.aval for v in known_and_res_invars], **known_params) eqn_known = pe.new_jaxpr_eqn(known_and_res_invars, @@ -760,7 +837,9 @@ def _run_state_partial_eval_custom( _, staged_which_linear = partition_list(in_unknowns, which_linear) which_linear_unknown = (*[False] * num_res, *staged_which_linear) - staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown) + staged_params = dict(jaxpr=jaxpr_staged, which_linear=which_linear_unknown, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_staged.invars)) rejiggered_resvars = [*nonref_resvars, *ref_resvars] _, staged_invars = partition_list(in_unknowns, eqn.invars) res_staged_invars = [*rejiggered_resvars, *staged_invars] @@ -791,8 +870,12 @@ def staged(*args): return eqn_known, eqn_staged, in_unknowns, in_unknowns, new_vars pe.partial_eval_jaxpr_custom_rules[run_state_p] = _run_state_partial_eval_custom -def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool] - ) -> tuple[core.Jaxpr, Any]: +def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool], + is_initialized: tuple[bool, ...]) -> tuple[core.Jaxpr, Any]: + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in transpose." + ) def trans(*args): # First we want to run the computation to read all the residual refs. We can # do that by using partial evaluation with all linear inputs unknown. @@ -811,8 +894,14 @@ def trans(*args): all_avals = [*res_avals, *[v.aval for v in res_jaxpr_.outvars]] empty_res = map(ad.zeros_like_aval, all_avals) res_jaxpr, _ = _convert_outputs_to_writes(res_jaxpr_) - res = run_state_p.bind(*res_args, *empty_res, jaxpr=res_jaxpr, - which_linear=(False,) * (len(res_args) + len(empty_res))) + res = run_state_p.bind( + *res_args, + *empty_res, + jaxpr=res_jaxpr, + which_linear=(False,) * (len(res_args) + len(empty_res)), + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(res_jaxpr.invars), + ) res = res[len(res_args):] ref_res_, nonref_res_ = split_list(res, [num_res_ref]) @@ -835,7 +924,12 @@ def trans(*args): return jaxpr_trans, consts def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...]): + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in transpose." + ) # if any in_ct is nonzero, we definitely want it in args_ (and the # corresponding x in args could be an undefined primal, but doesn't have to be) # for non-res stuff: @@ -859,12 +953,19 @@ def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x): # the loop was 'getting and setting', grab that cotangent! transpose_args.append(ct) - jaxpr_transpose_, consts = _transpose_jaxpr(jaxpr, which_linear) + jaxpr_transpose_, consts = _transpose_jaxpr( + jaxpr, which_linear, is_initialized + ) jaxpr_transpose = hoist_consts_to_refs(jaxpr_transpose_) which_linear = (*[False] * len(consts), *which_linear) - const_all_outs = run_state_p.bind(*consts, *transpose_args, - jaxpr=jaxpr_transpose, - which_linear=which_linear) + const_all_outs = run_state_p.bind( + *consts, + *transpose_args, + jaxpr=jaxpr_transpose, + which_linear=which_linear, + # TODO(sharadmv): compute this in the general case + is_initialized=(True,) * len(jaxpr_transpose.invars), + ) _, all_outs = split_list(const_all_outs, [len(consts)]) ct_outs = [ct if ad.is_undefined_primal(x) else None for x, ct in zip(args, all_outs)] @@ -875,9 +976,15 @@ def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr, def _run_state_discharge_rule(in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], *args: Any, jaxpr: core.Jaxpr, - which_linear: Sequence[bool]): + which_linear: Sequence[bool], + is_initialized: tuple[bool, ...]): + if not all(is_initialized): + raise NotImplementedError( + "Uninitialized Refs are not supported in discharge." + ) del out_avals - out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear) + out_vals = run_state_p.bind(*args, jaxpr=jaxpr, which_linear=which_linear, + is_initialized=is_initialized) new_invals = [] for aval, out_val in zip(in_avals, out_vals): new_invals.append(out_val if isinstance(aval, AbstractRef) else None) @@ -896,16 +1003,23 @@ def _initial_style_jaxpr(fun, in_tree, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) return jaxpr, consts, out_tree_thunk() + T = TypeVar('T') def run_state(f: Callable[..., None]) -> Callable[[T], T]: def wrapped(args): flat_args, in_tree = tree_util.tree_flatten(args) - avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args] - jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals)) + ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) + # There may be some uninitialized values here in ref_args. + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals) jaxpr = hoist_consts_to_refs(jaxpr_) - which_linear = (False,) * (len(consts) + len(flat_args)) - out_const_flat = run_state_p.bind(*consts, *flat_args, jaxpr=jaxpr, - which_linear=which_linear) + which_linear = (False,) * (len(consts) + len(ref_args)) + refs_is_initialized = tuple(r is not uninitialized for r in ref_args) + init_args = tuple(r for r in ref_args if r is not uninitialized) + # Consts are always initialized. + is_initialized = (True,) * len(consts) + refs_is_initialized + out_const_flat = run_state_p.bind(*consts, *init_args, jaxpr=jaxpr, + which_linear=which_linear, + is_initialized=is_initialized) _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped @@ -913,12 +1027,19 @@ def wrapped(args): def run_state_reference(f: Callable[..., None]): def wrapped(args): flat_args, in_tree = tree_util.tree_flatten(args) - avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in flat_args] - jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, map(AbstractRef, avals)) + ref_avals, ref_args = unzip2(map(get_ref_aval_from_value, flat_args)) + jaxpr_, consts, _ = initial_style_jaxpr(f, in_tree, ref_avals) jaxpr = hoist_consts_to_refs(jaxpr_) discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ()) + + # Initialize any uninitialized values here in ref_args in the reference. + ref_args = [ + _default_initialization(aval) if r is uninitialized else r + for r, aval in zip(ref_args, ref_avals) + ] + out_const_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, - *consts, *args) + *consts, *ref_args) _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 993eeb814e30..634617102d6c 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,7 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Protocol, Union +from typing import Any, Callable, Protocol, Union from jax._src import core from jax._src import dtypes @@ -404,3 +404,26 @@ def _unshard_ref(mesh, names, ref_aval: AbstractRef): raise NotImplementedError("Can't unshard a Ref") return ref_aval core.unshard_aval_handlers[AbstractRef] = _unshard_ref + + +# Sentinel type for indicating an uninitialized value. +class Uninitialized: + pass +uninitialized = Uninitialized() + + +_ref_type_aval_mappings: dict[ + type[Any], Callable[[Any], tuple[AbstractRef, Array | Uninitialized]], +] = {} + + +def _default_value_to_ref_aval(x: Any) -> tuple[AbstractRef, Array]: + # Default type mapping just creates an AbstractRef from the array's aval. + aval = core.raise_to_shaped(core.get_aval(x)) + return AbstractRef(aval), x + + +def get_ref_aval_from_value(x: Any): + if type(x) in _ref_type_aval_mappings: + return _ref_type_aval_mappings[type(x)](x) + return _default_value_to_ref_aval(x) diff --git a/tests/state_test.py b/tests/state_test.py index 0d6cddfc88c8..92ea2473811c 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -30,6 +30,7 @@ from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu +from jax._src.state import types as state_types from jax._src.util import tuple_insert import jax.numpy as jnp from jax._src.lax.control_flow import for_loop @@ -1411,6 +1412,26 @@ def f(refs): self.assertEqual(x, 2 + 2 * 3 * 2) self.assertEqual(y, 2 * 3 * 2) + def test_run_state_with_uninitialized_input(self): + def f(refs): + x_ref, y_ref = refs + # y_ref is uninitialized so we shouldn't read from it until we write into + # it. + x = x_ref[...] + y_ref[...] = x * 2 + x_ref[...] = y_ref[...] + x_ref[...] + # x + x * 2, x * 2 + # jax.ShapeDtypeStruct is weirdly special to JAX, so we make our own class. + class MyArrayType: + pass + state_types._ref_type_aval_mappings[MyArrayType] = lambda _: ( + AbstractRef(core.ShapedArray((), jnp.int32)), + state_types.uninitialized, + ) + x, y = run_state(f)((jnp.int32(2), MyArrayType())) + self.assertEqual(x, 2 + 2 * 2) + self.assertEqual(y, 2 * 2) + def test_nontrivial_run_state_jit(self): def f(refs): x_ref, y_ref = refs From 155aa6caa4f2d28d8fa529358642ce62c81d514a Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 23 Oct 2024 08:03:37 -0700 Subject: [PATCH 015/698] [pallas:mosaic_gpu] Memref reshape Transform to allow the user to reshape references. It is not possible for primitives to return references so in order to support reshaping we need to use TransformRef. This CL introduces both a reshape memref transform and a function for the user to create transformed refs of that type. PiperOrigin-RevId: 688966337 --- jax/_src/pallas/mosaic_gpu/core.py | 25 ++++++++++++++------- jax/_src/pallas/mosaic_gpu/lowering.py | 30 ++++++++++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 16 ++++++++++++++ 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2ed8910bf3d7..891e674cdc95 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -21,6 +21,7 @@ from collections.abc import Sequence import dataclasses import enum +import itertools as it from typing import Any, ClassVar, Literal from jax._src import core as jax_core @@ -29,6 +30,7 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import pallas_call from jax._src.state.types import Transform +from jax._src.state import indexing import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp from jaxlib.mlir import ir @@ -39,6 +41,20 @@ DimensionSemantics = Literal["parallel", "sequential"] +def is_trivial_index(idx, shape) -> bool: + """Checks if the index selects the entire shape.""" + + # Slices that select the entire dimension. + def _slices(d): + slices = [slice(b, e, s) for b, e, s in it.product([0, None], [d, None], [1, None])] + return [indexing.Slice(0, d, 1), *slices] + + if isinstance(idx, tuple): + return all(i in _slices(d) for d, i in zip(shape, idx)) + + return idx is ... or (len(shape) == 1 and idx in _slices(shape[0])) + + @dataclasses.dataclass(frozen=True, kw_only=True) class GPUCompilerParams(pallas_core.CompilerParams): """Mosaic GPU compiler parameters. @@ -405,13 +421,6 @@ def get_ref_aval(self) -> AbstractMemoryRef: ) -def _is_trivial_index(idx): - _is_deref1 = lambda i: i is Ellipsis or i == slice(None) - if isinstance(idx, tuple): - return all(_is_deref1(i) for i in idx) - - return _is_deref1(idx) - class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): __slots__ = ["inner_aval", "memory_space"] @@ -431,7 +440,7 @@ def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error arr = wgmma_accumulator_deref(tracer) - if not _is_trivial_index(idx): + if not is_trivial_index(idx, tracer.shape): arr = arr[idx] return arr diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ec5584233552..f72c901a9e76 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -44,6 +44,7 @@ from jax._src.state import discharge from jax._src.state import indexing from jax._src.state import primitives as sp +from jax._src.state.types import RefReshaper import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import utils as mgpu_utils @@ -866,6 +867,33 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): ) +def _handle_reshaping( + ref: ir.Value, transforms: Sequence[gpu_core.Transform] +) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: + is_trivial_indexer = lambda t: isinstance( + t, indexing.NDIndexer + ) and gpu_core.is_trivial_index(t.indices, t.shape) + + last_reshaper_idx = next( + reversed([i for i, t in enumerate(transforms) if isinstance(t, RefReshaper)]), + None, + ) + if last_reshaper_idx is None: + return ref, transforms + # Check that before the reshape are only trivial indexes and or + # other reshapes. + # TODO(cperivol): Reshapes should bubble up rather than being + # expected to effectively be the first ref transform. + if not all(isinstance(t, RefReshaper) or is_trivial_indexer(t) for t in transforms[:last_reshaper_idx]): + raise NotImplementedError( + "Reshapes do not compose with other transforms and indexers must be" + f" trivial (transforms: {transforms})" + ) + reshaper = cast(RefReshaper, transforms[last_reshaper_idx]) + # Skip all the reshapes and trivial indexes. + return mgpu.memref_reshape(ref, reshaper.shape), transforms[last_reshaper_idx + 1:] + + def _handle_indexing( ref: ir.Value, transforms: Sequence[gpu_core.Transform] ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: @@ -916,6 +944,7 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): raise TypeError(f"Can only load from references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) + x_smem, transforms = _handle_reshaping(x_smem, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): @@ -942,6 +971,7 @@ def _swap_lowering_rule( raise TypeError(f"Can only store to references (got {x_smem}).") x_aval = ctx.avals_in[0] transforms = jax.tree.unflatten(tree, leaves) + x_smem, transforms = _handle_reshaping(x_smem, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 22ae3e699b38..aec13a826863 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -74,6 +74,22 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) + def test_reshape(self): + shape1, shape2 = (128,), (2, 16, 4) + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + ) + def kernel(x_ref, out_ref): + x_ref_reshaped = x_ref.reshape(shape2) + self.assertEqual(x_ref.shape, shape1) + self.assertEqual(x_ref_reshaped.shape, shape2) + out_ref[...] = x_ref_reshaped[...] + + x = jnp.arange(math.prod(shape1)).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) + def test_add_xy(self): @functools.partial( pl.pallas_call, From 40c92c1f8c24ec217b3ad49e1e9290f93c2b5e79 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 23 Oct 2024 08:24:09 -0700 Subject: [PATCH 016/698] [pallas:mosaic_gpu] An extremely specific heuristic to allow swiglu. PiperOrigin-RevId: 688973012 --- jax/_src/pallas/mosaic_gpu/primitives.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 13d76174472c..a3eed49c208b 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -483,6 +483,25 @@ def _wgmma_lowering( gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims ): rhs_transpose = True + case ( + gpu_core.UnswizzleRef(rhs_swizzle), + gpu_core.TransposeRef((1, 0, 2, 3, 4)), + gpu_core.UntileRef(rhs_tiling), + gpu_core.TransposeRef(permutation=(1, 0, 2)), + state.types.RefReshaper(shape=new_shape), + ): + if len(rhs_tiling) != 2 or len(new_shape) != 2: + raise ValueError("WGMMA expects shapes 2D tiled into 2D tiles.") + + if any(d % t != 0 for d, t in util.safe_zip(new_shape, rhs_tiling)): + raise ValueError( + f"The last reshape {new_shape} is not divisible by the tiling" + f" {rhs_tiling}." + ) + + high_dims = [d // t for d, t in util.safe_zip(new_shape, rhs_tiling)] + b = mgpu.memref_reshape(b, (*high_dims, *rhs_tiling)) + rhs_transpose = False case _: raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") From 5ea6215436f7aab2a62e533f43b5b0a032480606 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 23 Oct 2024 08:34:09 -0700 Subject: [PATCH 017/698] Add test for jax2tf conversion of dot general with algorithm. Fixes https://github.com/jax-ml/jax/issues/24236 To be fair, the fix was actually in https://github.com/openxla/xla/pull/18222, but this adds a test to JAX to confirm. PiperOrigin-RevId: 688976685 --- jax/experimental/jax2tf/tests/jax2tf_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 0e8396e6455c..27f830511eab 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1690,6 +1690,22 @@ def f_jax(x): res, x + _testing_multi_platform_to_add[tf_device_jax_platform]) + def test_dot_algorithm(self): + # ref: https://github.com/jax-ml/jax/issues/24236 + if tf.version.VERSION.split(".") <= ["2", "17", "0"]: + self.skipTest("This test works only with newer versions of TF") + + if jtu.test_device_matches(["tpu"]): + algorithm = "BF16_BF16_F32" + else: + algorithm = "F32_F32_F32" + + def f_jax(x): + return jax.lax.dot(x, x, precision=algorithm) + + f_tf = jax2tf.convert(f_jax, native_serialization=True) + f_tf(np.ones((128, 128), dtype=np.float32)) # no crash + def test_dot_algorithm_non_native_unsupported(self): def f_jax(x): return jax.lax.dot(x, x, precision="F32_F32_F32") From 5b3b6e84dbe0e8155359bdfbfbbf3cc0573bd692 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 23 Oct 2024 08:35:26 -0700 Subject: [PATCH 018/698] [Pallas:MGPU] Allow initializing accumulators with values in registers This is useful to avoid unnecessary shared stores and fences in some kernels like flash attention. PiperOrigin-RevId: 688977199 --- jax/_src/pallas/mosaic_gpu/core.py | 18 +++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 51 ++++++++++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 24 ++++++++++++ 3 files changed, 93 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 891e674cdc95..c0d40799e8d8 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -31,6 +31,7 @@ from jax._src.pallas import pallas_call from jax._src.state.types import Transform from jax._src.state import indexing +from jax._src.state import types as state_types import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp from jaxlib.mlir import ir @@ -414,12 +415,29 @@ def get_ref_aval(self) -> AbstractMemoryRef: class WGMMAAccumulatorRef: shape: tuple[int, int] dtype: jnp.dtype = jnp.float32 + _init: Any = state_types.uninitialized def get_ref_aval(self) -> AbstractMemoryRef: + if self._init is not state_types.uninitialized: + raise ValueError( + "Preinitialized WGMMAAccumulatorRef only supported in pl.run_state." + ) return WGMMAAbstractAccumulatorRef( jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS ) + @staticmethod + def init(array): + return WGMMAAccumulatorRef(array.shape, array.dtype, _init=array) + + +def _wgmma_ref_type_mapping(ref: WGMMAAccumulatorRef): + aval = WGMMAAbstractAccumulatorRef( + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), GPUMemorySpace.REGS + ) + return aval, ref._init +state_types._ref_type_aval_mappings[WGMMAAccumulatorRef] = _wgmma_ref_type_mapping + class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): __slots__ = ["inner_aval", "memory_space"] diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f72c901a9e76..21f6f2570a5d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -36,6 +36,7 @@ from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives @@ -43,6 +44,7 @@ from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import discharge from jax._src.state import indexing +from jax._src.state import types as state_types from jax._src.state import primitives as sp from jax._src.state.types import RefReshaper import jax.experimental.mosaic.gpu as mgpu @@ -1261,6 +1263,55 @@ def _run_scoped_lowering_rule( return outs +@register_lowering_rule(discharge.run_state_p) +def _run_state_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr: jax_core.Jaxpr, + which_linear: tuple[bool, ...], + is_initialized: tuple[bool, ...], +): + del which_linear + # TODO(apaszke): This should be unified with run_scoped. + if not all(is_initialized): + raise NotImplementedError("Uninitialized Refs are not supported in lowering of run_state.") + + should_discharge = [] + new_input_vals = [] + for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out): + aval = v.aval + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) + should_discharge.append(True) + assert isinstance(out_aval, jax_core.ShapedArray) + else: + new_input_vals.append(arg) + should_discharge.append(not isinstance(out_aval, state_types.AbstractRef)) + if not any(should_discharge): + raise NotImplementedError( + "Expected at least one accumulator to in run_state." + ) + + discharged_jaxpr, new_consts = discharge.discharge_state( + jaxpr, (), should_discharge=should_discharge + ) + assert not new_consts + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, () + ) + # Await the accumulators and extract their final values. + nvvm_dialect.wgmma_wait_group_sync_aligned(0) + outs = [ + out.value if isinstance(out, mgpu.WGMMAAccumulator) else out + for out in outs + ] + # Blend the discharge results with refs we closed over. I don't fully + # understand the reasons behind this calling convention, but sharadmv@ has + # assured me that this is ok. + outs_it = iter(outs) + return [next(outs_it) if d else a for d, a in zip(should_discharge, args)] + + def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index aec13a826863..5c546ad17be3 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -767,6 +767,30 @@ def scope(acc_ref): )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) + def test_wgmma_registers_init(self): + def kernel(a_ref, b_ref, i_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref[...], b_ref) + o_ref[...] = pl.run_state(scope)(plgpu.ACC.init(i_ref[...])) + + key1, key2, key3 = jax.random.split(jax.random.key(42), 3) + a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) + i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 + + transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), + plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms), + ], + out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), + )(a, b, i) + np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) + def test_wgmma_sliced_ref(self): def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): From 88d231a3f2e73f0396a8f1f1444bf3c0a4b45a4c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 23 Oct 2024 09:15:42 -0700 Subject: [PATCH 019/698] [Pallas] Allow core_map's mesh to discharge backend specific effects Backends often have custom effectful primitives, but their effects do not extend beyond the scope of a single kernel, so we should remove them in core_map's abstract eval. PiperOrigin-RevId: 688990275 --- jax/_src/pallas/core.py | 4 +++ jax/_src/pallas/mosaic/core.py | 4 +++ jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/core.py | 20 ++++++++++++++ jax/_src/pallas/mosaic_gpu/primitives.py | 35 ++++++------------------ 5 files changed, 38 insertions(+), 26 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index abbd7154d1b7..fd7883109b9c 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1059,6 +1059,8 @@ def _core_map_abstract_eval(*args, jaxpr, mesh): raise ValueError("core_map must not return any outputs.") effs = set() for eff in jaxpr.effects: + if mesh.discharges_effect(eff): + continue if not isinstance(eff, jax_core.NamedAxisEffect): effs.add(eff) continue @@ -1083,6 +1085,8 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh): jax_core.check_jaxpr(jaxpr) effs = set() for eff in jaxpr.effects: + if mesh.discharges_effect(eff): + continue if not isinstance(eff, jax_core.NamedAxisEffect): effs.add(eff) continue diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 82fe9c2baa96..60207256fd05 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -220,6 +220,10 @@ class TensorCoreMesh: def shape(self): return collections.OrderedDict(zip(self.axis_names, self.devices.shape)) + def discharges_effect(self, effect: jax_core.Effect): + del effect + return False + def create_tensorcore_mesh( axis_name: str, devices: Sequence[jax.Device] | None = None diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 91616948be49..9ee7c04c3c3e 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -74,6 +74,7 @@ pytype_strict_library( "//jax", "//jax:core", "//jax:dtypes", + "//jax:effects", "//jax:mosaic_gpu", "//jax:tree_util", "//jax/_src/pallas", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index c0d40799e8d8..9d28d72f08ad 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -26,6 +26,7 @@ from jax._src import core as jax_core from jax._src import dtypes +from jax._src import effects from jax._src import tree_util from jax._src.pallas import core as pallas_core from jax._src.pallas import pallas_call @@ -511,6 +512,9 @@ def shape(self): ) return collections.OrderedDict(pairs) + def discharges_effect(self, effect: jax_core.Effect): + return effect is _wgmma_pipeline_effect or effect is _memory_effect + def _gpu_mesh_discharge_rule( in_avals, @@ -544,3 +548,19 @@ def body(*args): return out, () pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule + + +class MemoryEffect(jax_core.Effect): + pass + + +effects.control_flow_allowed_effects.add_type(MemoryEffect) +_memory_effect = MemoryEffect() + + +class _WGMMAPipelineEffect(effects.Effect): + pass + + +effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) +_wgmma_pipeline_effect = _WGMMAPipelineEffect() diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index a3eed49c208b..b3c5a8c5839c 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -21,7 +21,6 @@ import jax from jax._src import core as jax_core -from jax._src import effects from jax._src import state from jax._src import tree_util from jax._src import util @@ -246,15 +245,6 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: raise ValueError("Barrier does not support arbirary transforms") -class MemoryEffect(jax_core.Effect): - ... - - -effects.control_flow_allowed_effects.add_type(MemoryEffect) - -_memory_effect = MemoryEffect() - - barrier_arrive_p = jax_core.Primitive("barrier_arrive") barrier_arrive_p.multiple_results = True @@ -262,7 +252,7 @@ class MemoryEffect(jax_core.Effect): @barrier_arrive_p.def_effectful_abstract_eval def _barrier_arrive_abstract_eval(*avals, **params): del avals, params # Unused. - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(barrier_arrive_p) @@ -299,7 +289,7 @@ def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None: @barrier_wait_p.def_effectful_abstract_eval def _barrier_wait_abstract_eval(*avals, **params): del avals, params # Unused. - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(barrier_wait_p) @@ -336,7 +326,7 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None: @wait_smem_to_gmem_p.def_effectful_abstract_eval def _wait_smem_to_gmem_abstract_eval(n): del n # Unused. - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(wait_smem_to_gmem_p) @@ -350,13 +340,6 @@ def wait_smem_to_gmem(n: int) -> None: wait_smem_to_gmem_p.bind(n) -class _WGMMAPipelineEffect(effects.Effect): - pass - - -_wgmma_pipeline_effect = _WGMMAPipelineEffect() -effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) - # WGMMA on an accumulator reference wgmma_ref_p = jax_core.Primitive("wgmma_ref") wgmma_ref_p.multiple_results = True @@ -419,7 +402,7 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): if not isinstance(acc_aval, gpu_core.WGMMAAbstractAccumulatorRef): raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc_aval}") return (), { - _wgmma_pipeline_effect, + gpu_core._wgmma_pipeline_effect, state.WriteEffect(0), state.ReadEffect(0), state.ReadEffect(2), @@ -529,7 +512,7 @@ def _wgmma_lowering( def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs): del args, kwargs return acc, { - _wgmma_pipeline_effect, + gpu_core._wgmma_pipeline_effect, state.ReadEffect(2), *([state.ReadEffect(1)] if isinstance(lhs_ref, state.AbstractRef) else []) } @@ -545,7 +528,7 @@ def wgmma_wait(n: int): @wgmma_wait_p.def_effectful_abstract_eval def wgmma_wait_effectful_abstract_eval(_): - return [], {_wgmma_pipeline_effect} + return [], {gpu_core._wgmma_pipeline_effect} @lowering.register_lowering_rule(wgmma_wait_p) @@ -570,7 +553,7 @@ def _wgmma_accumulator_deref_abstract_eval(acc): # Dereferencing implies flushing so we have a wgmma pipeline effect. ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc assert isinstance(ret, jax_core.ShapedArray), acc - return ret, {_wgmma_pipeline_effect} + return ret, {gpu_core._wgmma_pipeline_effect} @discharge.register_discharge_rule(wgmma_accumulator_deref_p) @@ -620,7 +603,7 @@ def layout_cast(x: Any, new_layout: Layout): @set_max_registers_p.def_effectful_abstract_eval def _set_max_registers_abstract_eval(n, *, action): del n, action # Unused. - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(set_max_registers_p) @@ -648,7 +631,7 @@ def set_max_registers(n: int, *, action: Literal["increase", "decrease"]): @commit_smem_p.def_effectful_abstract_eval def _commit_smem_abstract_eval(): - return (), {_memory_effect} + return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(commit_smem_p) From 11faf68018250b3681be70e0181b8a1da7da5662 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 23 Oct 2024 09:37:23 -0700 Subject: [PATCH 020/698] [Pallas:TPU] Match lax.pow(float, int) behavior in Pallas. Both math::PowF and Exp2Op require a floating point exponent so casting it to x.dtype for parity of lax.pow. PiperOrigin-RevId: 688997089 --- jax/_src/pallas/mosaic/lowering.py | 5 +++++ tests/pallas/ops_test.py | 3 --- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 18b73a66ca23..8b2c18165f61 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2014,6 +2014,11 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): + # jax accepts float base (x) and integer/float exponent (y), and integer + # exponent is casted to float. + out_type = aval_to_ir_type(ctx.avals_out[0]) + if jnp.issubdtype(ctx.avals_in[1].dtype, jnp.integer): + y = arith.sitofp(out_type, y) if not isinstance(x, ir.Value) and x == 2.: return math.exp2(y) x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 7b54ef5f9f88..79a50f562271 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -868,9 +868,6 @@ def kernel(x_ref, o_ref): ("float64", "float64"), ) def test_pow(self, x_dtype, y_dtype): - if jtu.test_device_matches(["tpu"]): - self.skipTest("TODO: Error on TPU") - if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") From 84cd3567b58c65e1753c5c45b14b3debae68021e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 23 Oct 2024 10:08:14 -0700 Subject: [PATCH 021/698] Avoid querying metadata query to check if it's GCE if `TPU_SKIP_MDS_QUERY` is set. PiperOrigin-RevId: 689009215 --- jax/_src/clusters/cloud_tpu_cluster.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index 02aea2cd64d5..c8aa765c181c 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -180,6 +180,9 @@ def is_env_present(cls) -> bool: if not running_in_cloud_tpu_vm: logger.debug("Did not detect cloud TPU VM") return False + if os.environ.get("TPU_SKIP_MDS_QUERY") is not None: + logger.debug("TPU_SKIP_MDS_QUERY is set to True, so it's probably not a GCE TPU cluster.") + return False metadata_response, metadata_code = get_metadata('agent-worker-number') if metadata_code == metadata_response_code_success: logger.debug("Gce Tpu Cluster detected for Jax Distributed System") From 148f9d655978be3b7f2945ffbd5b2b4113c1622f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 23 Oct 2024 05:37:26 -0700 Subject: [PATCH 022/698] Better docs for jnp.cov & jnp.corrcoef --- jax/_src/numpy/lax_numpy.py | 166 +++++++++++++++++++++++++++++++++++- 1 file changed, 164 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 69f6b6ebbdf3..6c12c87cbd2f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -12418,12 +12418,102 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, return moveaxis(result, 0, axis) -@util.implements(np.cov) @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, bias: bool = False, ddof: int | None = None, fweights: ArrayLike | None = None, aweights: ArrayLike | None = None) -> Array: + r"""Estimate the weighted sample covariance. + + JAX implementation of :func:`numpy.cov`. + + The covariance :math:`C_{ij}` between variable *i* and variable *j* is defined + as + + .. math:: + + cov[X_i, X_j] = E[(X_i - E[X_i])(X_j - E[X_j])] + + Given an array of *N* observations of the variables :math:`X_i` and :math:`X_j`, + this can be estimated via the sample covariance: + + .. math:: + + C_{ij} = \frac{1}{N - 1} \sum_{n=1}^N (X_{in} - \overline{X_i})(X_{jn} - \overline{X_j}) + + Where :math:`\overline{X_i} = \frac{1}{N} \sum_{k=1}^N X_{ik}` is the mean of the + observations. + + Args: + m: array of shape ``(M, N)`` (if ``rowvar`` is True), or ``(N, M)`` + (if ``rowvar`` is False) representing ``N`` observations of ``M`` variables. + ``m`` may also be one-dimensional, representing ``N`` observations of a + single variable. + y: optional set of additional observations, with the same form as ``m``. If + specified, then ``y`` is combined with ``m``, i.e. for the default + ``rowvar = True`` case, ``m`` becomes ``jnp.vstack([m, y])``. + rowvar: if True (default) then each row of ``m`` represents a variable. If + False, then each column represents a variable. + bias: if False (default) then normalize the covariance by ``N - 1``. If True, + then normalize the covariance by ``N`` + ddof: specify the degrees of freedom. Defaults to ``1`` if ``bias`` is False, + or to ``0`` if ``bias`` is True. + fweights: optional array of integer frequency weights of shape ``(N,)``. This + is an absolute weight specifying the number of times each observation is + included in the computation. + aweights: optional array of observation weights of shape ``(N,)``. This is + a relative weight specifying the "importance" of each observation. In the + ``ddof=0`` case, it is equivalent to assigning probabilities to each + observation. + + Returns: + A covariance matrix of shape ``(M, M)``. + + See also: + - :func:`jax.numpy.corrcoef`: compute the correlation coefficient, a normalized + version of the covariance matrix. + + Examples: + Consider these observations of two variables that correlate perfectly. + The covariance matrix in this case is a 2x2 matrix of ones: + + >>> x = jnp.array([[0, 1, 2], + ... [0, 1, 2]]) + >>> jnp.cov(x) + Array([[1., 1.], + [1., 1.]], dtype=float32) + + Now consider these observations of two variables that are perfectly + anti-correlated. The covariance matrix in this case has ``-1`` in the + off-diagonal: + + >>> x = jnp.array([[-1, 0, 1], + ... [ 1, 0, -1]]) + >>> jnp.cov(x) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + Equivalently, these sequences can be specified as separate arguments, + in which case they are stacked before continuing the computation. + + >>> x = jnp.array([-1, 0, 1]) + >>> y = jnp.array([1, 0, -1]) + >>> jnp.cov(x, y) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + In general, the entries of the covariance matrix may be any positive + or negative real value. For example, here is the covariance of 100 + points drawn from a 3-dimensional standard normal distribution: + + >>> key = jax.random.key(0) + >>> x = jax.random.normal(key, shape=(3, 100)) + >>> with jnp.printoptions(precision=2): + ... print(jnp.cov(x)) + [[ 1.22 -0. 0.11] + [-0. 0.84 -0.1 ] + [ 0.11 -0.1 0.88]] + """ if y is not None: m, y = util.promote_args_inexact("cov", m, y) if y.ndim > 2: @@ -12486,9 +12576,81 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze() -@util.implements(np.corrcoef) @partial(jit, static_argnames=('rowvar',)) def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array: + r"""Compute the Pearson correlation coefficients. + + JAX implementation of :func:`numpy.corrcoef`. + + This is a normalized version of the sample covariance computed by :func:`jax.numpy.cov`. + For a sample covariance :math:`C_{ij}`, the correlation coefficients are + + .. math:: + + R_{ij} = \frac{C_{ij}}{\sqrt{C_{ii}C_{jj}}} + + they are constructed such that the values satisfy :math:`-1 \le R_{ij} \le 1`. + + Args: + x: array of shape ``(M, N)`` (if ``rowvar`` is True), or ``(N, M)`` + (if ``rowvar`` is False) representing ``N`` observations of ``M`` variables. + ``x`` may also be one-dimensional, representing ``N`` observations of a + single variable. + y: optional set of additional observations, with the same form as ``m``. If + specified, then ``y`` is combined with ``m``, i.e. for the default + ``rowvar = True`` case, ``m`` becomes ``jnp.vstack([m, y])``. + rowvar: if True (default) then each row of ``m`` represents a variable. If + False, then each column represents a variable. + + Returns: + A covariance matrix of shape ``(M, M)``. + + See also: + - :func:`jax.numpy.cov`: compute the covariance matrix. + + Examples: + Consider these observations of two variables that correlate perfectly. + The correlation matrix in this case is a 2x2 matrix of ones: + + >>> x = jnp.array([[0, 1, 2], + ... [0, 1, 2]]) + >>> jnp.corrcoef(x) + Array([[1., 1.], + [1., 1.]], dtype=float32) + + Now consider these observations of two variables that are perfectly + anti-correlated. The correlation matrix in this case has ``-1`` in the + off-diagonal: + + >>> x = jnp.array([[-1, 0, 1], + ... [ 1, 0, -1]]) + >>> jnp.corrcoef(x) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + Equivalently, these sequences can be specified as separate arguments, + in which case they are stacked before continuing the computation. + + >>> x = jnp.array([-1, 0, 1]) + >>> y = jnp.array([1, 0, -1]) + >>> jnp.corrcoef(x, y) + Array([[ 1., -1.], + [-1., 1.]], dtype=float32) + + The entries of the correlation matrix are normalized such that they + lie within the range -1 to +1, where +1 indicates perfect correlation + and -1 indicates perfect anti-correlation. For example, here is the + correlation of 100 points drawn from a 3-dimensional standard normal + distribution: + + >>> key = jax.random.key(0) + >>> x = jax.random.normal(key, shape=(3, 100)) + >>> with jnp.printoptions(precision=2): + ... print(jnp.corrcoef(x)) + [[ 1. -0. 0.1 ] + [-0. 1. -0.12] + [ 0.1 -0.12 1. ]] + """ util.check_arraylike("corrcoef", x) c = cov(x, y, rowvar) if len(shape(c)) == 0: From 62351585824c5f5bb70009456e4ea4949fad0320 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 23 Oct 2024 11:17:12 -0700 Subject: [PATCH 023/698] Dot algorithms are now supported for all types, change the test to reflect it. PiperOrigin-RevId: 689036316 --- tests/lax_test.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 1581c61d57eb..4d20240e1940 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1122,11 +1122,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on TPU." ) - if algorithm != lax.DotAlgorithmPreset.DEFAULT and dtype != np.float32: - raise SkipTest( - f"The dot algorithm '{algorithm}' is only supported for float32 on" - " TPU." - ) lhs_shape = (3, 4) rhs_shape = (4, 3) rng = jtu.rand_default(self.rng()) From 9bf1516abebdd431aa80eed4e39f42df593c3444 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 23 Oct 2024 11:37:19 -0700 Subject: [PATCH 024/698] Improve docs for jnp.block --- jax/_src/numpy/lax_numpy.py | 70 ++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c12c87cbd2f..397a818ae405 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5098,9 +5098,77 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: else: return asarray(xs), 1 -@util.implements(np.block) + @jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: + """Create an array from a list of blocks. + + JAX implementation of :func:`numpy.block`. + + Args: + arrays: an array, or nested list of arrays which will be concatenated + together to form the final array. + + Returns: + a single array constructed from the inputs. + + See also: + - :func:`concatenate`, :func:`concat`: concatenate arrays along an existing axis. + - :func:`stack`, :func:`vstack`, :func:`hstack`, :func:`dstack` concatenate + arrays along a new axis. + + Examples: + consider these blocks: + + >>> zeros = jnp.zeros((2, 2)) + >>> ones = jnp.ones((2, 2)) + >>> twos = jnp.full((2, 2), 2) + >>> threes = jnp.full((2, 2), 3) + + Passing a single array to :func:`block` returns the array: + + >>> jnp.block(zeros) + Array([[0., 0.], + [0., 0.]], dtype=float32) + + Passing a simple list of arrays concatenates them along the last axis: + + >>> jnp.block([zeros, ones]) + Array([[0., 0., 1., 1.], + [0., 0., 1., 1.]], dtype=float32) + + Passing a doubly-nested list of arrays concatenates the inner list along + the last axis, and the outer list along the second-to-last axis: + + >>> jnp.block([[zeros, ones], + ... [twos, threes]]) + Array([[0., 0., 1., 1.], + [0., 0., 1., 1.], + [2., 2., 3., 3.], + [2., 2., 3., 3.]], dtype=float32) + + Note that blocks need not align in all dimensions, though the size along the axis + of concatenation must match. For example, this is valid because after the inner, + horizontal concatenation, the resulting blocks have a valid shape for the outer, + vertical concatenation. + + >>> a = jnp.zeros((2, 1)) + >>> b = jnp.ones((2, 3)) + >>> c = jnp.full((1, 2), 2) + >>> d = jnp.full((1, 2), 3) + >>> jnp.block([[a, b], [c, d]]) + Array([[0., 1., 1., 1.], + [0., 1., 1., 1.], + [2., 2., 3., 3.]], dtype=float32) + + Note also that this logic generalizes to blocks in 3 or more dimensions. + Here's a 3-dimensional block-wise array: + + >>> x = jnp.arange(6).reshape((1, 2, 3)) + >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] + >>> jnp.block(blocks).shape + (5, 8, 9) + """ out, _ = _block(arrays) return out From a7d711513c01bbe756c816b1517ad9a2cfbcf9ed Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 23 Oct 2024 15:09:46 -0400 Subject: [PATCH 025/698] Perform searchsorted binary search using unsigned intermediate values. Midpoint computation for a binary search should be performed unsigned, see https://research.google/blog/extra-extra-read-all-about-it-nearly-all-binary-searches-and-mergesorts-are-broken/ In addition, we can avoid the somewhat verbose floor_divide HLO since we know the values in question are positive. --- jax/_src/numpy/lax_numpy.py | 4 +++- tests/lax_numpy_test.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c12c87cbd2f..7cb9f3969cdf 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -12672,9 +12672,11 @@ def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> A @partial(vectorize, excluded={0, 1, 3, 4}) def _searchsorted_via_scan(unrolled: bool, sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: op = _sort_le_comparator if side == 'left' else _sort_lt_comparator + unsigned_dtype = np.uint32 if dtype == np.int32 else np.uint64 def body_fun(state, _): low, high = state - mid = (low + high) // 2 + mid = low.astype(unsigned_dtype) + high.astype(unsigned_dtype) + mid = lax.div(mid, unsigned_dtype(2)).astype(dtype) go_left = op(query, sorted_arr[mid]) return (where(go_left, low, mid), where(go_left, mid, high)), () n_levels = int(np.ceil(np.log2(len(sorted_arr) + 1))) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a6d9c613379c..ccfce51ec909 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2785,7 +2785,7 @@ def testSearchsortedDtype(self): message="NumPy will stop allowing conversion.*"): out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) else: - with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): + with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype.*int64"): with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"): out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) From ea1fc65c69f58f0881ae7e37f682bbb94853127e Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 23 Oct 2024 12:22:58 -0700 Subject: [PATCH 026/698] [Pallas TPU] Fix `OpsTest.test_elementwise` test for bf16 inputs For bf16 inputs, the shape must be (8, 128) PiperOrigin-RevId: 689060557 --- tests/pallas/ops_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 79a50f562271..3b68dc839ef1 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -793,12 +793,15 @@ def test_elementwise(self, fn, dtype): self.skipTest(f"{fn.__name__} not implemented on TPU") @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1 + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + grid=1, ) def kernel(x_ref, o_ref): o_ref[:] = fn(x_ref[...]) - x = jnp.array([0.42, 2.4]).astype(dtype) + # create an array with shape (8, 128) + x = jnp.array([0.42, 2.4] * (8 * 128 // 2)).reshape(8, 128).astype(dtype) self.assertAllClose(kernel(x), fn(x), rtol=1e-6) @parameterized.named_parameters( From 16f8958ececb55fe9036296e4ad027ddd7e89fd7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 23 Oct 2024 12:29:34 -0700 Subject: [PATCH 027/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ffcd64e30efda9610a89af347f2023d050d788a3. PiperOrigin-RevId: 689062763 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6a80394a2e7e..8777af77c673 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "bf8dafb2a7dfe5ea32988515d491ca6c0fd2c83f" -XLA_SHA256 = "a74647bd55cc0c9765d02bdaa29c5a78580afa34a0c9180a895f3e7bd06ac1b1" +XLA_COMMIT = "ffcd64e30efda9610a89af347f2023d050d788a3" +XLA_SHA256 = "02bd9cccb4bf1b1616ca6585942679d9653e125a497f560fd72f1fa0c572cdd1" def repo(): tf_http_archive( From b8bacda2d914118bbd763b88d67d2c37f4171b70 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 23 Oct 2024 16:21:36 -0700 Subject: [PATCH 028/698] [Mosaic TPU] Use native vector tiling to load and store with untiled memref. PiperOrigin-RevId: 689142734 --- .../tpu/transforms/apply_vector_layout.cc | 69 ++++++++++++++----- .../tpu/transforms/infer_vector_layout.cc | 39 +++++++++-- jaxlib/mosaic/dialect/tpu/util.cc | 29 ++++++++ jaxlib/mosaic/dialect/tpu/util.h | 8 +++ 4 files changed, 122 insertions(+), 23 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 3a9a36f6c0b7..ee95feb4d18e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3020,14 +3020,29 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( Tiling memref_tiling, getMemRefTiling(load_op.getBase(), ctx.target_shape)); - if (memref_tiling != layout_out.tiling() && - !(memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 && - memref_tiling[1] % layout_out.tiling()[1] == 0)) { - // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). - // TODO(b/295393167): need to support strided load for bitwidth < 32. - if (layout_out.bitwidth() != 32 || - layout_out.tiling() != std::array{1, ctx.target_shape[1]}) { - return op.emitOpError("Not implemented"); + if (memref_tiling != layout_out.tiling()) { + if (memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 && + memref_tiling[1] % layout_out.tiling()[1] == 0) { + // In this case, it is valid to use output tiling (1, 128 * packing) when + // loading from a 1D memref. + } else if (layout_out.bitwidth() == 32 && + layout_out.tiling() == + std::array{1, ctx.target_shape[1]}) { + // In this case, it is valid to use output tiling (1, TARGET_SHAPE.lanes) + // because we strided-load one row from each tile of the memref. This can + // save us a bunch of loads! + // TODO(b/295393167): need to support strided load for bitwidth < 32. + } else if (layout_out.bitwidth() == 32 && + canReinterpretToUntiledMemref( + memref_ty, ctx.target_shape, + /*allow_minormost_padding=*/true)) { + // In this case, if the memref can be reinterpreted to untiled, it is + // valid to use any tiling for output. But using native tiling can save us + // a bunch of loads! + } else { + return op.emitOpError( + "Not implemented: dismatch in memref tiling and vector tiling in " + "load"); } } // TODO(apaszke): Check that loads are from vmem! @@ -4204,14 +4219,31 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( const Tiling memref_tiling, getMemRefTiling(store_op.getBase(), ctx.target_shape)); - if (memref_tiling != to_store_layout.tiling() && - !(memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 && - memref_tiling[1] % to_store_layout.tiling()[1] == 0)) { - // Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes). - // TODO(b/295393167): need to support strided store for bitwidth < 32. - if (to_store_layout.bitwidth() != 32 || - to_store_layout.tiling() != Tiling{1, ctx.target_shape[1]}) { - return op.emitOpError("Not implemented"); + if (memref_tiling != to_store_layout.tiling()) { + if (memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 && + memref_tiling[1] % to_store_layout.tiling()[1] == 0) { + // In this case, it is valid to have to_store tiling (1, 128 * packing) + // when storing to a 1D memref. + } else if (to_store_layout.bitwidth() == 32 && + to_store_layout.tiling() == + std::array{1, ctx.target_shape[1]}) { + // In this case, it is valid to have to_store tiling (1, + // TARGET_SHAPE.lanes) because we strided-store one row to each tile of + // the memref. This can save us a bunch of stores! + // TODO(b/295393167): need to support strided store for bitwidth < 32. + } else if (to_store_layout.bitwidth() == 32 && + // We accept padding in the minormost dim, because + // apply_vector_layout will properly mask stores。 + canReinterpretToUntiledMemref( + memref_ty, ctx.target_shape, + /*allow_minormost_padding=*/true)) { + // In this case, if the memref can be reinterpreted to untiled, it is + // valid to use any tiling for to_store. But using native tiling can save + // us a bunch of stores! + } else { + return op.emitOpError( + "Not implemented: dismatch in memref tiling and vector tiling in " + "store"); } } @@ -5143,8 +5175,10 @@ FailureOr> tpu_rotate_with_overflow( // Compute the mask for the blend. // Positive blends blend "forward" and negative blends blend "backward". auto mask_val = amount; + auto vreg_rot_amount = amount; if (amount < 0) { mask_val = layout_in.tiling()[tiling_dim] - std::abs(amount); + vreg_rot_amount += target_shape[tiling_dim]; } auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, loc); auto mask = builder.create( @@ -5156,7 +5190,8 @@ FailureOr> tpu_rotate_with_overflow( in_tiles.Each([&](absl::Span idxs, Value *v) { if (dim >= in_tiles.num_dimensions() - 2) { *v = builder.create(loc, res_vreg_ty, in_tiles(idxs), - amount, tiling_dim, nullptr, nullptr); + vreg_rot_amount, tiling_dim, nullptr, + nullptr); } }); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2adc3bf0768e..6bec22403724 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -50,6 +50,7 @@ limitations under the License. #include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" namespace mlir::tpu { @@ -1187,7 +1188,7 @@ class VectorLayoutInferer { } LogicalResult infer(vector::LoadOp op) { - auto src_ty = op.getMemRefType(); + auto src_ty = getMemRefType(op.getBase()); auto res_ty = op.getVectorType(); TPU_CHECK_OP(src_ty.getRank() == res_ty.getRank(), "memref and vector rank mismatch"); @@ -1280,6 +1281,16 @@ class VectorLayoutInferer { setLayout(op, in_layout, VectorLayout(bitwidth, {std::nullopt, offsets[1]}, layout_tiling, ImplicitDim::kNone)); + } else if (bitwidth == 32 && + canReinterpretToUntiledMemref( + src_ty, target_shape_, /*allow_minormost_padding=*/true) && + *(src_ty.getShape().end() - 2) > 1) { + // Since it is untiled, we can load from any arbitrary address which + // means we can always set the sublane offset to 0. + // Note: if the src_shape[-2] == 1, we can just use the tiling from ref. + setLayout(op, in_layout, + VectorLayout(bitwidth, {0, offsets[1].value_or(0)}, + nativeTiling(bitwidth), ImplicitDim::kNone)); } else { setLayout( op, in_layout, @@ -1515,7 +1526,7 @@ class VectorLayoutInferer { } LogicalResult infer(vector::StoreOp op) { - auto ref_ty = op.getMemRefType(); + auto ref_ty = getMemRefType(op.getBase()); auto store_ty = op.getValueToStore().getType(); TPU_CHECK_OP(ref_ty.getRank() == store_ty.getRank(), "memref and vector rank mismatch"); @@ -1596,11 +1607,27 @@ class VectorLayoutInferer { // We can strided store sublanes if we're storing a single sublane for // multiple times. Enabling this helps store one entire row to memref // more efficiently. - store_layout = VectorLayout(store_ty.getElementTypeBitWidth(), offsets, - {1, tiling[1]}, ImplicitDim::kNone); + store_layout = + VectorLayout(bitwidth, offsets, {1, tiling[1]}, ImplicitDim::kNone); + } else if (bitwidth == 32 && + // We accept padding in the minormost dim, because + // apply_vector_layout will properly mask stores. + canReinterpretToUntiledMemref( + ref_ty, target_shape_, /*allow_minormost_padding=*/true)) { + // Since it is untiled, we can store to any arbitrary address which + // means the sublane offset can be any value and we can fold it to + // 2nd minor index. + // TODO(jevinjiang): We can fold the sublane offset into the 2nd minor + // index. But we need to handle negative index in lower-to-llo. For + // now, we just force the sublane offset to be 0. + if (offsets[1].value_or(0) < 0 || offsets[1].value_or(0) >= tiling[1]) { + offsets[1] = 0; + } + store_layout = VectorLayout(bitwidth, {0, offsets[1]}, + nativeTiling(bitwidth), ImplicitDim::kNone); } else { - store_layout = VectorLayout(store_ty.getElementTypeBitWidth(), offsets, - {tiling[0], tiling[1]}, ImplicitDim::kNone); + store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]}, + ImplicitDim::kNone); } } SmallVector in_layout{store_layout}; diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 0e3e6d0d9cd8..638a76fa5683 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -15,12 +15,14 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/util.h" +#include #include #include "llvm/Support/MathExtras.h" #include "absl/types/span.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { SmallVector ComputeTileStrides(MemRefType memref_ty, @@ -39,4 +41,31 @@ SmallVector ComputeTileStrides(MemRefType memref_ty, } return tile_strides; } + +bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty, + const std::array& target_shape, + bool allow_minormost_padding) { + auto tiled_layout = + dyn_cast(tiled_memref_ty.getLayout()); + if (!tiled_layout) { + // We expect the tiled memref to have a tiled layout. + return false; + } + if (tiled_layout.getTiles().empty() || + tiled_layout.getTiles().front().dimensions().size() != 2 || + tiled_memref_ty.getRank() < 2) { + // TODO(jevinjiang): Currently we only support >= 2D memref, we might + // need to handle 1D memref if we find a use case. + return false; + } + if (!allow_minormost_padding && + *(tiled_memref_ty.getShape().end() - 1) != target_shape[1]) { + return false; + } + auto packing = 32 / tiled_memref_ty.getElementTypeBitWidth(); + return (*(tiled_memref_ty.getShape().end() - 1) <= target_shape[1] && + *(tiled_memref_ty.getShape().end() - 2) % packing == 0 && + *(tiled_layout.getTileStrides().end() - 1) == 1 && + *(tiled_layout.getTileStrides().end() - 2) == 1); +} } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index c18bd1b3fbc2..f1771b948304 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -98,6 +98,14 @@ std::string shapeToString(const T &shape) { SmallVector ComputeTileStrides(MemRefType memref_ty, absl::Span tiling); + +// Returns true if a >=2D memref has a tiled layout and can be equivalently +// considered as an untiled memref, except for potential padding in the +// minormost dimension up to target_shape[1] (if allow_minormost_padding is +// true). +bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty, + const std::array &target_shape, + bool allow_minormost_padding = false); } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ From bd9a10e4eb42ed8151f15aac8f90c180a5353b7b Mon Sep 17 00:00:00 2001 From: ZincCat Date: Thu, 24 Oct 2024 02:20:54 -0400 Subject: [PATCH 029/698] fix the wrong output of pallas attention kernel when q_len!=kv_len --- jax/experimental/pallas/ops/gpu/attention.py | 64 +++++++++++--------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 66c9dea39734..198340ec0d11 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -177,13 +177,14 @@ def mha( debug: bool = False, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) num_warps_ = num_warps if num_warps_ is None: @@ -198,16 +199,16 @@ def mha( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) return pl.pallas_call( @@ -243,13 +244,14 @@ def _mha_forward( debug: bool, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) num_warps_ = num_warps if num_warps_ is None: @@ -260,7 +262,7 @@ def _mha_forward( out_shape = [ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out jax.ShapeDtypeStruct( - shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse ), ] in_specs = [ @@ -268,16 +270,16 @@ def _mha_forward( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out, lse = pl.pallas_call( kernel, @@ -362,7 +364,8 @@ def mha_backward_kernel( block_d: int, ): del out_ref # Not needed - seq_len = q_ref.shape[0] + q_seq_len = q_ref.shape[0] + kv_seq_len = k_ref.shape[0] # Scan #1: dK and dV # 1. Load a block of K and V of size (block_k1, head_dim) in SMEM. @@ -423,7 +426,7 @@ def inner_loop_dkdv(start_q, carry): lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0 dv, dk = lax.fori_loop( - lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) + lower_bound, pl.cdiv(q_seq_len, block_q1), inner_loop_dkdv, (dv, dk) ) dv_ref[...] = dv.astype(dv_ref.dtype) dk_ref[...] = dk.astype(dk_ref.dtype) @@ -486,7 +489,7 @@ def inner_loop_dq(start_k, dq): if causal: upper_bound = lax.div((start_q + 1) * block_q2, block_k2) else: - upper_bound = pl.cdiv(seq_len, block_k2) + upper_bound = pl.cdiv(kv_seq_len, block_k2) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) dq_ref[...] = dq.astype(dq_ref.dtype) @@ -508,9 +511,10 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, segment_ids, )[1](do) elif backward_pass_impl == "triton": - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) delta = _preprocess_backward(out, do, lse, block_q, debug, interpret) out_shapes = [ jax.ShapeDtypeStruct(q.shape, q.dtype), @@ -520,29 +524,29 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] else: - in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0))) + in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) - grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k)) + grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_k)) num_warps = 8 dq, dk, dv = pl.pallas_call( functools.partial( From 14e0f0e7fae55e31b7fc22ecf8f19cd4675ef19a Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Thu, 24 Oct 2024 01:45:01 -0700 Subject: [PATCH 030/698] [Mosaic GPU] Query SM and PTX ISA dynamically using driver and LLVM Originally proposed in #24021. Slightly rewritter to make testing with internal LLVM toolchains better. Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`). Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`). Then use LLVM to map the SM string to a PTX ISA version that supports the SM. Co-authored-by: Andrey Portnoy PiperOrigin-RevId: 689286774 --- jaxlib/mosaic/gpu/BUILD | 15 ++++++ jaxlib/mosaic/gpu/custom_call.cc | 50 ++++++++++++++---- jaxlib/mosaic/gpu/target.cc | 88 ++++++++++++++++++++++++++++++++ jaxlib/mosaic/gpu/target.h | 30 +++++++++++ 4 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 jaxlib/mosaic/gpu/target.cc create mode 100644 jaxlib/mosaic/gpu/target.h diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index e5eaeb347137..49c488e73850 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -26,6 +26,19 @@ py_library( deps = [":_mosaic_gpu_ext"], ) +cc_library( + name = "target", + srcs = ["target.cc"], + hdrs = ["target.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:MC", + ], +) + cc_library( name = "passes", srcs = [ @@ -104,12 +117,14 @@ cc_library( srcs = ["custom_call.cc"], deps = [ ":passes", + ":target", "//jaxlib/cuda:cuda_vendor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 103f9f78c32f..05c3725fc3a0 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -32,6 +32,7 @@ limitations under the License. #include #include "jaxlib/gpu/vendor.h" +#include "jaxlib/mosaic/gpu/target.h" #include "absl/base/optimization.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" @@ -89,8 +90,30 @@ namespace { using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); +absl::StatusOr> GetSmAndPtxIsaVersion() { + // Assumes driver has been initialized and a context exists. XLA already has + // some utilities to query this, but we try to stay runtime-agnostic, so we + // build our own here. + CUdevice device; + if (cuCtxGetDevice(&device) != CUDA_SUCCESS) { + return absl::InternalError("Failed to get device for current context"); + } + int major = 0; + if (cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, + device) != CUDA_SUCCESS) { + return absl::InternalError("Failed to get major compute capability"); + } + int minor = 0; + if (cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, + device) != CUDA_SUCCESS) { + return absl::InternalError("Failed to get minor compute capability"); + } + return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor); +} + mlir::FailureOr GetPassPipeline( - mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target) { + mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, + const std::string& sm, const std::string& ptx_isa) { static bool register_once = []() { llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget(); @@ -131,7 +154,9 @@ mlir::FailureOr GetPassPipeline( convert-scf-to-cf, convert-nvvm-to-llvm, expand-strided-metadata, - nvvm-attach-target{O=3 chip=sm_90a fast=false features=+ptx80 ftz=false module= triple=nvptx64-nvidia-cuda}, + nvvm-attach-target{O=3 chip=)" + + sm + R"( fast=false features=+)" + ptx_isa + + R"( ftz=false module= triple=nvptx64-nvidia-cuda}, lower-affine, convert-arith-to-llvm{index-bitwidth=0}, convert-index-to-llvm{index-bitwidth=64}, @@ -251,7 +276,8 @@ class TemporaryDirectory { std::string path; }; -void DumpCompilationOutput(mlir::ModuleOp module) { +void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, + const std::string& ptx_isa) { bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; @@ -260,8 +286,8 @@ void DumpCompilationOutput(mlir::ModuleOp module) { } module = module.clone(); // Prevent accidental modification. - auto passes = GetPassPipeline(module.getContext(), - mlir::gpu::CompilationTarget::Assembly); + auto passes = GetPassPipeline( + module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); if (mlir::failed(passes) || mlir::failed(RunPasses(std::move(*passes), module))) { return; @@ -297,7 +323,7 @@ void DumpCompilationOutput(mlir::ModuleOp module) { // Run ptxas to generate SASS. std::vector ptxas_args = { "ptxas", "--opt-level", "3", - "--gpu-name", "sm_90a", "--output-file", + "--gpu-name", sm.c_str(), "--output-file", elf_path.c_str(), ptx_path.c_str()}; if (dump_ptxas) { ptxas_args.push_back("-v"); @@ -321,9 +347,15 @@ void DumpCompilationOutput(mlir::ModuleOp module) { absl::StatusOr> Compile( mlir::ModuleOp module) { - DumpCompilationOutput(module); - auto passes = GetPassPipeline(module.getContext(), - mlir::gpu::CompilationTarget::Binary); + auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); + if (!sm_and_ptx_isa.ok()) { + return sm_and_ptx_isa.status(); + } + const std::string sm = sm_and_ptx_isa.value().first; + const std::string ptx_isa = sm_and_ptx_isa.value().second; + DumpCompilationOutput(module, sm, ptx_isa); + auto passes = GetPassPipeline( + module.getContext(), mlir::gpu::CompilationTarget::Binary, sm, ptx_isa); if (mlir::failed(passes)) { return absl::InternalError("Failed to construct pass pipeline"); } diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc new file mode 100644 index 000000000000..a1a66a709cbe --- /dev/null +++ b/jaxlib/mosaic/gpu/target.cc @@ -0,0 +1,88 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/mosaic/gpu/target.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "llvm/include/llvm/MC/MCSubtargetInfo.h" +#include "llvm/include/llvm/MC/TargetRegistry.h" + +namespace mosaic::gpu { + +absl::StatusOr> GetSmAndPtxIsaVersion( + int major, int minor) { + // "base" compute capability as reported by the driver. + // For example for a Hopper H200 GPU this would return sm_90, and never + // sm_90a. + std::string sm_base = absl::StrCat("sm_", major, minor); + + const std::string triple = "nvptx64-nvidia-cuda"; + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(triple, error); + if (target == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to lookup LLVM target based on triple %s: %s", triple, error)); + } + + // Check if there's a variant of the current SM that ends in "a" + // (has architecture-specific capabilities) + const char* sm_arch_specific = nullptr; + { + // generic subtarget + std::unique_ptr subtarget_info{ + target->createMCSubtargetInfo(triple, "", "")}; + if (subtarget_info == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to get generic LLVM subtarget info for triple %s", triple)); + } + for (const llvm::SubtargetSubTypeKV& subtype : + subtarget_info->getAllProcessorDescriptions()) { + if (absl::StartsWith(subtype.Key, sm_base) && + absl::EndsWith(subtype.Key, "a")) { + sm_arch_specific = subtype.Key; + break; + } + } + } + + const std::string sm = sm_arch_specific ? sm_arch_specific : sm_base; + + std::unique_ptr subtarget_info{ + target->createMCSubtargetInfo(triple, sm, "")}; + if (subtarget_info == nullptr) { + return absl::InternalError( + absl::StrFormat("Failed to get LLVM subtarget info for sm %s", sm)); + } + + for (const llvm::SubtargetFeatureKV& feature : + subtarget_info->getEnabledProcessorFeatures()) { + if (absl::StartsWith(feature.Key, "ptx")) { + std::string ptx_isa = feature.Key; + return std::make_pair(sm, ptx_isa); + } + } + return absl::InternalError(absl::StrFormat( + "Failed to find a PTX ISA LLVM subtarget feature for %s", sm)); +} + +} // namespace mosaic::gpu diff --git a/jaxlib/mosaic/gpu/target.h b/jaxlib/mosaic/gpu/target.h new file mode 100644 index 000000000000..070ecedebd01 --- /dev/null +++ b/jaxlib/mosaic/gpu/target.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_TARGET_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_TARGET_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace mosaic::gpu { + +absl::StatusOr> GetSmAndPtxIsaVersion( + int major, int minor); + +} // namespace mosaic::gpu + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_TARGET_H_ From bb2e2303d79972bfb38fe02099b3231eb1a3f3ad Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 24 Oct 2024 01:53:31 -0700 Subject: [PATCH 031/698] [Pallas:MGPU] Treat each warpgroup as a single logical thread. As an extra minor change, we now disallow specifying the predicate when uniform is unset, as that implies that we're going to use two different mechanisms to select a single thread. PiperOrigin-RevId: 689289365 --- jax/_src/pallas/mosaic_gpu/lowering.py | 17 ++++++++++------- jax/_src/pallas/mosaic_gpu/primitives.py | 18 +++++++++++++++++- jax/experimental/mosaic/gpu/core.py | 2 ++ tests/pallas/mosaic_gpu_test.py | 20 ++++++++++++++++++++ 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 21f6f2570a5d..7517597d637f 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -67,6 +67,7 @@ # sensitive to alignment and while this is quite conservative, it gets the job # done. We should make this more refined in the future. _SMEM_ALIGNMENT = 1024 +WARPGROUP_SIZE = 128 def _align_to(x: int, alignment: int): if (rem := x % alignment): @@ -164,9 +165,11 @@ def _run_scoped_resource_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int: aval = v.aval if isinstance(aval.dtype, gpu_core.BarrierType): rs += Resources( - barrier_counts=collections.Counter( - [mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape)] - ) + barrier_counts=collections.Counter([ + mgpu.Barrier( + aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape + ) + ]) ) else: rs += Resources( @@ -592,7 +595,6 @@ def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: gmem_transform=gmem_transforms, swizzle=swizzle, arrive=False, # The caller must do ``arrive_expect_tx`` manually! - uniform=False, predicate=is_memory_thread, ) @@ -645,7 +647,6 @@ def store( gmem_slice=store_slice, gmem_transform=gmem_transforms, swizzle=swizzle, - uniform=False, predicate=do_store, ) return base_offset @@ -747,7 +748,7 @@ def _(step, carry): ) rs = _estimate_resources(jaxpr) extra_barriers = [ - mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) + mgpu.Barrier(aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape) for aval in scratch_avals if isinstance(aval.dtype, gpu_core.BarrierType) ] @@ -1216,7 +1217,9 @@ def _run_scoped_lowering_rule( elif isinstance(aval.dtype, gpu_core.BarrierType): input_refs.append( ctx.module_ctx.reserve_barrier( - mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) + mgpu.Barrier( + aval.dtype.num_arrivals * WARPGROUP_SIZE, *aval.shape + ) ) ) should_discharge.append(False) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b3c5a8c5839c..3874c7125d5e 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -17,6 +17,7 @@ from __future__ import annotations import enum +import math from typing import Any, Literal import jax @@ -24,6 +25,7 @@ from jax._src import state from jax._src import tree_util from jax._src import util +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -34,6 +36,9 @@ import jax.experimental.mosaic.gpu as mgpu +WARPGROUP_SIZE = 128 + + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True @@ -169,8 +174,19 @@ def _copy_gmem_to_smem_lowering( barrier = barrier.__getitem__( *map(lowering._as_index, barrier_indexer.indices) ) + dst_ty = ir.MemRefType(dst.type) + bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) + if bytes % WARPGROUP_SIZE: + raise NotImplementedError("Only aligned copies are supported") + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: apaszke - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= WARPGROUP_SIZE + barrier.arrive_expect_tx(bytes) ctx.launch_ctx.async_copy( - src_ref=src, dst_ref=dst, barrier=barrier, **copy_params + src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params ) return () diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 8ee1bda2f41e..d35340340695 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -355,6 +355,8 @@ def async_copy( f"Expected same element type, got {element_type} and" f" {dst_ref_ty.element_type}" ) + if predicate is not None and not uniform: + raise ValueError("Predicate can only be defined when uniform is True") if not isinstance(gmem_transform, tuple): gmem_transform = (gmem_transform,) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 5c546ad17be3..6a2ba17bb34b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1087,6 +1087,26 @@ def kernel(): f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) + def test_cross_wg_barrier(self): + mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) + + @jax.jit + def f(): + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def kernel(): + def scoped(barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + # Each warpgroup is a single logical thread! + pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) + y_init = jnp.zeros((2, 128), np.int32) + return inner(y_init) + np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + if __name__ == "__main__": absltest.main() From 717467a82f32e8f9b9e2b2843fddcf4ed3942473 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 24 Oct 2024 04:17:16 -0700 Subject: [PATCH 032/698] [pallas] `input_output_aliases` now only include refs which have been written to PiperOrigin-RevId: 689323778 --- jax/_src/pallas/core.py | 48 ++++++++++++++++++++++++++++++ jax/_src/pallas/mosaic/core.py | 31 +++++++------------ jax/_src/pallas/mosaic_gpu/core.py | 37 ++++++++--------------- 3 files changed, 71 insertions(+), 45 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index fd7883109b9c..dad45bbae207 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -37,6 +37,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state import types as state_types from jax._src.state.types import TransformedRef import jax.numpy as jnp @@ -1070,6 +1071,53 @@ def _core_map_abstract_eval(*args, jaxpr, mesh): _core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {} + + +def default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + grid, + compiler_params, + backend, + jaxpr, +): + """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" + del out_avals # Unused. + + def body(*args): + # Due to aliasing, ``args`` contains aliased inputs and outputs so we + # remove outputs. + in_refs = args[:len(in_avals)] + jax_core.eval_jaxpr(jaxpr, in_refs) + + assert len(jaxpr.outvars) == 0 + modified_idxs = sorted( + eff.input_index + for eff in jaxpr.effects + if isinstance(eff, state_types.WriteEffect) + ) + any_spec = BlockSpec(memory_space=MemorySpace.ANY) + from jax._src.pallas import pallas_call # Avoid circular dependency. + outs = pallas_call.pallas_call( + body, + out_shape=[in_avals[idx] for idx in modified_idxs], + in_specs=[any_spec] * len(in_avals), + out_specs=[any_spec] * len(modified_idxs), + input_output_aliases={ + in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) + }, + grid=grid, + compiler_params=compiler_params, + backend=backend, + )(*args) + # ``outs`` lacks the unmodified inputs. Add them back in. + all_outs = [*args] + for out_idx, in_idx in enumerate(modified_idxs): + all_outs[in_idx] = outs[out_idx] + return all_outs, () + + @state_discharge.register_discharge_rule(core_map_p) def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwargs): if type(mesh) not in _core_map_mesh_rules: diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 60207256fd05..6e16df2e54de 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -32,6 +32,9 @@ import jax.numpy as jnp import numpy as np +# TODO(b/375357542): Remove the import once the bug is fixed. +_ = pallas_call + map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -250,33 +253,19 @@ def _tensorcore_mesh_discharge_rule( mesh, jaxpr, ): - del out_avals assert isinstance(mesh, TensorCoreMesh) if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") core_axis_name, num_cores = list(mesh.shape.items())[0] - def body(*args): - # Due to aliasing, args contains aliased inputs and outputs so we remove - # outputs. - in_refs = args[:len(in_avals)] - jax_core.eval_jaxpr(jaxpr, in_refs) - assert len(jaxpr.outvars) == 0 - out = pallas_call.pallas_call( - body, - out_shape=in_avals, - in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - out_specs=[pallas_core.BlockSpec( - memory_space=pallas_core.MemorySpace.ANY)] - * len(in_avals), - input_output_aliases={i: i for i in range(len(in_avals))}, + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + jaxpr=jaxpr, grid=((core_axis_name, num_cores),), - compiler_params=dict( - mosaic=dict(dimension_semantics=("parallel",)), - ), + compiler_params=TPUCompilerParams(dimension_semantics=("parallel",)), backend="mosaic_tpu", - )(*args) - return out, () + ) pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( _tensorcore_mesh_discharge_rule diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 9d28d72f08ad..ff22f276001f 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -29,8 +29,6 @@ from jax._src import effects from jax._src import tree_util from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call -from jax._src.state.types import Transform from jax._src.state import indexing from jax._src.state import types as state_types import jax.experimental.mosaic.gpu as mgpu @@ -155,7 +153,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class UntileRef(Transform): +class UntileRef(state_types.Transform): tiling: tuple[int, ...] def transform_shape(self, shape): @@ -173,7 +171,7 @@ def transform_dtype(self, dtype): def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] idxs_after_tiling = [] @@ -231,7 +229,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class TransposeRef(Transform): +class TransposeRef(state_types.Transform): permutation: tuple[int, ...] def transform_shape(self, shape): @@ -244,7 +242,7 @@ def transform_dtype(self, dtype): def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: removed_dims = [ i for i, idx in enumerate(idxs) if not isinstance(idx, slice) ] @@ -316,12 +314,12 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class UnswizzleRef(Transform): +class UnswizzleRef(state_types.Transform): swizzle: int def untransform_index( self, idxs: tuple[Index, ...] - ) -> tuple[tuple[Index, ...], Transform]: + ) -> tuple[tuple[Index, ...], state_types.Transform]: if not idxs: return idxs, self if not all(isinstance(idx, slice) for idx in idxs[-2:]): @@ -523,29 +521,20 @@ def _gpu_mesh_discharge_rule( mesh, jaxpr, ): - del out_avals assert isinstance(mesh, GPUMesh) if mesh.cluster: raise NotImplementedError if mesh.num_threads is None: raise NotImplementedError - def body(*args): - # Due to aliasing, args contains aliased inputs and outputs so we remove - # outputs. - in_refs = args[:len(in_avals)] - jax_core.eval_jaxpr(jaxpr, in_refs) - assert len(jaxpr.outvars) == 0 - any_spec = pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY) - out = pallas_call.pallas_call( - body, - out_shape=in_avals, - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(in_avals), - input_output_aliases={i: i for i in range(len(in_avals))}, + return pallas_core.default_mesh_discharge_rule( + in_avals, + out_avals, + *args, + jaxpr=jaxpr, grid=tuple(mesh.shape.items()), backend="mosaic_gpu", - )(*args) - return out, () + compiler_params=GPUCompilerParams(), + ) pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule From e5f4be55641235837d8601bbf54091cdb0cd2ffe Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 24 Oct 2024 13:07:33 +0200 Subject: [PATCH 033/698] [shape_poly] Expands support for random.choice `random.choice` uses `np.insert(arr.shape, new_shape)` which attempts to coerce all the values in `new_shape` to constants when `arr.shape` is constant. Replace use of `np.insert` with tuple slicing and concatenation. The case when the sampled axis has non-constant size and `replace=False` is not supported, because `permutation` on arrays with non-constant size is not supported. Adds tests for many combinations of arguments for `random.choice`. Improves a few error messages. --- jax/_src/prng.py | 5 +-- jax/_src/random.py | 14 ++++++-- .../jax2tf/tests/shape_poly_test.py | 2 +- tests/shape_poly_test.py | 36 ++++++++++++++++++- 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7ca7db022d89..039b0a309775 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1067,8 +1067,9 @@ def threefry_2x32(keypair, count): odd_size = count.size % 2 if not isinstance(odd_size, int): - msg = ("jax.random functions have limited support for shape polymorphism. " - "In particular, the product of the known dimensions must be even.") + msg = ("jax.random functions have limited support for shape polymorphism " + "when using threefry. " + f"In particular, the array size ({count.size}) must be even.") raise core.InconclusiveDimensionOperation(msg) if odd_size: diff --git a/jax/_src/random.py b/jax/_src/random.py index 203f72d406e5..dc9fc18aff38 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -581,6 +581,10 @@ def _shuffle(key, x, axis) -> Array: # another analysis (where the keys are generated one bit at a time). exponent = 3 # see tjablin@'s analysis for explanation of this parameter uint32max = jnp.iinfo(np.uint32).max + if not core.is_constant_dim(x.size): + raise NotImplementedError( + "shape polymorphism for `permutation` or `shuffle`" + f" for arrays of non-constant size: {x.size}") num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max))) for _ in range(num_rounds): @@ -640,7 +644,9 @@ def choice(key: KeyArrayLike, if n_inputs <= 0: raise ValueError("a must be greater than 0 unless no samples are taken") if not replace and n_draws > n_inputs: - raise ValueError("Cannot take a larger sample than population when 'replace=False'") + raise ValueError( + f"Cannot take a larger sample (size {n_draws}) than " + f"population (size {n_inputs}) when 'replace=False'") if p is None: if replace: @@ -653,7 +659,9 @@ def choice(key: KeyArrayLike, check_arraylike("choice", p) p_arr, = promote_dtypes_inexact(p) if p_arr.shape != (n_inputs,): - raise ValueError("p must be None or match the shape of a") + raise ValueError( + "p must be None or a 1D vector with the same size as a.shape[axis]. " + f"p has shape {p_arr.shape} and a.shape[axis] is {n_inputs}.") if replace: p_cuml = jnp.cumsum(p_arr) r = p_cuml[-1] * (1 - uniform(key, shape, dtype=p_cuml.dtype)) @@ -665,7 +673,7 @@ def choice(key: KeyArrayLike, result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) return result.reshape(shape if arr.ndim == 0 else - np.insert(np.delete(arr.shape, axis), axis, shape)) + arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:]) def normal(key: KeyArrayLike, diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 38af6d9d76d5..7fdc6854da23 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2114,7 +2114,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] polymorphic_shapes=[None, "b0, ..."], expect_error=( (core.InconclusiveDimensionOperation, - "the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else (None, None)), + "array size .* must be even") if flags_name == "threefry_non_partitionable" else (None, None)), override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [ diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 1b213a8b5bb4..ead77e2b5053 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2941,6 +2941,40 @@ def test_vmap_error(self): RandArg((3, 5, 0), _f32)], polymorphic_shapes=[None, "b0, b1, ..."], override_jax_config_flags=override_jax_config_flags), # type: ignore + [ + PolyHarness("random_choice", f"{flags_name}_arr_poly={arr_poly}_shape_poly={shape_poly}_replace={replace}_use_p={use_p}", + lambda key, a, res_shape, use_p: jax.random.choice( + jax.random.wrap_key_data(key), + a, + shape=res_shape.shape, + p=jnp.full((a.shape[1],), 0.1, dtype=_f32) if use_p else None, + axis=1, + replace=replace), + arg_descriptors=[RandArg((key_size,), np.uint32), + RandArg((64, 12, 4), _f32), # sample on axis=1 + RandArg((3, 4), _f32), + StaticArg(use_p)], + # TODO(necula): threefry requires even-sized samples. + polymorphic_shapes=[None, + "_, 2*b1, _" if arr_poly else None, + "b3, b4" if shape_poly else None], + # The array sampled dimension must be larger than res_shape.size + symbolic_constraints=[ + "2*b1 >= 12" if arr_poly else "1 >= 0", + "2*b1 >= b3*b4" if arr_poly and shape_poly else "1 >= 0", + "12 >= b3*b4" if shape_poly else "1 >= 0" + ], + override_jax_config_flags=override_jax_config_flags, + expect_error=( + (NotImplementedError, "permutation") + if arr_poly and not use_p else None)) # type: ignore + # np.insert used in random.choice tries to coerce shape_poly to + # integer arrays, but only when the arr_poly is False. + for arr_poly in [True, False] + for shape_poly in [True, False] + for replace in [True, False] + for use_p in [True, False] + ], PolyHarness("random_split", f"{flags_name}", lambda key, a: jax.random.key_data( jax.random.split(jax.random.wrap_key_data(key), @@ -2971,7 +3005,7 @@ def test_vmap_error(self): polymorphic_shapes=[None, "b0, ..."], expect_error=( (core.InconclusiveDimensionOperation, - "the product of the known dimensions must be even") if flags_name == "threefry_non_partitionable" else None), + "array size .* must be even") if flags_name == "threefry_non_partitionable" else None), override_jax_config_flags=override_jax_config_flags) # type: ignore ] for key_size, flags_name, override_jax_config_flags in [ From e5bbf3dca16bc59bd23cf6f7fba56f8367d85d21 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 24 Oct 2024 10:12:54 +0200 Subject: [PATCH 034/698] [jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf. Consider the use case when we call_tf a restored saved model that includes parameters (hence functions closing over tf.Variable), and then we jax2tf.convert it with native serialization, under tf.function (or for saving to saved model). The lowering for call_tf in presence of functions with captured inputs requires looking up the tf.Variable and reading its value. This fails with an error that `v.numpy()` is not allowd in graph mode. The fix is to use `tf.init_scope()` to lift out of graph building mode, so that we can read the value of the variables. --- jax/experimental/jax2tf/call_tf.py | 15 +++++++++---- jax/experimental/jax2tf/tests/call_tf_test.py | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index baae52403053..2321a8a035f7 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -510,10 +510,17 @@ def _call_tf_lowering( else: captured_inputs.append(inp) - captured_ops = tuple( - mlir.ir_constant(np.asarray(inp)) - for inp in captured_inputs - ) + # The following use case happens when we call_tf a restored saved model that + # includes parameters (hence functions closing over tf.Variable), and then + # we jax2tf.convert it with native serialization, under tf.function (or + # for saving to saved model). The `np.asarray(inp)` fails because it thinks + # it is in TF graph mode. The `tf.init_scope()` lifts out of function-building + # graph scopes, and allows us to read the values of the variables + with tf.init_scope(): + captured_ops = tuple( + mlir.ir_constant(np.asarray(inp)) + for inp in captured_inputs + ) if call_tf_graph: with jax2tf_internal.inside_call_tf(): diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 492dfad4c855..e8d284178691 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -391,6 +391,20 @@ def fun_tf(x): res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose((x * 3. + 4. + 2.) * 3. + 5., res, check_dtypes=False) + def test_with_capture_then_convert_again(self): + captured_by_tf = tf.Variable(np.arange(1024, dtype=np.float32)) + def tf_fn(x): + return tf.math.add(x, captured_by_tf) + + x = np.arange(1024, dtype=np.float32) + res = jax2tf.convert(jax2tf.call_tf(tf_fn))(x) + self.assertAllClose(res, 2 * x) + + # The bug appears only when we use non-eager mode on the converted func + res = tf.function(jax2tf.convert(jax2tf.call_tf(tf_fn)), + autograph=False)(x) + self.assertAllClose(res, 2 * x) + @_parameterized_jit def test_grad(self, with_jit=False): x = np.float32(3.) @@ -957,6 +971,13 @@ def f_jax(param, x): restored_jax = jax2tf.call_tf(restored_model.f) self.assertAllClose(f_jax(param, x), restored_jax(x)) self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x)) + self.assertAllClose(f_jax(param, x), jax2tf.convert(restored_jax)(x)) + self.assertAllClose(f_jax(param, x), + tf.function(jax2tf.convert(restored_jax), + autograph=False)(x)) + self.assertAllClose(f_jax(param, x), + tf.function(jax2tf.convert(restored_jax), + autograph=True)(x)) def test_saved_model_shape_poly(self): tracing_count = 0 From 6c8e56f43f1bd139869b600311c36bc0923e6b58 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 24 Oct 2024 08:45:09 -0700 Subject: [PATCH 035/698] Finish 0.4.35 release by removing dead code PiperOrigin-RevId: 689396609 --- jax/_src/config.py | 80 ++++++++++++-------------- jax/_src/dispatch.py | 60 +++++-------------- jax/_src/sharding_impls.py | 16 ++---- tests/garbage_collection_guard_test.py | 7 --- tests/memories_test.py | 3 - tests/pjit_test.py | 6 -- tests/shard_map_test.py | 2 - 7 files changed, 55 insertions(+), 119 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 8bebd7d904a6..a05e6e190d44 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1719,51 +1719,43 @@ def transfer_guard(new_val: str) -> Iterator[None]: yield -if lib.xla_extension_version < 293: - - def array_garbage_collection_guard(_val): - raise NotImplementedError( - 'jaxlib version is too low for garbage collection guard' - ) +def _update_garbage_collection_guard(state, key, val): + """Applies the transfer guard level within guard_lib.""" + if val is None: + setattr(state, key, None) + elif val == 'allow': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW) + elif val == 'log': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG) + elif val == 'fatal': + setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL) + else: + assert False, f'Invalid garbage collection guard level {val}' -else: - def _update_garbage_collection_guard(state, key, val): - """Applies the transfer guard level within guard_lib.""" - if val is None: - setattr(state, key, None) - elif val == 'allow': - setattr(state, key, guard_lib.GarbageCollectionGuardLevel.ALLOW) - elif val == 'log': - setattr(state, key, guard_lib.GarbageCollectionGuardLevel.LOG) - elif val == 'fatal': - setattr(state, key, guard_lib.GarbageCollectionGuardLevel.FATAL) - else: - assert False, f'Invalid garbage collection guard level {val}' - - array_garbage_collection_guard = optional_enum_state( - name='jax_array_garbage_collection_guard', - enum_values=['allow', 'log', 'fatal'], - # The default is applied by guard_lib. - default=None, - help=( - 'Select garbage collection guard level for "jax.Array" objects.\nThis' - ' option can be used to control what happens when a "jax.Array"' - ' object is garbage collected. It is desirable for "jax.Array"' - ' objects to be freed by Python reference couting rather than garbage' - ' collection in order to avoid device memory being held by the arrays' - ' until garbage collection occurs.\n\nValid values are:\n * "allow":' - ' do not log garbage collection of "jax.Array" objects.\n * "log":' - ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' - ' fatal error if a "jax.Array" is garbage collected.\nDefault is' - ' "allow".' - ), - update_global_hook=lambda val: _update_garbage_collection_guard( - guard_lib.global_state(), 'garbage_collect_array', val - ), - update_thread_local_hook=lambda val: _update_garbage_collection_guard( - guard_lib.thread_local_state(), 'garbage_collect_array', val - ), - ) +array_garbage_collection_guard = optional_enum_state( + name='jax_array_garbage_collection_guard', + enum_values=['allow', 'log', 'fatal'], + # The default is applied by guard_lib. + default=None, + help=( + 'Select garbage collection guard level for "jax.Array" objects.\nThis' + ' option can be used to control what happens when a "jax.Array"' + ' object is garbage collected. It is desirable for "jax.Array"' + ' objects to be freed by Python reference couting rather than garbage' + ' collection in order to avoid device memory being held by the arrays' + ' until garbage collection occurs.\n\nValid values are:\n * "allow":' + ' do not log garbage collection of "jax.Array" objects.\n * "log":' + ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' + ' fatal error if a "jax.Array" is garbage collected.\nDefault is' + ' "allow".' + ), + update_global_hook=lambda val: _update_garbage_collection_guard( + guard_lib.global_state(), 'garbage_collect_array', val + ), + update_thread_local_hook=lambda val: _update_garbage_collection_guard( + guard_lib.thread_local_state(), 'garbage_collect_array', val + ), +) def _update_debug_log_modules(module_names_str: str | None): logging_config.disable_all_debug_logging() diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 179f8430febe..e1e4bce2743a 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -48,13 +48,12 @@ from jax._src import lib from jax._src.mesh import AbstractMesh, Mesh from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - SingleDeviceSharding, NamedSharding, - GSPMDSharding, TransferToMemoryKind, is_single_device_sharding) + SingleDeviceSharding, NamedSharding, TransferToMemoryKind, + is_single_device_sharding) from jax._src.layout import Layout, DeviceLocalLayout @@ -361,50 +360,21 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): f"platform {inp_plat} and target sharding's device set " f"ids: {target_ids} on platform {target_plat}") - if xla_extension_version >= 292: - if inp_sharding.is_fully_replicated: - permute_order = None - else: - permute_order = np.vectorize(target_sharding._device_assignment.index, - otypes=[int])(inp_sharding._device_assignment) - new_mesh = Mesh( - target_sharding.mesh.devices.reshape(inp_sharding.mesh.axis_sizes), - inp_sharding.mesh.axis_names) - new_s = NamedSharding( - new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, - _logical_device_ids=(None if permute_order is None else - tuple(permute_order.tolist()))) - new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) - return api.jit(_identity_fn, out_shardings=target_sharding, - donate_argnums=donate_argnums)(new_x) + if inp_sharding.is_fully_replicated: + permute_order = None else: - old_hlo_sharding = inp_sharding._to_xla_hlo_sharding(x.ndim) - if old_hlo_sharding.is_replicated(): - new_hlo_sharding = old_hlo_sharding - else: - permute_order = np.vectorize(target_sharding._device_assignment.index, + permute_order = np.vectorize(target_sharding._device_assignment.index, otypes=[int])(inp_sharding._device_assignment) - # Unfortunately need to fallback to V1 sharding here. - new_op_sharding = old_hlo_sharding.to_proto() - new_op_sharding.iota_reshape_dims = [] - new_op_sharding.iota_transpose_perm = [] - new_op_sharding.tile_assignment_devices = np.take( - permute_order, old_hlo_sharding.tile_assignment_devices() - ) - new_hlo_sharding = xc.HloSharding.from_proto(new_op_sharding) - assert (list(np.take(inp_sharding._device_assignment, - old_hlo_sharding.tile_assignment_devices())) - == list(np.take(target_sharding._device_assignment, - new_op_sharding.tile_assignment_devices))) - - new_x = array.make_array_from_single_device_arrays( - x.shape, - GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding, - memory_kind=target_sharding.memory_kind), - x._arrays, - ) - return api.jit(_identity_fn, out_shardings=target_sharding, - donate_argnums=donate_argnums)(new_x) + new_mesh = Mesh( + target_sharding.mesh.devices.reshape(inp_sharding.mesh.axis_sizes), + inp_sharding.mesh.axis_names) + new_s = NamedSharding( + new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, + _logical_device_ids=(None if permute_order is None else + tuple(permute_order.tolist()))) + new_x = array.make_array_from_single_device_arrays(x.shape, new_s, x._arrays) + return api.jit(_identity_fn, out_shardings=target_sharding, + donate_argnums=donate_argnums)(new_x) @dataclasses.dataclass(frozen=True) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index f54c39efebce..78a59a4fcd5c 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -37,7 +37,6 @@ are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method -from jax._src.lib import xla_extension_version import numpy as np @@ -242,8 +241,6 @@ class NamedSharding(sharding.Sharding): _parsed_pspec: ParsedPartitionSpec _manual_axes: frozenset[MeshAxisName] _logical_device_ids: tuple[int, ...] | None - if xla_extension_version < 292: - _logical_device_ids = None @use_cpp_method() def __init__( @@ -308,15 +305,10 @@ def _from_parsed_pspec( cls, mesh, parsed_pspec, *, memory_kind=None, _manual_axes=frozenset(), _logical_device_ids=None, ): - if xla_extension_version >= 292: - return cls(mesh, parsed_pspec.get_partition_spec(), - memory_kind=memory_kind, _parsed_pspec=parsed_pspec, - _manual_axes=_manual_axes, - _logical_device_ids=_logical_device_ids) - else: - return cls(mesh, parsed_pspec.get_partition_spec(), - memory_kind=memory_kind, _parsed_pspec=parsed_pspec, - _manual_axes=_manual_axes) + return cls(mesh, parsed_pspec.get_partition_spec(), + memory_kind=memory_kind, _parsed_pspec=parsed_pspec, + _manual_axes=_manual_axes, + _logical_device_ids=_logical_device_ids) @property def num_devices(self) -> int: diff --git a/tests/garbage_collection_guard_test.py b/tests/garbage_collection_guard_test.py index d23d239dda1b..8d534f22801a 100644 --- a/tests/garbage_collection_guard_test.py +++ b/tests/garbage_collection_guard_test.py @@ -20,7 +20,6 @@ from absl.testing import absltest import jax from jax._src import config -from jax._src.lib import xla_extension_version import jax._src.test_util as jtu import jax.numpy as jnp @@ -46,9 +45,6 @@ def _create_array_cycle(): class GarbageCollectionGuardTest(jtu.JaxTestCase): def test_gced_array_is_not_logged_by_default(self): - if xla_extension_version < 293: - self.skipTest("Requires xla_extension_version >= 293") - # Create a reference cycle of two jax.Arrays. _create_array_cycle() @@ -66,9 +62,6 @@ def test_gced_array_is_not_logged_by_default(self): ) def test_gced_array_is_logged(self): - if xla_extension_version < 293: - self.skipTest("Requires xla_extension_version >= 293") - # Use mock_stderr to be able to inspect stderr. mock_stderr = io.StringIO() diff --git a/tests/memories_test.py b/tests/memories_test.py index 7f05ac424127..5f0ab04612e2 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -26,7 +26,6 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.layout import DeviceLocalLayout as DLL, Layout -from jax._src.lib import xla_extension_version from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp @@ -655,8 +654,6 @@ def f(): @jtu.run_on_devices('tpu') def test_ragged_copy_on_host(self): - if xla_extension_version < 290: - self.skipTest('Requires xla_extension_version >= 290') mesh = jtu.create_mesh((2,), ('x')) sharding = jax.sharding.NamedSharding(mesh, P(('x'))) cpu_sharding = sharding.with_memory_kind('pinned_host') diff --git a/tests/pjit_test.py b/tests/pjit_test.py index fd65c79538cf..9e3fdde3d13e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -38,7 +38,6 @@ from jax import stages from jax import lax from jax._src.lax import lax as lax_internal -from jax._src.lib import xla_extension_version from jax.lax import with_sharding_constraint from jax._src import prng from jax.sharding import PartitionSpec as P, Mesh @@ -5604,9 +5603,6 @@ def test_hlo_sharding_manual_replicated(self): self.assertTrue(hs4.is_tiled()) def test_hlo_sharding_with_device_ordering(self): - if xla_extension_version < 291: - self.skipTest('Requires xla_extension_version >= 291') - hs1 = xc.HloSharding.subgroup_with_device_ordering( np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.int64), subgroup_types=[xc.OpSharding.Type.REPLICATED], @@ -5718,7 +5714,6 @@ def f(x): self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text()) - @unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292") def test_lowering_with_sharding_constraint(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) arr = np.arange(16).reshape(4, 2, 2) @@ -5744,7 +5739,6 @@ def f(x): self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str) # TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline. - @unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292") @jtu.skip_on_devices('cpu') def test_compile_with_inferred_out_sharding(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 364a90621fa9..f1da6b927cc6 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -37,7 +37,6 @@ from jax._src import core from jax._src import prng from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals @@ -2642,7 +2641,6 @@ def fwd(a): @unittest.skipIf(sdy is None, "shardy is not enabled") class SdyIntegrationTest(jtu.JaxTestCase): - @unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292") # Verify we can lower to a `ManualComputationOp`. def test_shardy_collective_permute(self): mesh = jtu.create_mesh((2,), ('x',)) From 6634f5a34825638ba6e1433bffc694cb65620970 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 24 Oct 2024 09:48:53 -0700 Subject: [PATCH 036/698] [Mosaic GPU] Use absl::StrCat instead std::string::operator+ Repeated string addition is apparently a bit of an anti-pattern. Not that it matters much in this place, but why not do it properly. PiperOrigin-RevId: 689416587 --- jaxlib/mosaic/gpu/custom_call.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 05c3725fc3a0..2eeac946afb1 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -31,6 +31,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/gpu/target.h" #include "absl/base/optimization.h" @@ -143,7 +144,7 @@ mlir::FailureOr GetPassPipeline( return true; }(); (void)register_once; - return mlir::parsePassPipeline( + return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( canonicalize, @@ -154,8 +155,8 @@ mlir::FailureOr GetPassPipeline( convert-scf-to-cf, convert-nvvm-to-llvm, expand-strided-metadata, - nvvm-attach-target{O=3 chip=)" + - sm + R"( fast=false features=+)" + ptx_isa + + nvvm-attach-target{O=3 chip=)", + sm, R"( fast=false features=+)", ptx_isa, R"( ftz=false module= triple=nvptx64-nvidia-cuda}, lower-affine, convert-arith-to-llvm{index-bitwidth=0}, @@ -169,19 +170,19 @@ mlir::FailureOr GetPassPipeline( gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, - gpu-module-to-binary{format=)" + - mlir::gpu::stringifyCompilationTarget(target).str() + R"(}, + gpu-module-to-binary{format=)", + mlir::gpu::stringifyCompilationTarget(target).str(), R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, - )" + + )", (target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering," - : "") + + : ""), R"( convert-to-llvm, reconcile-unrealized-casts ) - )"); + )")); } mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes, From afc78524e172f88f170a34967d1a3f895ea0e16d Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Oct 2024 10:58:53 -0700 Subject: [PATCH 037/698] Remove silent data corruption runtime flags from persistent cache key. These flags have no effect on the compiled executable, just the runtime execution. PiperOrigin-RevId: 689442580 --- jax/_src/cache_key.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 9bce9d0e4308..6e025653b81d 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -274,6 +274,8 @@ def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]): "--xla_dump_hlo_pipeline_re", "--xla_tpu_sdc_checker_streamz_metric", "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", + "--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks", + "--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present", "--xla_gpu_cuda_data_dir", ] # LINT.ThenChange(:debug_options) From 423112853548b817a5875b084ad33c517517c761 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 24 Oct 2024 17:44:26 +0000 Subject: [PATCH 038/698] improve concreteness error message in remat --- jax/_src/ad_checkpoint.py | 20 +++++++++----------- tests/api_test.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 5160104e2141..f5d5be6a2751 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -410,17 +410,15 @@ def _trace_to_jaxpr(fun, in_tree, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) except core.ConcretizationTypeError as e: msg, = e.args - if 'for checkpoint' not in msg: - raise - new_msg = msg + "\n\n" + ( - "Consider using the `static_argnums` parameter for `jax.remat` or " - "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " - "involving `static_argnums`:\n" - "https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html" - "\n") - new_e = core.ConcretizationTypeError.__new__(core.ConcretizationTypeError) - new_e.args = (new_msg,) - raise new_e from None + if 'for checkpoint' in msg: + msg += "\n\n" + ( + "Consider using the `static_argnums` parameter for `jax.remat` or " + "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " + "involving `static_argnums`:\n" + "https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html" + "\n") + e.args = msg, + raise return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree() diff --git a/tests/api_test.py b/tests/api_test.py index d0a711f4a617..d0396a9a20a7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -33,6 +33,7 @@ import re import subprocess import sys +import traceback import types from typing import NamedTuple import unittest @@ -6423,6 +6424,21 @@ def f(x): y_, = vjp(jnp.ones_like(y)) self.assertAllClose(y, y_, atol=0, rtol=0) + def test_concreteness_error_includes_user_code(self): + @jax.remat + def f(x): + if x > 0: + return x + else: + return jnp.sin(x) + + try: + f(3.) + except TracerBoolConversionError: + self.assertIn('x > 0', traceback.format_exc()) + else: + assert False + @jtu.with_config(jax_pprint_use_color=False) class JaxprTest(jtu.JaxTestCase): From bd417ba6d0b5bb6c080fe948029831f284eb6c9b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Oct 2024 11:36:11 -0700 Subject: [PATCH 039/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1f6bd971dd2f531554eac88c0868b952d2543491. PiperOrigin-RevId: 689457454 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 8777af77c673..b8dcce3a5396 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ffcd64e30efda9610a89af347f2023d050d788a3" -XLA_SHA256 = "02bd9cccb4bf1b1616ca6585942679d9653e125a497f560fd72f1fa0c572cdd1" +XLA_COMMIT = "1f6bd971dd2f531554eac88c0868b952d2543491" +XLA_SHA256 = "30225604bae42819e4212a82e6c17721c7d3bf146e6fb01dfed7f378c7ff6c49" def repo(): tf_http_archive( From 8c9dc21e30751f2a9934da3ae97871307e70ba7a Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 24 Oct 2024 11:50:24 -0700 Subject: [PATCH 040/698] Update hermetic CUDA docs. PiperOrigin-RevId: 689463215 --- docs/developer.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/developer.md b/docs/developer.md index 68e8e931e2e5..cbb60382b7f1 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -109,6 +109,11 @@ current directory. --cuda_version=12.3.2 --cudnn_version=9.1.1 ``` + Please note that these parameters are optional: by default Bazel will + download CUDA and CUDNN redistribution versions provided in `.bazelrc` in the + environment variables `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` + respectively. + To point to CUDA/CUDNN/NCCL redistributions on local file system, you can use the following command: From 9500bd451a20e9e07292ada2590639b63d6e60ab Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 24 Oct 2024 14:15:04 -0700 Subject: [PATCH 041/698] Fix float0 behavior inside shard_map transpose under scan. PiperOrigin-RevId: 689512880 --- jax/experimental/shard_map.py | 2 +- tests/shard_map_test.py | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 24bed503491b..03f3c96005ec 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1652,7 +1652,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite + else x if rewrite or dtypes.dtype(x) == dtypes.float0 else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns)))) for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f1da6b927cc6..c5df1ca7c872 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1495,6 +1495,55 @@ def f(x): self.assertEqual(str(e1.primitive), 'psum2') self.assertEqual(str(e2.primitive), 'pbroadcast') + def test_transpose_float0(self): + mesh = jtu.create_mesh((4,), ('x',)) + + s = jax.sharding.NamedSharding(mesh, P(None, 'x')) + + # vjp that triggers float0 + @jax.custom_vjp + def f(x, _): + return x + def f_fwd(x, y): + return x, jnp.zeros(shape=y.shape, dtype=np.int32) + def f_rev(tmp, g): + return (g, tmp) + f.defvjp(f_fwd, f_rev) + + # trivial vjp that consumes float0 + @jax.custom_vjp + def g(x, y): + return x, y + def g_fwd(x, y): + return jax.vjp(lambda x, y: (x, y), x, y) + def g_bwd(vjp_fn, result): + return vjp_fn(result) + g.defvjp(g_fwd, g_bwd) + + @partial(shard_map, mesh=mesh, in_specs=(P('x'), P()), out_specs=P()) + def f_shmapped(x, y): + return jax.lax.psum(f(x, y).sum(), axis_name=('x')) + + @partial(shard_map, mesh=mesh, check_rep=False, + in_specs=P('x'), out_specs=(P('x'), P())) + def f_shmapped2(x, y): + return g(x, y) + + def f_wrapper(x, y): + x, y = jax.lax.map(lambda xs: f_shmapped2(xs[0], xs[1]), (x, y)) + return jax.lax.map(lambda xs: f_shmapped(xs[0], xs[1]), (x, y)).sum() + + @partial(jax.jit, in_shardings=s, + out_shardings=jax.sharding.NamedSharding(mesh, P())) + def example(x, y): + return jax.grad(f_wrapper, allow_int=True, argnums=(0, 1))(x, y) + + x = np.zeros(shape=(8,16), dtype=np.float32) + y = np.zeros(shape=(8,16), dtype=np.int32) + # Doesn't crash. + dx, dy = example(x, y) + self.assertEqual(dy.dtype, jax.dtypes.float0) + def test_rewrite_binops(self): mesh = jtu.create_mesh((4,), ('x',)) From af28595909d0e45986bd78ae9863ad2ed0e445ae Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Thu, 24 Oct 2024 14:20:03 -0700 Subject: [PATCH 042/698] Add a jax_wheel Bazel rule to build jax pip packages PiperOrigin-RevId: 689514531 --- jax/tools/build_utils.py | 5 ++ jaxlib/jax.bzl | 89 ++++++++++++++++++++ jaxlib/tools/BUILD.bazel | 107 +++++++++++++++++++----- jaxlib/tools/build_gpu_kernels_wheel.py | 3 +- jaxlib/tools/build_gpu_plugin_wheel.py | 3 +- jaxlib/tools/build_wheel.py | 3 +- 6 files changed, 187 insertions(+), 23 deletions(-) diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 84cc697d1894..83d0b4b25923 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -62,6 +62,11 @@ def platform_tag(cpu: str) -> str: }[(platform.system(), cpu)] return f"{platform_name}_{cpu_name}" +def get_githash(jaxlib_git_hash): + if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): + with open(jaxlib_git_hash, "r") as f: + return f.readline().strip() + return jaxlib_git_hash def build_wheel( sources_path: str, output_path: str, package_name: str, git_hash: str = "" diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 3c812d62cfae..d6811bf66b7b 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -305,6 +305,95 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _jax_wheel_impl(ctx): + executable = ctx.executable.wheel_binary + + output = ctx.actions.declare_directory(ctx.label.name) + args = ctx.actions.args() + args.add("--output_path", output.path) # required argument + args.add("--cpu", ctx.attr.platform_tag) # required argument + jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path + args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + + if ctx.attr.enable_cuda: + args.add("--enable-cuda", "True") + if ctx.attr.platform_version == "": + fail("platform_version must be set to a valid cuda version for cuda wheels") + args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + if ctx.attr.enable_rocm: + args.add("--enable-rocm", "True") + if ctx.attr.platform_version == "": + fail("platform_version must be set to a valid rocm version for rocm wheels") + args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + if ctx.attr.skip_gpu_kernels: + args.add("--skip_gpu_kernels") + + args.set_param_file_format("flag_per_line") + args.use_param_file("@%s", use_always = False) + ctx.actions.run( + arguments = [args], + inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], + outputs = [output], + executable = executable, + ) + return [DefaultInfo(files = depset(direct = [output]))] + +_jax_wheel = rule( + attrs = { + "wheel_binary": attr.label( + default = Label("//jaxlib/tools:build_wheel"), + executable = True, + # b/365588895 Investigate cfg = "exec" for multi platform builds + cfg = "target", + ), + "platform_tag": attr.string(mandatory = True), + "git_hash": attr.label(allow_single_file = True), + "enable_cuda": attr.bool(default = False), + # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. + "platform_version": attr.string(mandatory = True, default = ""), + "skip_gpu_kernels": attr.bool(default = False), + "enable_rocm": attr.bool(default = False), + }, + implementation = _jax_wheel_impl, + executable = False, +) + +def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): + """Create jax artifact wheels. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the name of the wheel + wheel_binary: the binary to use to build the wheel + enable_cuda: whether to build a cuda wheel + platform_version: the cuda version to use for the wheel + + Returns: + A directory containing the wheel + """ + _jax_wheel( + name = name, + wheel_binary = wheel_binary, + enable_cuda = enable_cuda, + platform_version = platform_version, + # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to + # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to + # the git hash file needs to be created first. + git_hash = select({ + "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", + "//conditions:default": None, + }), + # Following the convention in jax/tools/build_utils.py. + # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. + platform_tag = select({ + "//jaxlib/tools:macos_arm64": "arm64", + "//jaxlib/tools:win_amd64": "AMD64", + "//jaxlib/tools:arm64": "aarch64", + "@platforms//cpu:x86_64": "x86_64", + }), + ) + jax_test_file_visibility = [] def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 4553dc1e3ea8..48dc03cfb7d6 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -14,9 +14,11 @@ # JAX is Autograd and XLA +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows", "jax_py_test") +load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") licenses(["notice"]) # Apache 2 @@ -30,11 +32,11 @@ py_binary( "//jaxlib", "//jaxlib:README.md", "//jaxlib:setup.py", - "@xla//xla/python:xla_client.py", - "@xla//xla/python:xla_extension", - "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:api.h", + "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", + "@xla//xla/python:xla_client.py", + "@xla//xla/python:xla_extension", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + if_cuda([ @@ -44,11 +46,11 @@ py_binary( "//jaxlib/rocm:rocm_gpu_support", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", ], ) @@ -57,7 +59,7 @@ jax_py_test( srcs = ["build_wheel_test.py"], data = [":build_wheel"], deps = [ - "@bazel_tools//tools/python/runfiles", + "@bazel_tools//tools/python/runfiles", ], ) @@ -102,11 +104,11 @@ py_binary( "//jax_plugins/rocm:__init__.py", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", ], ) @@ -131,10 +133,75 @@ py_binary( "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ - "//jax/tools:build_utils", - "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_wheel//:pkg", - "@pypi_setuptools//:pkg", + "//jax/tools:build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + +selects.config_setting_group( + name = "macos", + match_any = [ + "@platforms//os:osx", + "@platforms//os:macos", + ], +) + +selects.config_setting_group( + name = "arm64", + match_any = [ + "@platforms//cpu:aarch64", + "@platforms//cpu:arm64", + ], +) + +selects.config_setting_group( + name = "macos_arm64", + match_all = [ + ":arm64", + ":macos", + ], +) + +selects.config_setting_group( + name = "win_amd64", + match_all = [ + "@platforms//cpu:x86_64", + "@platforms//os:windows", ], ) + +string_flag( + name = "jaxlib_git_hash", + build_setting_default = "", +) + +config_setting( + name = "jaxlib_git_hash_nightly_or_release", + flag_values = { + ":jaxlib_git_hash": "nightly", + }, +) + +jax_wheel( + name = "jaxlib_wheel", + wheel_binary = ":build_wheel", +) + +jax_wheel( + name = "jax_cuda_plugin_wheel", + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_kernels_wheel", +) + +jax_wheel( + name = "jax_cuda_pjrt_wheel", + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_plugin_wheel", +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index ced0b76c344c..5b3ac636303a 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -171,11 +171,12 @@ def prepare_wheel_rocm( if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: + git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=args.jaxlib_git_hash, + git_hash=git_hash, ) finally: tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 0e2bba0c74d0..08c2389c292a 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -167,11 +167,12 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: + git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=args.jaxlib_git_hash, + git_hash=git_hash, ) finally: if tmpdir: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 5ebdf6e4c6b6..438cebca2b06 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -410,7 +410,8 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=args.jaxlib_git_hash) + git_hash = build_utils.get_githash(args.jaxlib_git_hash) + build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash) finally: if tmpdir: tmpdir.cleanup() From 5c614470adda0628c38d93d4c06b137018dcb492 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 24 Oct 2024 15:58:57 -0700 Subject: [PATCH 043/698] [Pallas TPU] Add lowerings for scalar `absf` and `rsqrt` This PR is similar to https://github.com/jax-ml/jax/pull/24284 PiperOrigin-RevId: 689546724 --- tests/pallas/ops_test.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 3b68dc839ef1..0988e56d7e41 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -832,11 +832,21 @@ def test_elementwise_scalar(self, fn, dtype): "Scalar population count on TPU is only supported in interpret mode" ) + if ( + jtu.test_device_matches(["tpu"]) + and fn == jnp.abs + and jnp.issubdtype(dtype, jnp.integer) + and not self.INTERPRET + ): + self.skipTest( + "Scalar abs for integers on TPU is only supported in interpret mode" + ) + # TODO(b/370578663): implement these lowerings on TPU if jtu.test_device_matches(["tpu"]) and fn in ( - jnp.abs, jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, + jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh, jnp.cbrt, jnp.cosh, jnp.expm1, - jnp.sinh, lax.rsqrt, + jnp.sinh, ): self.skipTest(f"{fn.__name__} not implemented on TPU") From 0bc70bbd731259020376b8411c674fca85116b6d Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 25 Oct 2024 02:11:50 -0700 Subject: [PATCH 044/698] Disable jax2tf test recently added in cl/688976685. See failure: https://github.com/jax-ml/jax/actions/runs/11514933009/job/32054580529?pr=24183 PiperOrigin-RevId: 689703645 --- jax/experimental/jax2tf/tests/jax2tf_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 27f830511eab..8ef9a1a5dd25 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1690,6 +1690,7 @@ def f_jax(x): res, x + _testing_multi_platform_to_add[tf_device_jax_platform]) + @unittest.skip("TODO(danfm): Test fails at head with segfault in GH") def test_dot_algorithm(self): # ref: https://github.com/jax-ml/jax/issues/24236 if tf.version.VERSION.split(".") <= ["2", "17", "0"]: From 9088adda68f1e802f528523427fc24510950122f Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 25 Oct 2024 02:30:17 -0700 Subject: [PATCH 045/698] [jax2tf] Disable jax2tf with non-native serialization. jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024. This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error. PiperOrigin-RevId: 689708392 --- CHANGELOG.md | 10 ++++ jax/experimental/jax2tf/jax2tf.py | 52 ++++++++++++------- jax/experimental/jax2tf/tests/call_tf_test.py | 6 +-- jax/experimental/jax2tf/tests/jax2tf_test.py | 6 +-- .../jax2tf/tests/shape_poly_test.py | 27 +--------- 5 files changed, 50 insertions(+), 51 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fc683cd2861..2a71144340fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.36 +* Breaking Changes + * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` + or with `enable_xla=False` have been deprecated since July 2024, with + JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` + with native serialization will still be supported. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes @@ -21,6 +27,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.experimental.host_callback` has been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See {jax-issue}`#20385` for a discussion of alternatives. + * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` + or with `enable_xla=False` have been deprecated since July 2024, with + JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` + with native serialization is still supported. * Changes: * `jax.lax.FftType` was introduced as a public name for the enum of FFT diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 29a1034e51ed..b12edf2a37ec 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -119,6 +119,10 @@ def _sanitize_scope_name(name): # Line below is different externally and internally. allow_enable_xla_false = lambda: True +# TODO(b/353437398): Deprecate support for `native_serialization=False`. +# Line below is different externally and internally. +allow_native_serialization_false = lambda: True + # A value suitable in a TF tracing context: tf.Tensor, tf.Variable, # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) TfVal = Any @@ -294,8 +298,8 @@ def convert(fun_jax: Callable, See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. - polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of - the form `e1 >= e2` or `e1 <= e2`. + polymorphic_constraints: a sequence of constraints on symbolic dimension + expressions, of the form `e1 >= e2` or `e1 <= e2`. See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode @@ -332,28 +336,38 @@ def convert(fun_jax: Callable, tuple/lists/dicts thereof), and returns TfVals as outputs, and uses only TensorFlow ops and thus can be called from a TensorFlow program. """ - if not enable_xla: - if allow_enable_xla_false(): - warnings.warn("jax2tf.convert with enable_xla=False is deprecated.", - DeprecationWarning, - stacklevel=2) - else: - raise ValueError("jax2tf.convert with enable_xla=False is not supported.") - if native_serialization is DEFAULT_NATIVE_SERIALIZATION: if not enable_xla: native_serialization = False else: native_serialization = config.jax2tf_default_native_serialization.value - if not native_serialization: - warnings.warn( - "jax2tf.convert with native_serialization=False is deprecated.", - DeprecationWarning, - stacklevel=2) - if native_serialization and not enable_xla: - raise ValueError( - "native_serialization is not supported with enable_xla=False") + if not enable_xla: + if allow_enable_xla_false(): + warnings.warn( + "jax2tf.convert with enable_xla=False has been deprecated " + "since July 2024.", + DeprecationWarning, + stacklevel=2) + if native_serialization: + raise ValueError( + "native_serialization is not supported with enable_xla=False") + else: + raise ValueError( + "jax2tf.convert with enable_xla=False has been deprecated " + "since July 2024 and it is not supported anymore.") + + elif not native_serialization: + if allow_native_serialization_false(): + warnings.warn( + "jax2tf.convert with native_serialization=False has been deprecated " + "since July 2024.", + DeprecationWarning, + stacklevel=2) + else: + raise ValueError( + "jax2tf.convert with native_serialization=False has been deprecated " + "since July 2024 and it is not supported anymore.") if not native_serialization and polymorphic_constraints: raise ValueError( @@ -2188,7 +2202,7 @@ def _dot_general(lhs, rhs, *, dimension_numbers, _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" # TODO(b/293247337): we ought to turn on this safety check, but this leads to - # failures. Since we are going to turn on native serializaton soon, wait + # failures. Since we are going to turn on native serialization soon, wait # until then to turn on this check. # lhs_aval, rhs_aval = _in_avals # if lhs_aval.dtype != rhs_aval.dtype: diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index e8d284178691..6d5efb7b1e66 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -90,7 +90,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) @@ -897,7 +897,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) @@ -1203,7 +1203,7 @@ def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( message=( - "(jax2tf.convert with native_serialization=False is deprecated" + "(jax2tf.convert with native_serialization=False has been deprecated" "|Calling from_dlpack with a DLPack tensor is deprecated)" ) ) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 8ef9a1a5dd25..27e001fbdb49 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -79,7 +79,7 @@ def setUpClass(cls): def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) self.warning_ctx.__enter__() @@ -1722,7 +1722,7 @@ class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): def setUp(self): super().setUp() self.warning_ctx = jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) self.warning_ctx.__enter__() @@ -1763,7 +1763,7 @@ def setUp(self): super().setUp() @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) def test_simple(self): self.ConvertAndCompare(jnp.sin, 0.7) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 38af6d9d76d5..786b98e339e9 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1031,7 +1031,7 @@ def f_jax(x): # A function whose gradient is a constant self.assertAllClose(f_jax(x), restored_f(x)) @jtu.ignore_warning( - message="jax2tf.convert with native_serialization=False is deprecated" + message="jax2tf.convert with native_serialization=False has been deprecated" ) def test_readme_examples(self): """Some of the examples from the README.""" @@ -1124,31 +1124,6 @@ def f2_jax(x): # f32[b, b] # JAX with static shapes sees that x.shape[0] != x.shape[1] self.assertEqual(jnp.sum(x45), f2_jax(x45)) - # In graph serialization eager mode, we catch the broken assumption b >= 1 - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - re.escape( - "Found inconsistency between dimension size args[0].shape[1] (= 5) " - "and the specification 'b' (= 4)")): - jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=False)(x45) - - # In graph serialization graph mode we also catch it (except on TPU, where - # the behavior is as for jit_compile=1) - - f2_tf = tf.function( - jax2tf.convert(f2_jax, polymorphic_shapes=["b, b"], - native_serialization=False), - autograph=False, - ).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) - if jtu.test_device_matches(["tpu"]): - self.assertEqual(1. + jnp.sum(x45), f2_tf(x45)) - else: - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, - r"Found inconsistency"): - _ = f2_tf(x45) - # We also catch the error with native serialization with self.assertRaisesRegex( tf.errors.InvalidArgumentError, From bb5fbec64bc8f028699adb7ec18911af3fae8930 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 25 Oct 2024 02:30:29 -0700 Subject: [PATCH 046/698] [mosaic] Use .clone() to duplicate a module, rather than printing and parsing it. PiperOrigin-RevId: 689708462 --- jax/_src/tpu_custom_call.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 265d36d62b50..6e7402c20a15 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -417,10 +417,11 @@ def _lower_mosaic_module_to_asm( needs_layout_passes = not device_type # We'll mutate the module, so clone it with module.context as ctx, module.operation.location as _: - module = ir.Module.parse( - module.operation.get_asm(binary=True, enable_debug_info=True) - ) if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True) + ) + module_op = module.operation some_tpu = jax.devices(backend)[0] device_kind = some_tpu.device_kind if not device_kind.startswith("TPU v"): @@ -435,15 +436,17 @@ def _lower_mosaic_module_to_asm( ) needs_hlo_passes = False needs_layout_passes = False + else: + module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True try: pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})") - pipeline.run(module.operation) + pipeline.run(module_op) finally: ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects bytecode_buffer = io.BytesIO() - module.operation.write_bytecode(bytecode_buffer, desired_version=0) + module_op.write_bytecode(bytecode_buffer, desired_version=0) asm = bytecode_buffer.getvalue() return asm, ( has_communication, From bfd7075c39d47d84f3a936e1e504aeb2264de4a4 Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Fri, 25 Oct 2024 05:01:44 -0500 Subject: [PATCH 047/698] [ROCm] ci build fixes --- build/rocm/ci_build.sh | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 302a0449b19e..0a50b5845d69 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -53,7 +53,7 @@ CUSTOM_INSTALL="" JAX_USE_CLANG="" POSITIONAL_ARGS=() -RUNTIME_FLAG=1 +RUNTIME_FLAG=0 while [[ $# -gt 0 ]]; do case $1 in @@ -113,11 +113,13 @@ function upsearch (){ } # Set up WORKSPACE. -WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" -BUILD_TAG="${BUILD_TAG:-jax}" - -# Determine the docker image name and BUILD_TAG. -DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}" +if [ ${RUNTIME_FLAG} -eq 0 ]; then + DOCKER_IMG_NAME="${BUILD_TAG}" +else + WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" + BUILD_TAG="${BUILD_TAG:-jax}" + DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}" +fi # Under Jenkins matrix build, the build tag may contain characters such as # commas (,) and equal signs (=), which are not valid inside docker image names. From 8948e6de5823390f568908601f9d398d06a3b638 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 23 Oct 2024 15:52:20 -0700 Subject: [PATCH 048/698] sharding cleanup: use inline checks for unimplemented and auto --- jax/_src/export/_export.py | 4 +- jax/_src/interpreters/pxla.py | 55 ++++++++++++++-------------- jax/_src/layout.py | 4 +- jax/_src/pjit.py | 61 +++++++++++++++---------------- jax/_src/sharding_impls.py | 17 +-------- jax/_src/stages.py | 4 +- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/interpreters/pxla.py | 1 - 8 files changed, 66 insertions(+), 82 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index b1bb797d538c..99794c8cc23c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -801,9 +801,9 @@ def _export_lowered( nr_devices = len(lowering.compile_args["device_assignment"]) def export_sharding(s: LoweringSharding, aval: core.ShapedArray) -> HloSharding | None: - if sharding_impls.is_unspecified(s): + if isinstance(s, sharding_impls.UnspecifiedValue): return None - return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + return s._to_xla_hlo_sharding(aval.ndim) all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], module_kept_var_idx, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a14ca3dcabd8..b81cb9ef9238 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -68,8 +68,8 @@ from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, - UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, - is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources, + UnspecifiedValue, get_array_mapping as _get_array_mapping, + array_mapping_to_axis_resources, SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, @@ -149,7 +149,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, @lru_cache(maxsize=2048) def is_default_layout(curr_layout, sharding, aval): - if curr_layout is None or sharding is None or is_unspecified(sharding): + if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue): return True if (aval is core.abstract_token or aval.dtype == dtypes.float0 or dtypes.issubdtype(aval.dtype, dtypes.extended)): @@ -1643,7 +1643,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp def check_if_any_auto( shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool: for s in shardings: - if is_auto(s): + if isinstance(s, AUTO): return True return False @@ -1727,14 +1727,14 @@ def _get_and_check_device_assignment( devices = tuple(devices) for i, s_type, source_info in shardings: - if is_unspecified(i): + if isinstance(i, UnspecifiedValue): continue if first_sharding_info is None: first_sharding_info = ( - (i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore - else (i._device_assignment, s_type, source_info)) # type: ignore - arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore + (i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO) + else (i._device_assignment, s_type, source_info)) + arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment if not devices: if first_sharding_info[0] != arr_device_assignment: raise DeviceAssignmentMismatchError([ @@ -1836,7 +1836,8 @@ class SemanticallyEqualShardings: def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], avals: tuple[core.AbstractValue]): gspmd_shardings = [ - s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore + s if isinstance(s, (UnspecifiedValue, AUTO)) + else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings self.shardings = shardings @@ -2004,7 +2005,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): except: return True for i in shardings: - if is_unspecified_or_auto(i): + if isinstance(i, (UnspecifiedValue, AUTO)): continue if i.memory_kind is None: # pytype: disable=attribute-error continue @@ -2034,7 +2035,7 @@ def _default_rule(prim, num_outvars, *_, **__): if in_shardings is None: invar_mem_kind = [None] * len(jaxpr.invars) else: - invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind + invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind for s in in_shardings] safe_map(write, jaxpr.invars, invar_mem_kind) safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars)) @@ -2129,7 +2130,7 @@ def _abstract_to_concrete_mesh(abstract_mesh): out = [] for s, a in zip(shardings, avals): - if is_unspecified(s) and a.sharding is not None: + if isinstance(s, UnspecifiedValue) and a.sharding is not None: out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), a.sharding.spec)) else: @@ -2216,9 +2217,9 @@ def lower_sharding_computation( committed = bool( devices_from_context or len(device_assignment) > 1 or - any(not is_unspecified(i) for i in unique_in_shardings) or - any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or - any(not is_unspecified(o) for o in unique_out_shardings)) + any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or + any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or + any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings)) da_object = _create_da_object(tuple(device_assignment)) @@ -2690,7 +2691,7 @@ def _maybe_get_and_check_in_shardings( new_in_shardings = [] for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings, global_in_avals): - if is_unspecified(orig): + if isinstance(orig, UnspecifiedValue): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) @@ -2726,7 +2727,7 @@ def _maybe_get_and_check_out_shardings( new_out_shardings = [] for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings, global_out_avals): - if is_unspecified(orig): + if isinstance(orig, UnspecifiedValue): if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) @@ -2839,16 +2840,16 @@ def from_hlo(name: str, da = _create_da_object(tuple(device_assignment)) del device_assignment - allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i) + allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings) - allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o) + allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO)) for o in out_shardings) mesh = None if auto_spmd_lowering: for i in it.chain.from_iterable([in_shardings, out_shardings]): - if is_auto(i): - mesh = i.mesh # type: ignore + if isinstance(i, AUTO): + mesh = i.mesh break xla_executable = _cached_compilation( @@ -2861,9 +2862,9 @@ def from_hlo(name: str, assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( xla_executable, mesh) - in_shardings = [x if is_auto(i) else i + in_shardings = [x if isinstance(i, AUTO) else i for x, i in safe_zip(in_shardings_xla, in_shardings)] - out_shardings = [x if is_auto(o) else o + out_shardings = [x if isinstance(o, AUTO) else o for x, o in safe_zip(out_shardings_xla, out_shardings)] else: if pmap_nreps == 1: @@ -2954,8 +2955,8 @@ def contains_explicit_attributes(self): self.donate_argnames is not None or self.device is not None or self.backend is not None or - any(not is_unspecified(i) for i in self.in_shardings_leaves) or - any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or + any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or any(i is not None for i in self.in_layouts_leaves) or any(o is not None for o in self.out_layouts_leaves)) @@ -3130,7 +3131,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): def check_device_backend_on_shardings(shardings) -> bool: for i in shardings: - if is_unspecified(i) or is_auto(i): + if isinstance(i, (UnspecifiedValue, AUTO)): continue if getattr(i, '_device_backend', False): return True @@ -3156,7 +3157,7 @@ def check_array_xla_sharding_layout_match( args_after_dce, in_xla_shardings, in_xla_layouts, arg_names): if not isinstance(arg, ArrayImpl): continue - if is_unspecified_or_auto(xs): + if isinstance(xs, (UnspecifiedValue, AUTO)): continue db_xs = check_device_backend_on_shardings([xs]) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 64bbd3268b16..5309f0b1fd9c 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -19,7 +19,7 @@ import numpy as np from jax._src.dtypes import iinfo, issubdtype from jax._src.sharding import Sharding -from jax._src.sharding_impls import AUTO as AutoSharding, is_auto +from jax._src.sharding_impls import AUTO as AutoSharding from jax._src.lib import xla_client as xc Shape = tuple[int, ...] @@ -101,7 +101,7 @@ def __init__(self, device_local_layout: LayoutOptions = None, sharding: ShardingOptions = None): # If layout is concrete and sharding is not, error. if (isinstance(device_local_layout, DeviceLocalLayout) and - (sharding is None or is_auto(sharding))): + (sharding is None or isinstance(sharding, AutoSharding))): raise ValueError( 'Sharding has to be concrete when layout is of type' f' {type(device_local_layout)}. Please pass a' diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a69e8987b2d8..2abf81f26aa4 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -67,8 +67,7 @@ from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, - ParsedPartitionSpec, get_single_pspec, is_unspecified, - is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) + ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding) from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef from jax._src.traceback_util import api_boundary @@ -418,10 +417,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") - if in_shardings is not None and not is_unspecified(in_shardings): + if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue): raise ValueError('If backend or device is specified on jit, then ' 'in_shardings should not be specified.') - if out_shardings is not None and not is_unspecified(out_shardings): + if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue): raise ValueError('If backend or device is specified on jit, then ' 'out_shardings should not be specified.') @@ -440,7 +439,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') user_specified_in_shardings = (in_shardings is not None and - not is_unspecified(in_shardings)) + not isinstance(in_shardings, UnspecifiedValue)) in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings) out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings) @@ -483,7 +482,7 @@ def lower(*args, **kwargs): @api_boundary def eval_shape(*args, **kwargs): p, _ = _infer_params(fun, jit_info, args, kwargs) - out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] + out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] # TODO(yashkatariya): Add `Layout` to SDS. out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, weak_type=x.weak_type) @@ -1001,7 +1000,7 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): if x is None and (mesh is None or mesh.empty): return UNSPECIFIED - if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x): + if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)): return x if mesh is None: msg = ('jax.jit only supports `Sharding`s being passed to' @@ -1110,7 +1109,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves) # Only do this if original in_shardings are unspecified. If it is AUTO, go # via flatten_axis_resources. - if is_unspecified(orig_in_shardings): + if isinstance(orig_in_shardings, UnspecifiedValue): in_shardings_flat = (orig_in_shardings,) * len(in_avals) else: in_shardings_flat = flatten_axis_resources( @@ -1312,8 +1311,7 @@ def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) - if (is_unspecified(orig_out_shardings) or - isinstance(orig_out_shardings, sharding.Sharding)): + if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)): out_shardings_flat = (orig_out_shardings,) * len(out_avals) else: out_shardings_flat = flatten_axis_resources( @@ -1391,7 +1389,7 @@ def pjit_check_aval_sharding( what_aval: str, allow_uneven_sharding: bool): new_names = [''] * len(shardings) if names is None else names for aval, s, name in zip(flat_avals, shardings, new_names): - if is_unspecified_or_auto(s): + if isinstance(s, (UnspecifiedValue, AUTO)): continue name_str = f' with pytree key path {name}' if name else '' shape = aval.shape @@ -1466,7 +1464,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): else: arg_layout, dispatch_arg_layout = None, None # Sharding can be unspecified when array is committed if it's a PmapSharding. - is_pmap_sharding = (is_unspecified(rs) or + is_pmap_sharding = (isinstance(rs, UnspecifiedValue) or isinstance(getattr(arg, 'sharding', None), PmapSharding)) if jit_in_l is None: if committed: @@ -1527,15 +1525,15 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if getattr(a, '_committed', True): committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) - resolved_in_shardings = [] + resolved_in_shardings: list[PjitSharding] = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does # not allow None as the sharding. arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True)) if hasattr(arg, 'sharding') and arg.sharding is not None else (UNSPECIFIED, False)) - if is_unspecified(pjit_in_s): - if is_unspecified(arg_s): + if isinstance(pjit_in_s, UnspecifiedValue): + if isinstance(arg_s, UnspecifiedValue): resolved_in_shardings.append(arg_s) else: if committed: @@ -1553,7 +1551,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'multiple devices is not supported.') else: if (isinstance(arg, np.ndarray) and - not pjit_in_s.is_fully_replicated and # type: ignore + not pjit_in_s.is_fully_replicated and # type: ignore[union-attr] xb.process_count() > 1): raise ValueError( 'Passing non-trivial shardings for numpy ' @@ -1572,16 +1570,16 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] # jax.jit does not allow resharding across different memory kinds even # if the argument is uncommitted. Use jax.device_put for those cases, # either outside or inside jax.jit. - if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore + if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore[union-attr] raise ValueError( 'Memory kinds passed to jax.jit does not match memory kind on the' - f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore + f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr] f'arg memory kind: {arg_s.memory_kind} for ' f'arg shape: {shaped_abstractify(arg).str_short()}') if (committed and not isinstance(arg_s, PmapSharding) and not op_shardings.are_op_shardings_equal( - pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore + pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore[union-attr] arg_s._to_xla_hlo_sharding(arg.ndim))): raise ValueError('Sharding passed to pjit does not match the sharding ' 'on the respective arg. ' @@ -1780,8 +1778,8 @@ def pjit_staging_rule(trace, *args, **params): params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) if (params["inline"] and - all(is_unspecified(i) for i in params["in_shardings"]) and - all(is_unspecified(o) for o in params["out_shardings"]) and + all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and + all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): if config.dynamic_shapes.value: @@ -1830,7 +1828,7 @@ def pjit_staging_rule(trace, *args, **params): def _pjit_forwarding(jaxpr, out_shardings, out_layouts): in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol + in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) @@ -1896,8 +1894,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, func = mod_ctx.cached_primitive_lowerings.get(key, None) if func is None: - arg_shardings = [None if is_unspecified(i) else i for i in in_shardings] - result_shardings = [None if is_unspecified(o) else o for o in out_shardings] + arg_shardings = [None if isinstance(i, UnspecifiedValue) else i for i in in_shardings] + result_shardings = [None if isinstance(o, UnspecifiedValue) else o for o in out_shardings] # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. @@ -1990,9 +1988,9 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, def _pjit_batcher_for_sharding( s: sharding.Sharding | UnspecifiedValue, dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): - if is_unspecified(s): + if isinstance(s, UnspecifiedValue): return s - hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore + hlo_s = s._to_xla_hlo_sharding(ndim) if spmd_axis_name is None: if sharding_impls.is_op_sharding_replicated(hlo_s): return s @@ -2004,7 +2002,7 @@ def _pjit_batcher_for_sharding( tad.insert(dim, 1) new_op.tile_assignment_dimensions = tad new_gs = GSPMDSharding( - s._device_assignment, new_op, # type: ignore + s._device_assignment, new_op, _device_list=getattr(s, '_internal_device_list', None)) return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] else: @@ -2107,7 +2105,7 @@ def keep_where(l, should_keep): # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) in_fwd = [ - fwd if is_unspecified(os) and ol is None else None + fwd if isinstance(os, UnspecifiedValue) and ol is None else None for os, ol, fwd in zip( keep_where(out_shardings, known_outs), keep_where(out_layouts, known_outs), in_fwd_primal) @@ -2358,9 +2356,9 @@ def _pjit_pp_rule(eqn, context, settings): del params['inline'] if not any(params['donated_invars']): del params['donated_invars'] - if all(is_unspecified(s) for s in params['in_shardings']): + if all(isinstance(s, UnspecifiedValue) for s in params['in_shardings']): del params['in_shardings'] - if all(is_unspecified(s) for s in params['out_shardings']): + if all(isinstance(s, UnspecifiedValue) for s in params['out_shardings']): del params['out_shardings'] if all(l is None for l in params['in_layouts']): del params['in_layouts'] @@ -2382,8 +2380,7 @@ def _pjit_pp_rule(eqn, context, settings): def _pjit_state_discharge_rule( in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, **params): - if not (all(map(is_unspecified, in_shardings)) and - all(map(is_unspecified, out_shardings))): + if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)): raise NotImplementedError if not (all(l is None for l in in_layouts) and diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 78a59a4fcd5c..047571754d6f 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -957,21 +957,11 @@ def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) -def is_auto(x): - return isinstance(x, AUTO) - - class UnspecifiedValue: def __repr__(self): return "UnspecifiedValue" UNSPECIFIED = UnspecifiedValue() -def is_unspecified(x): - return isinstance(x, UnspecifiedValue) - -def is_unspecified_or_auto(x): - return is_auto(x) or is_unspecified(x) - MeshAxisName = Any @@ -1014,8 +1004,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): def get_array_mapping( axis_resources: ParsedPartitionSpec | AUTO | UnspecifiedValue ) -> ArrayMappingOrAutoOrUnspecified: - # TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported. - # Don't use `is_auto` here to satisfy pytype and mypy. if isinstance(axis_resources, (AUTO, UnspecifiedValue)): return axis_resources return OrderedDict((axis, i) @@ -1113,7 +1101,7 @@ def prepare_axis_resources(axis_resources, arg_name, new_entries = [] for entry in entries: - if is_unspecified_or_auto(entry) or entry is None: + if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None: new_entries.append(entry) elif isinstance(entry, sharding.Sharding): if isinstance(entry, PmapSharding): @@ -1131,8 +1119,7 @@ def prepare_axis_resources(axis_resources, arg_name, def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue - if (is_unspecified_or_auto(arg_axis_resources) or - isinstance(arg_axis_resources, sharding.Sharding)): + if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)): continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = collections.Counter( diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3a2c375b64db..92c680009c93 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -43,7 +43,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.sharding_impls import is_unspecified_or_auto +from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.layout import Layout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -649,7 +649,7 @@ def out_info(self): # PyTree of OutInfo out_avals = self._lowering.compile_args["global_out_avals"] out_shardings = self._lowering.compile_args["out_shardings"] return self.out_tree.unflatten( - [OutInfo(o.shape, o.dtype, None if is_unspecified_or_auto(s) else s) + [OutInfo(o.shape, o.dtype, None if isinstance(s, (UnspecifiedValue, AUTO)) else s) for o, s in zip(out_avals, out_shardings)]) def compile( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index b12edf2a37ec..348baa868f09 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3537,7 +3537,7 @@ def split_to_logical_devices(tensor: TfVal, def _xla_compatible_sharding_to_hlo_sharding( s: sharding.Sharding, aval: core.ShapedArray) -> xla_client.HloSharding | None: - if sharding_impls.is_unspecified(s): + if isinstance(s, sharding_impls.UnspecifiedValue): return None return s._to_xla_hlo_sharding(aval.ndim) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 15c9a2cfe49d..f3fd8bac558c 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -40,7 +40,6 @@ ArrayMapping as ArrayMapping, UNSPECIFIED as _UNSPECIFIED, # noqa: F401 array_mapping_to_axis_resources as array_mapping_to_axis_resources, - is_unspecified as _is_unspecified, # noqa: F401 ) from jax._src.sharding_specs import ( From 33a46e8f68d20c1b8df3e0283fc3b8a10e5f7b55 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 25 Oct 2024 08:26:19 -0400 Subject: [PATCH 049/698] Re-enable jax2tf test for dot algorithm with stricter TF version check. --- jax/experimental/jax2tf/tests/jax2tf_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 27e001fbdb49..68c7b15383fe 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1690,11 +1690,10 @@ def f_jax(x): res, x + _testing_multi_platform_to_add[tf_device_jax_platform]) - @unittest.skip("TODO(danfm): Test fails at head with segfault in GH") def test_dot_algorithm(self): # ref: https://github.com/jax-ml/jax/issues/24236 - if tf.version.VERSION.split(".") <= ["2", "17", "0"]: - self.skipTest("This test works only with newer versions of TF") + if tf.version.VERSION.split(".") <= ["2", "18", "0"]: + self.skipTest("Because of an XLA bug this test segfaults with TF v2.18.0") if jtu.test_device_matches(["tpu"]): algorithm = "BF16_BF16_F32" From c62b19883f45420f763d41140cefb01782bbc558 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 25 Oct 2024 16:11:35 +0300 Subject: [PATCH 050/698] Fix copy and paste error in CHANGELOG. --- CHANGELOG.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a71144340fc..3b47f0a1af61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,10 +27,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.experimental.host_callback` has been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See {jax-issue}`#20385` for a discussion of alternatives. - * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` - or with `enable_xla=False` have been deprecated since July 2024, with - JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` - with native serialization is still supported. * Changes: * `jax.lax.FftType` was introduced as a public name for the enum of FFT From 63c1699ed04a94758b3b63725b14af41dc88ba4b Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 25 Oct 2024 07:39:35 -0700 Subject: [PATCH 051/698] Fix a use-after-free bug in third_party/py/jax/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc The backing array of the initializer_list is destroyed at the end of the full expression. PiperOrigin-RevId: 689783482 --- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 53c8d44f6c32..a1379619d922 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -118,8 +118,8 @@ class MosaicGpuTest : public ::testing::Test { }; TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) { - llvm::ArrayRef shape{1, 2, 3}; - llvm::ArrayRef slice_shape{1, 2}; + std::vector shape{1, 2, 3}; + std::vector slice_shape{1, 2}; mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(&context_); @@ -128,7 +128,7 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) { EXPECT_THAT( FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, slice_shape), + memref_type, mlir::ArrayRef(slice_shape)), StatusIs( absl::StatusCode::kFailedPrecondition, HasSubstr( @@ -136,8 +136,8 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorRequiresSliceShapeHasTheCorrectRank) { } TEST_F(MosaicGpuTest, InitTmaDescriptorGracefullyRejectsSubByteTypes) { - llvm::ArrayRef shape{1, 2, 3}; - llvm::ArrayRef slice_shape{1, 2, 3}; + std::vector shape{1, 2, 3}; + std::vector slice_shape{1, 2, 3}; mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(&context_); @@ -145,14 +145,14 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorGracefullyRejectsSubByteTypes) { mlir::MemRefType::get(shape, builder_.getI4Type()); EXPECT_THAT(FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, slice_shape), + memref_type, mlir::ArrayRef(slice_shape)), StatusIs(absl::StatusCode::kUnimplemented, HasSubstr("Sub-byte types are not yet supported"))); } TEST_F(MosaicGpuTest, InitTmaDescriptorProducesACallToRuntime) { - llvm::ArrayRef shape{1, 2, 3}; - llvm::ArrayRef slice_shape{1, 2, 3}; + std::vector shape{1, 2, 3}; + std::vector slice_shape{1, 2, 3}; mlir::LLVM::LLVMPointerType pointer_type = mlir::LLVM::LLVMPointerType::get(&context_); @@ -161,7 +161,7 @@ TEST_F(MosaicGpuTest, InitTmaDescriptorProducesACallToRuntime) { absl::StatusOr fn_or = FromCppFunc(*module_, mosaic_gpu::InitTmaDescriptor, pointer_type, - memref_type, slice_shape); + memref_type, mlir::ArrayRef(slice_shape)); ASSERT_OK(fn_or); llvm::SmallVector call_ops = From d4c46825d6604d8e3bd511ea44c5764921004653 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 25 Oct 2024 08:10:29 -0700 Subject: [PATCH 052/698] Finalize deprecation of xb, xc, & xe symbols in jax.interpreters.xla PiperOrigin-RevId: 689792265 --- CHANGELOG.md | 3 +++ jax/interpreters/xla.py | 29 ++++++++++------------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b47f0a1af61..542a7d417269 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. or with `enable_xla=False` have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` with native serialization will still be supported. + * In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed + after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, + `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. ## jax 0.4.35 (Oct 22, 2024) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bbd5b65d5d3e..2711bcfb80d5 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -23,26 +23,24 @@ apply_primitive as apply_primitive, ) -from jax._src import xla_bridge as _xb from jax._src.lib import xla_client as _xc - -_xe = _xc._xla -Backend = _xe.Client +Backend = _xc._xla.Client +del _xc # Deprecations _deprecations = { - # Added 2024-06-28 + # Finalized 2024-10-24; remove after 2025-01-24 "xb": ( - "jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.", - _xb + ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " + "Use jax.lib.xla_bridge instead."), None ), "xc": ( - "jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.", - _xc, + ("jax.interpreters.xla.xc was removed in JAX v0.4.36. " + "Use jax.lib.xla_client instead."), None ), "xe": ( - "jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.", - _xe, + ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " + "Use jax.lib.xla_extension instead."), None ), # Finalized 2024-05-13; remove after 2024-08-13 "backend_specific_translations": ( @@ -82,13 +80,6 @@ ), } -import typing from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -if typing.TYPE_CHECKING: - xb = _xb - xc = _xc - xe = _xe -else: - __getattr__ = _deprecation_getattr(__name__, _deprecations) +__getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr -del typing From 4972f84c94ec3af6f5a0437cd84e010ace03d25d Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 25 Oct 2024 09:08:40 -0700 Subject: [PATCH 053/698] [Mosaic] Use max sublane offset per shuffled load to decide whether to avoid bank conflict. PiperOrigin-RevId: 689809024 --- jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 1 + .../tpu/transforms/apply_vector_layout.cc | 106 +++++++++++++++--- 3 files changed, 92 insertions(+), 16 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 44199612ea73..faee869663c4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -784,6 +784,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, Option<"vmem_banks", "vmem-banks", "int", /*default=*/"-1", "">, + Option<"max_shuffle_sublane_offset", "max-shuffle-sublane-offset", "int", /*default=*/"-1", "Max sublane offset per shuffled load/store">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index e827faed3d0e..32ccd45f6e49 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -63,6 +63,7 @@ struct ApplyVectorLayoutContext { std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; int64_t vmem_banks = -1; // -1 means "unspecified". + int32_t max_shuffle_sublane_offset = -1; // -1 means "unspecified". }; std::pair mightCommunicateBetweenChips(Operation* op); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index ee95feb4d18e..10585eec2920 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4658,6 +4658,34 @@ const llvm::StringMap &rules() { {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; return *rules; } + +// Determines whether we should handle bank conflict for the given stride and +// max_sublane_offset. +// +// See `handleBankConflict` for how this is done. +bool shouldHandleBankConflict(const ApplyVectorLayoutContext &ctx, + int32_t stride, int max_sublane_offset) { + return ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0] && + ctx.max_shuffle_sublane_offset > 0 && + ctx.max_shuffle_sublane_offset >= max_sublane_offset; +} + +// Handles load/store bank conflict by adding one extra sublane to stride and +// adjusting sublane offsets accordingly. +// +// For example, when store stride is 4 and load sublane offsets are +// [0, 1, 2, 3, 4, 5, 6, 7], the store bank conflict can be avoided by changing +// stride to 5 and sublane offsets to [0, 1, 2, 3, 5, 6, 7, 8]. +void handleBankConflict(int32_t &stride, absl::Span sublane_offsets) { + // Add one extra sublane to stride to avoid bank conflict. + for (int i = 0; i < sublane_offsets.size(); ++i) { + // Adjust sublane offsets to match the stride. + sublane_offsets[i] += i / stride; + } + ++stride; +} + } // namespace RollVectorsOp assemble(OpBuilder &builder, VectorType vty, @@ -5641,16 +5669,37 @@ LogicalResult retileToLargeTileWithScratch( // The older hardware has limited support for shuffles so even if we have bank // conflicts, we just accept them and will have the lowering unroll the // loads/stores. + int64_t num_offsets = sublane_offsets.num_elements(); + // The max sublane offset before handling bank conflicts is always + // (num_offsets - 1). To avoid bank conflicts, we need to add one extra + // sublane to stride so (num_offsets - 1) / stride is the extra offset needed + // to pad sublanes. + // + // For example, if store stride = 4, sublane_count = 8, and + // load offsets = [0, 1, 2, 3, 4, 5, 6, 7], then the sublane offsets after + // handling bank conflicts will be [0, 1, 2, 3, 5, 6, 7, 8] and the max + // sublane offset will be 7 + (8 - 1) / 4 = 8. + // + // Before + // <-------- sublanes ---------> + // 0 1 ... 32 + // store: x---x---x---x---x---x---x---x + // load: xxxxxxxxx-------------------- + // + // After + // <-------- sublanes ---------> + // 0 5 ... 40 + // store: x----x----x----x----x----x----x----x + // load: xxxx-xxxx--------------------------- + // + // where "x" indicates a sublane that needs to be accessed and "-"" indicates + // a sublane that does not need to be accessed. + int max_sublane_offset = (num_offsets - 1) + (num_offsets - 1) / stride; bool should_handle_bank_confict = - ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && - ctx.vmem_banks < stride * ctx.target_shape[0]; - // Add one extra sublane to stride to avoid bank conflict. + shouldHandleBankConflict(ctx, stride, max_sublane_offset); if (should_handle_bank_confict) { - // Adjust sublane offsets to match the stride. - for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { - *(sublane_offsets.begin() + i) += i / stride; - } - stride += 1; + handleBankConflict(stride, absl::MakeSpan(sublane_offsets.data(), + sublane_offsets.num_elements())); } sublane_offsets.TransposeDimensions({0, 2, 1}); @@ -5773,9 +5822,34 @@ LogicalResult retileToSmallTileWithScratch( // The older hardware has limited support for shuffles so even if we have // bank conflicts, we just accept them and will have the lowering unroll the // loads/stores. + int64_t num_offsets = sublane_offsets.num_elements(); + // The max sublane offset before handling bank conflicts is always + // (num_offsets - 1). To avoid bank conflicts, we need to add one extra + // sublane to stride so (num_offsets - 1) / stride is the extra offset needed + // to pad sublanes. + // + // For example, if store stride = 4, sublane_count = 8, and + // load offsets = [0, 1, 2, 3, 4, 5, 6, 7], then the sublane offsets after + // handling bank conflicts will be [0, 1, 2, 3, 5, 6, 7, 8] and the max + // sublane offset will be 7 + (8 - 1) / 4 = 8. + // + // Before + // <-------- sublanes ---------> + // 0 4 ... + // store: x---x---x---x---x---x---x---x + // load: xxxxxxxxx------------------- + // + // After + // <-------- sublanes ---------> + // 0 5 ... + // store: x----x----x----x----x----x----x----x + // load: xxxx-xxxx--------------------------- + // + // where "x" indicates a sublane that needs to be accessed and "-"" indicates + // a sublane that does not need to be accessed. + int max_sublane_offset = (num_offsets - 1) + (num_offsets - 1) / stride; bool should_handle_bank_confict = - ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && - ctx.vmem_banks < stride * ctx.target_shape[0]; + shouldHandleBankConflict(ctx, stride, max_sublane_offset); bool use_shuffled_load = false; if (ctx.hardware_generation <= 4) { if (src_tile[0] == 8) { @@ -5794,11 +5868,8 @@ LogicalResult retileToSmallTileWithScratch( // Add one extra sublane to stride to avoid bank conflict. if (should_handle_bank_confict) { - // Adjust sublane offsets to match the stride. - for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { - *(sublane_offsets.begin() + i) += i / stride; - } - stride += 1; + handleBankConflict(stride, absl::MakeSpan(sublane_offsets.data(), + sublane_offsets.num_elements())); } sublane_offsets.TransposeDimensions({0, 2, 1}); auto mlirIndexConst = [&](int d) { @@ -6455,6 +6526,7 @@ struct ApplyVectorLayoutPass mxu_noncontracting_size = ctx.mxu_shape[1]; max_sublanes_in_scratch = ctx.max_sublanes_in_scratch; vmem_banks = ctx.vmem_banks; + max_shuffle_sublane_offset = ctx.max_shuffle_sublane_offset; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -6467,7 +6539,9 @@ struct ApplyVectorLayoutPass .target_shape = {sublane_count, lane_count}, .mxu_shape = {mxu_contracting_size, mxu_noncontracting_size}, .max_sublanes_in_scratch = max_sublanes_in_scratch, - .vmem_banks = vmem_banks}; + .vmem_banks = vmem_banks, + .max_shuffle_sublane_offset = max_shuffle_sublane_offset, + }; if (failed(applyLayoutFunc(ctx, getOperation()))) { signalPassFailure(); return; From 21f3353544bde08d55276921ed5e9ed5280af63c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 21 Oct 2024 11:34:57 -0400 Subject: [PATCH 054/698] Add support for layouts and other advanced features in ffi_call. --- jax/_src/callback.py | 17 ++++ jax/_src/extend/ffi.py | 189 +++++++++++++++++++++++++++++++++++------ tests/extend_test.py | 158 ++++++++++++++++++++++------------ 3 files changed, 286 insertions(+), 78 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 9630418ae76c..0b918c7a994e 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -160,9 +160,22 @@ def callback_batching_rule( batched_result_avals = tuple( core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) for aval in result_avals) + + # For FFI calls we must update the layouts. We handle the output layouts + # here, but the input layout updates depend on the vmap_method parameter. + if vmap_method != "sequential" and kwargs.get("output_layouts") is not None: + kwargs["output_layouts"] = tuple( + None if layout is None else tuple(n + 1 for n in layout) + (0,) + for layout in kwargs["output_layouts"]) + if vmap_method == "legacy_vectorized": # This method is kept to support the behavior that was previously exposed # when using `vectorized=True`. + if kwargs.get("input_layouts") is not None: + kwargs["input_layouts"] = tuple( + layout if d is batching.not_mapped else + (None if layout is None else tuple(n + 1 for n in layout) + (0,)) + for layout, d in zip(kwargs["input_layouts"], dims)) outvals = prim.bind( *new_args, vectorized=vectorized, @@ -175,6 +188,10 @@ def callback_batching_rule( bcast_args = [ lax.broadcast(x, (size,)) if d is batching.not_mapped else x for x, d in zip(new_args, dims)] + if kwargs.get("input_layouts") is not None: + kwargs["input_layouts"] = tuple( + None if layout is None else tuple(n + 1 for n in layout) + (0,) + for layout in kwargs["input_layouts"]) outvals = prim.bind( *bcast_args, vectorized=vectorized, diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 9a45b3f77a93..3012b74cf941 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -116,17 +116,17 @@ def _aval_shape(aval: core.AbstractValue) -> Shape: return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error -def _convert_layout(aval: core.AbstractValue, - layout: FfiLayoutOptions = None) -> Sequence[int]: +def _convert_layout_for_lowering( + aval: core.AbstractValue, layout: FfiLayoutOptions = None) -> Sequence[int]: """Convert a layout to the minor-to-major order used by the custom call API.""" if layout is None: - return list(reversed(range(len(_aval_shape(aval))))) + return tuple(reversed(range(len(_aval_shape(aval))))) elif isinstance(layout, DeviceLocalLayout): if layout._tiling is not None: raise ValueError("The FFI does not support layouts with tiling") return layout.major_to_minor[::-1] else: - return layout + return tuple(layout) def ffi_lowering( @@ -134,7 +134,7 @@ def ffi_lowering( *, operand_layouts: Sequence[FfiLayoutOptions] | None = None, result_layouts: Sequence[FfiLayoutOptions] | None = None, - backend_config: Mapping[str, ir.Attribute] | None = None, + backend_config: Mapping[str, ir.Attribute] | str | None = None, **lowering_args: Any ) -> mlir.LoweringRule: """Build a lowering rule for an foreign function interface (FFI) target. @@ -143,6 +143,10 @@ def ffi_lowering( compute the input and output types and shapes for the custom call, assuming row-major layouts. + Note that layouts passed to this function as tuples should be in + minor-to-major order (as expected by XLA) rather than major-to-minor as used + by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``. + If keyword arguments are passed to the lowering rule, these are treated as attributes, and added to `backend_config`. @@ -163,20 +167,32 @@ def _lowering( ) -> Sequence[ir.Value | Sequence[ir.Value]]: kwargs = dict(lowering_args) kwargs.setdefault("api_version", 4) - kwargs["backend_config"] = dict( - backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) + if kwargs["api_version"] >= 4: + if backend_config is not None and not isinstance(backend_config, dict): + raise ValueError( + "When api_version > 4, backend_config must be a dictionary.") + kwargs["backend_config"] = dict( + backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) + else: + if params: + raise ValueError( + "The use of ffi_call attributes requires a custom call API version " + f"of at least 4; got api_version={kwargs['api_version']}.") + kwargs["backend_config"] = backend_config if "result_types" not in kwargs: kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] if operand_layouts is None: - kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in) + kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in) else: kwargs["operand_layouts"] = [ - _convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)] + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_in, operand_layouts)] if result_layouts is None: - kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out) + kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out) else: kwargs["result_layouts"] = [ - _convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)] + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_out, result_layouts)] if "result_shapes" not in kwargs and not all( core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): kwargs["result_shapes"] = [ @@ -202,12 +218,39 @@ def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue return tuple(avals) +def _check_compatible_avals(a: core.AbstractValue, b: core.AbstractValue) -> bool: + if isinstance(a, core.AbstractToken) and isinstance(b, core.AbstractToken): + return True + if getattr(a, "shape", ()) != getattr(b, "shape", ()): + return False + if getattr(a, "dtype", ()) != getattr(b, "dtype", ()): + return False + return True + + +def _convert_layouts_for_ffi_call( + avals: Sequence[core.AbstractValue], + layouts: Sequence[FfiLayoutOptions]) -> tuple[Sequence[int], ...]: + return tuple( + _convert_layout_for_lowering( + aval, + layout if layout is None or isinstance(layout, DeviceLocalLayout) + else layout[::-1] + ) + for aval, layout in zip(avals, layouts)) + + def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], *deprecated_args: ArrayLike, has_side_effect: bool = False, vmap_method: str | None = None, + input_layouts: Sequence[FfiLayoutOptions] | None = None, + output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = None, + input_output_aliases: dict[int, int] | None = None, + custom_call_api_version: int = 4, + legacy_backend_config: str | None = None, vectorized: bool | DeprecatedArg = DeprecatedArg(), **deprecated_kwargs: Any, ) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]: @@ -227,7 +270,7 @@ def ffi_call( Args: target_name: the name of the XLA FFI custom call target that was registered - using :func:`~jaxlib.xla_client.register_custom_call_target`. + using :func:`~jax.extend.ffi.register_ffi_target`. result_shape_dtypes: an object, or sequence of objects, with ``shape`` and ``dtype`` attributes which are expected to match the shape and dtype of the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often @@ -238,6 +281,32 @@ def ffi_call( outputs are not used. vmap_method: string specifying how the FFI call transforms under :func:`~jax.vmap` as described above. + input_layouts: a sequence of layouts for each input argument. In each case, + the layout can be (a) ``None`` indicating that this input is in default + row-major order, (b) a ``DeviceLocalLayout`` specifying the axis order, + or (c) a sequence of integers specifying the major-to-minor axis + ordering. Users who are familiar with XLA layouts should note that this + function expects layouts in major-to-minor order instead of the + minor-to-major order that XLA uses. For example, a batch of row-major + matrices could be specified using the layout ``[0, 1, 2]``, whereas a + batch of column-major matrices would have layout ``[0, 2, 1]``. In both + of these examples, the leading/batch dimension is the "slowest" axis. The + ``input_layouts`` parameter should be used to request the memory layout + expected by the FFI call target, and XLA will ensure that the buffers + have the correct layouts before the handler is executed. + output_layouts: like ``input_layouts``, but specifying the required layouts + for the output arrays. + input_output_aliases: a dictionary where the keys are input indices and the + values are output indices. This mapping indicates which output arrays + alias specific input arrays. + custom_call_api_version: the version number of the custom call API + implemented by the FFI target ``target_name``. The only formally + supported version is the typed FFI API with ``custom_call_api_version=4``, + but earlier unsupported custom calls can be executed using this argument. + legacy_backend_config: for legacy targets implemented using + ``custom_call_api_version<4``, attributes are passed using the opaque + string representation provided by this argument. This parameter cannot be + used with ``custom_call_api_version>=4``. Returns: A function that can be called with the input arrays as positional arguments @@ -263,14 +332,73 @@ def ffi_call( f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, " f"but got: {vmap_method}") + output_layouts_: Sequence[FfiLayoutOptions] | None if isinstance(result_shape_dtypes, Sequence): + output_layouts_ = output_layouts # type: ignore multiple_results = True result_avals = _result_avals(result_shape_dtypes) else: multiple_results = False result_avals = _result_avals((result_shape_dtypes,)) + output_layouts_ = (output_layouts,) # type: ignore + + if custom_call_api_version >= 4 and legacy_backend_config is not None: + raise ValueError( + "The use of the legacy_backend_config parameter requires " + f"custom_call_api_version < 4; got {custom_call_api_version}.") def wrapped(*args: ArrayLike, **kwargs: Any): + in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] + + if input_layouts is None: + static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals)) + else: + if len(input_layouts) != len(in_avals): + raise ValueError( + f"The number of input arguments ({len(in_avals)}) must equal the " + f"number of input layouts ({len(input_layouts)}).") + static_input_layouts = _convert_layouts_for_ffi_call(in_avals, + input_layouts) + if output_layouts_ is None: + static_output_layouts = tuple(map(_convert_layout_for_lowering, + result_avals)) + else: + if len(output_layouts_) != len(result_avals): + raise ValueError( + f"The number of outputs ({len(result_avals)}) must equal the " + f"number of output layouts ({len(output_layouts_)}).") + static_output_layouts = _convert_layouts_for_ffi_call(result_avals, + output_layouts_) + + static_input_output_aliases: tuple[tuple[int, int], ...] = () + if input_output_aliases is not None: + for i_idx, o_idx in sorted(input_output_aliases.items()): + i_idx, o_idx = int(i_idx), int(o_idx) + if i_idx >= len(args): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with input index {i_idx} outside the range [0, " + f"{len(args)}).") + if o_idx >= len(result_avals): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with output index {o_idx} outside the range [0, " + f"{len(result_avals)}).") + in_aval = in_avals[i_idx] + out_aval = result_avals[o_idx] + if not _check_compatible_avals(in_aval, out_aval): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to an input with abstract value {in_aval} and an " + f"output with a different abstract value {out_aval}.") + if static_input_layouts[i_idx] != static_output_layouts[o_idx]: + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to an input with layout {static_input_layouts[i_idx]} " + "and an output with a different layout " + f"{static_output_layouts[o_idx]}.") + static_input_output_aliases += ((i_idx, o_idx),) + results = ffi_call_p.bind( *args, result_avals=result_avals, @@ -278,6 +406,11 @@ def wrapped(*args: ArrayLike, **kwargs: Any): vmap_method=vmap_method, target_name=target_name, has_side_effect=has_side_effect, + input_layouts=static_input_layouts, + output_layouts=static_output_layouts, + input_output_aliases=static_input_output_aliases, + custom_call_api_version=custom_call_api_version, + legacy_backend_config=legacy_backend_config, attributes=_wrap_kwargs_hashable(kwargs), ) if multiple_results: @@ -383,26 +516,23 @@ def __str__(self): def ffi_call_abstract_eval( *avals_in, result_avals: tuple[core.AbstractValue, ...], - target_name: str, - vectorized: bool | DeprecatedArg, - vmap_method: str | None, has_side_effect: bool, - attributes: Sequence[tuple[str, Any]], + **_, ): - del avals_in, target_name, vectorized, vmap_method, attributes + del avals_in # unused effects = {_FfiEffect} if has_side_effect else core.no_effects return result_avals, effects -def ffi_call_jvp(*args, target_name, **kwargs): - del args, kwargs +def ffi_call_jvp(*args, target_name, **_): + del args raise ValueError( f"The FFI call to `{target_name}` cannot be differentiated. " "You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.") -def ffi_call_transpose(*args, target_name, **kwargs): - del args, kwargs +def ffi_call_transpose(*args, target_name, **_): + del args raise ValueError( f"The FFI call to `{target_name}` cannot be differentiated. " "You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.") @@ -411,15 +541,22 @@ def ffi_call_transpose(*args, target_name, **kwargs): def ffi_call_lowering( ctx: mlir.LoweringRuleContext, *operands: ir.Value, - result_avals: tuple[core.AbstractValue, ...], target_name: str, - vectorized: bool | DeprecatedArg, - vmap_method: str | None, has_side_effect: bool, + input_layouts: Sequence[Sequence[int]], + output_layouts: Sequence[Sequence[int]], + input_output_aliases: Sequence[tuple[int, int]], + custom_call_api_version: int, + legacy_backend_config: str | None, attributes: Sequence[tuple[str, Any]], + **_, ) -> Sequence[ir.Value]: - del result_avals, vectorized, vmap_method - rule = ffi_lowering(target_name, has_side_effect=has_side_effect) + rule = ffi_lowering(target_name, has_side_effect=has_side_effect, + operand_layouts=input_layouts, + result_layouts=output_layouts, + operand_output_aliases=dict(input_output_aliases), + api_version=custom_call_api_version, + backend_config=legacy_backend_config) return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes)) diff --git a/tests/extend_test.py b/tests/extend_test.py index 0fc8821f1984..d9059d49ecc7 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -14,6 +14,7 @@ import os import unittest +from functools import partial import numpy as np from absl.testing import absltest @@ -34,6 +35,7 @@ from jax._src.interpreters import mlir from jax._src.layout import DeviceLocalLayout from jax._src.lib.mlir.dialects import hlo +from jax._src.lax import linalg as lax_linalg_internal jax.config.parse_flags_with_absl() @@ -122,7 +124,6 @@ def testLoweringLayouts(self, layout_spec, expected_layout): # layouts. def lowering_rule(ctx, x): aval, = ctx.avals_in - ndim = len(aval.shape) return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec], result_layouts=[layout_spec])(ctx, x) prim = core.Primitive("test_ffi") @@ -228,51 +229,42 @@ def fun(x): fun(jnp.ones(5)) self.assertNotIsInstance(manager.exception, TypeError) - @jtu.sample_product( - shape=[(1,), (4,), (5,)], - dtype=(np.int32,), - ) - @jtu.run_on_devices("gpu") - def testFfiCall(self, shape, dtype): - pivots_size = shape[-1] - permutation_size = 2 * pivots_size - pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) - pivots = jnp.broadcast_to(pivots, shape) - expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) - actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size) - self.assertArraysEqual(actual, expected) + @jtu.sample_product(shape=[(6, 5), (4, 5, 6)]) + @jtu.run_on_devices("gpu", "cpu") + def testFfiCall(self, shape): + x = self.rng().randn(*shape).astype(np.float32) + expected = lax_linalg_internal.geqrf(x) + actual = ffi_call_geqrf(x) + for a, b in zip(actual, expected): + self.assertArraysEqual(a, b) @jtu.sample_product( - shape=[(1,), (4,), (5,)], - dtype=(np.int32,), - vmap_method=("expand_dims", "broadcast_all", "sequential", - "legacy_vectorized"), + shape=[(6, 5), (4, 5, 6)], + vmap_method=["expand_dims", "broadcast_all", "sequential"], ) - @jtu.run_on_devices("gpu") - def testFfiCallBatching(self, shape, dtype, vmap_method): + @jtu.run_on_devices("gpu", "cpu") + def testFfiCallBatching(self, shape, vmap_method): shape = (10,) + shape - pivots_size = shape[-1] - permutation_size = 2 * pivots_size - pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype) - pivots = jnp.broadcast_to(pivots, shape) - expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) - actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation( - x, permutation_size, vmap_method=vmap_method))(pivots) - self.assertArraysEqual(actual, expected) - - @jtu.run_on_devices("gpu") + x = self.rng().randn(*shape).astype(np.float32) + expected = lax_linalg_internal.geqrf(x) + actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x) + for a, b in zip(actual, expected): + if vmap_method == "sequential" and len(shape) == 3: + # On GPU, the batched FFI call to geqrf uses an algorithm with + # different numerics than the unbatched version (which is used when + # vmap_method="sequential"). Therefore, we need to include floating + # point tolerance for this check. + self.assertArraysAllClose(a, b) + else: + self.assertArraysEqual(a, b) + + @jtu.run_on_devices("gpu", "cpu") def testVectorizedDeprecation(self): - pivots_size = 4 - shape = (10, pivots_size) - permutation_size = 2 * pivots_size - pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, - dtype=np.int32) - pivots = jnp.broadcast_to(pivots, shape) + x = self.rng().randn(3, 5, 4).astype(np.float32) with self.assertWarns(DeprecationWarning): - ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) + ffi_call_geqrf(x, vectorized=True) with self.assertWarns(DeprecationWarning): - jax.vmap( - lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots) + jax.vmap(ffi_call_geqrf)(x) def testBackwardCompatSyntax(self): def fun(x): @@ -280,20 +272,82 @@ def fun(x): with self.assertWarns(DeprecationWarning): jax.jit(fun).lower(jnp.ones(5)) + def testInputOutputAliases(self): + def fun(x): + return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]") -# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` -# custom call target because that's the only one in jaxlib that uses the -# new FFI interface. Once more are available, consider using something that -# can be run on multiple platforms. -def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs): - return jex.ffi.ffi_call( - "cu_lu_pivots_to_permutation", - jax.ShapeDtypeStruct( - shape=pivots.shape[:-1] + (permutation_size,), - dtype=pivots.dtype, - ), - **kwargs, - )(pivots) + def testInvalidInputOutputAliases(self): + def fun(x): + return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x) + with self.assertRaisesRegex(ValueError, "with input index"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x) + with self.assertRaisesRegex(ValueError, "with output index"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32), + input_output_aliases={0: 0})(x) + with self.assertRaisesRegex(ValueError, + "referring to an input with abstract value"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape, + x.dtype), + input_output_aliases={0: 0})(x) + with self.assertRaisesRegex(ValueError, + "referring to an input with abstract value"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def testLegacyBackendConfig(self): + def fun(x): + return jex.ffi.ffi_call("test", x, custom_call_api_version=2, + legacy_backend_config="12345")(x) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertRegex(hlo, 'backend_config = "12345"') + + def testInvalidBackendConfig(self): + def fun(x): + return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x) + with self.assertRaisesRegex(ValueError, + "The use of the legacy_backend_config"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jex.ffi.ffi_call("test", x, + custom_call_api_version=2)(x, attribute=1) + with self.assertRaisesRegex(ValueError, + "The use of ffi_call attributes requires"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + +def ffi_call_geqrf(x, **kwargs): + assert x.dtype == np.float32 + ndim = x.ndim + x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) + output_types = [ + x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)] + + def call(platform, x): + target_name = dict( + cpu="lapack_sgeqrf_ffi", + rocm="hipsolver_geqrf_ffi", + cuda="cusolver_geqrf_ffi", + )[platform] + return jex.ffi.ffi_call( + target_name, output_types, input_output_aliases={0: 0}, + input_layouts=[x_major_to_minor], + output_layouts=[x_major_to_minor, None], + **kwargs)(x) + + return lax.platform_dependent( + x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"), + cuda=partial(call, "cuda")) class MlirRegisterLoweringTest(jtu.JaxTestCase): From 34611be53d72b516447683526baa1f8b975df408 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 25 Oct 2024 10:34:33 -0700 Subject: [PATCH 055/698] Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here: * Handled transpose of `dot_general` correctly with shardings * Handled transpose of `reduce_sum` correctly with shardings * `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet). * `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute. * (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed. * Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer. * Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`. PiperOrigin-RevId: 689837320 --- jax/_src/core.py | 41 ++- jax/_src/internal_test_util/test_harnesses.py | 3 +- jax/_src/lax/lax.py | 238 ++++++++++++------ jax/_src/pallas/mosaic/lowering.py | 3 +- jax/_src/pallas/mosaic_gpu/lowering.py | 2 + jax/_src/pallas/triton/lowering.py | 3 +- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/sparse/bcoo.py | 3 +- jax/experimental/sparse/bcsr.py | 3 +- tests/pjit_test.py | 163 +++++++++++- 10 files changed, 365 insertions(+), 96 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7674ba76da38..8379ce5e070f 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1801,8 +1801,12 @@ def __hash__(self): getattr(self, 'sharding', None))) def to_tangent_aval(self): - return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) + if config.sharding_in_types.value: + return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type, self.sharding) + else: + return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type) def join(self, other): if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype: @@ -1845,6 +1849,14 @@ def _get_shape_sharding_str(shape, spec): def _forward_to_value(self, fun, ignored_tracer, *args): return fun(self.val, *args) +def _get_abstract_sharding(val): + from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error + + if (config.sharding_in_types.value and hasattr(val, 'sharding') and + isinstance(val.sharding, NamedSharding)): + return NamedSharding(val.sharding.mesh.abstract_mesh, + val.sharding._normalized_spec(val.ndim)) + return None class ConcreteArray(ShapedArray): __slots__ = ['val'] @@ -1853,7 +1865,8 @@ class ConcreteArray(ShapedArray): def __init__(self, dtype, val, weak_type=None): super().__init__( np.shape(val), dtype, - weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type) + weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type, + sharding=_get_abstract_sharding(val)) dtypes.check_valid_dtype(self.dtype) # Note: canonicalized self.dtype doesn't necessarily match self.val assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype) @@ -2132,12 +2145,16 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None): if handler: return handler(aval, weak_type) raise TypeError(type(aval)) +def _shaped_array_mapping(aval, weak_type): + if config.sharding_in_types.value: + return ShapedArray(aval.shape, aval.dtype, weak_type, sharding=aval.sharding) + return ShapedArray(aval.shape, aval.dtype, weak_type) + raise_to_shaped_mappings: dict[type, Callable] = { AbstractToken: lambda aval, _: aval, Bot: lambda aval, _: aval, UnshapedArray: lambda aval, _: aval, - ShapedArray: lambda aval, weak_type: ShapedArray( - aval.shape, aval.dtype, weak_type), + ShapedArray: _shaped_array_mapping, DConcreteArray: lambda aval, weak_type: DShapedArray( aval.shape, aval.dtype, weak_type ), @@ -3073,10 +3090,16 @@ def substitute(aval: AbstractValue): return aval for v, x in zip(call_jaxpr.invars, in_atoms): if not typecompat(substitute(v.aval), x.aval): - # TODO(mattjj): vars in error message are confusing b/c of Var.__repr__ - raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " - f"{x.aval} to jaxpr expecting type " - f"{substitute(v.aval)}") + # TODO(yashkatariya): Remove this once numpy array's aval has a sharding + # on it. + if (config.sharding_in_types.value and isinstance(x, Literal) and + v.aval.sharding is not None and x.val.ndim == 0): + pass + else: + # TODO(mattjj): vars in error message are confusing b/c of Var.__repr__ + raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " + f"{x.aval} to jaxpr expecting type " + f"{substitute(v.aval)}") env[v] = x if type(x) is Var else x.val _check_jaxpr(ctx_factory, call_jaxpr) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 1c03158953f0..48c645c4d033 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -1058,7 +1058,8 @@ def _make_broadcast_in_dim_harness(name, lax.broadcast_in_dim_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{outshape=}_broadcastdimensions={broadcast_dimensions}", lambda operand: lax.broadcast_in_dim_p.bind( - operand, shape=outshape, broadcast_dimensions=broadcast_dimensions), + operand, shape=outshape, broadcast_dimensions=broadcast_dimensions, + sharding=None), [RandArg(shape, dtype)], shape=shape, dtype=dtype, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a41c4c4cec5a..ebcc5aac412d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -563,6 +563,10 @@ def _convert_element_type( new_dtype = np.dtype(new_dtype) new_dtype = dtypes.dtype(new_dtype, canonicalize=True) + if (config.sharding_in_types.value and sharding is None and + isinstance(operand, Array)): + sharding = operand.sharding + if (dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" @@ -1136,7 +1140,7 @@ def ragged_dot( group_offset=group_offset) -def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array: +def broadcast(operand: ArrayLike, sizes: Sequence[int], sharding=None) -> Array: """Broadcasts an array, adding new leading dimensions Args: @@ -1150,13 +1154,14 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int]) -> Array: See Also: jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape. """ - if len(sizes) == 0: + if len(sizes) == 0 and sharding is None: return asarray(operand) dims = tuple(range(len(sizes), len(sizes) + np.ndim(operand))) - return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims) + return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims, + sharding=sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int]) -> Array: + broadcast_dimensions: Sequence[int], sharding=None) -> Array: """Wraps XLA's `BroadcastInDim `_ operator. @@ -1174,7 +1179,11 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ - if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array): + if not config.sharding_in_types.value and sharding is not None: + raise NotImplementedError("sharding argument to broadcast_in_dim is only " + "allowed when sharding_in_types config is on.") + if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and + isinstance(operand, Array) and sharding is None): return operand if config.dynamic_shapes.value: # We must gate this behavior under a flag because otherwise the errors @@ -1184,7 +1193,8 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, dyn_shape, static_shape = [], shape # type: ignore return broadcast_in_dim_p.bind( operand, *dyn_shape, shape=tuple(static_shape), - broadcast_dimensions=tuple(broadcast_dimensions)) + broadcast_dimensions=tuple(broadcast_dimensions), + sharding=sharding) def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" @@ -1613,17 +1623,16 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, weak_type = dtype is None and dtypes.is_weakly_typed(fill_value) dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) fill_value = _convert_element_type(fill_value, dtype, weak_type) - # In tracing mode we can't set sharing explictly and PmapShardng is not - # supported. - # NB: Consider using with_sharding_constraint in jitted computation - # if needed? if (sharding is not None and not isinstance(sharding, PmapSharding) and isinstance(fill_value, array.ArrayImpl)): broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) return array.make_array_from_callback(shape, sharding, lambda _: shard) - return broadcast(fill_value, shape) + if config.sharding_in_types.value and sharding is not None: + return broadcast(fill_value, shape, sharding=sharding) + else: + return broadcast(fill_value, shape) def zeros_like_shaped_array(aval: ShapedArray) -> Array: assert isinstance(aval, ShapedArray) @@ -1821,22 +1830,26 @@ def full_like(x: ArrayLike | DuckTypedArray, if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] - # If `x` has a sharding but no `_committed` attribute - # (in case of ShapeDtypeStruct), default it to True. - use_x_sharding = ( - sharding is None - # Tracer have special logic in handling sharding and even - # though hasattr(x, 'sharding') returns False, it is very slow. - # This bypasses the check. - and not isinstance(x, core.Tracer) - and hasattr(x, 'sharding') - and getattr(x, '_committed', True) - and not weak_type - and fill_shape == np.shape(x) # type: ignore[arg-type] - ) - if use_x_sharding: - # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. - sharding = x.sharding # type: ignore + if (config.sharding_in_types.value and sharding is None and + isinstance(x, Array)): + sharding = x.sharding + else: + # If `x` has a sharding but no `_committed` attribute + # (in case of ShapeDtypeStruct), default it to True. + use_x_sharding = ( + sharding is None + # Tracer have special logic in handling sharding and even + # though hasattr(x, 'sharding') returns False, it is very slow. + # This bypasses the check. + and not isinstance(x, core.Tracer) + and hasattr(x, 'sharding') + and getattr(x, '_committed', True) + and not weak_type + and fill_shape == np.shape(x) # type: ignore[arg-type] + ) + if use_x_sharding: + # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. + sharding = x.sharding # type: ignore val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type), sharding=sharding) return val @@ -2076,18 +2089,9 @@ def broadcasting_shape_rule(name, *avals): def broadcasting_sharding_rule(name, *avals): - shapes = [aval.shape for aval in avals if aval.shape] - if not shapes: - return () - if len({len(shape) for shape in shapes}) != 1: - msg = '{}: arrays must have same number of dimensions, got {}.' - raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) - - specs = [a.sharding.spec for a in avals if a.shape] - mesh = None for a in avals: - if a.shape: + if a.sharding is not None: mesh = a.sharding.mesh if mesh is not None and mesh != a.sharding.mesh: raise ValueError( @@ -2095,6 +2099,15 @@ def broadcasting_sharding_rule(name, *avals): f' another mesh: {a.sharding.mesh}') assert mesh is not None + shapes = [aval.shape for aval in avals if aval.shape] + if not shapes: + return NamedSharding(mesh, P()) + if len({len(shape) for shape in shapes}) != 1: + msg = '{}: arrays must have same number of dimensions, got {}.' + raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) + + specs = [a.sharding.spec for a in avals if a.shape] + result_specs = [None] * len(shapes[0]) for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))): if all(s == ss[0] for s in ss[1:]): @@ -2981,7 +2994,7 @@ def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): operand = core.Primitive.bind(convert_element_type_p, operand, new_dtype=new_dtype, weak_type=weak_type, sharding=sharding) - if sharding is not None: + if sharding is not None and not config.sharding_in_types.value: operand = pjit.with_sharding_constraint(operand, sharding) return operand convert_element_type_p.def_custom_bind(_convert_element_type_bind) @@ -3014,6 +3027,8 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) out = mlir.convert_hlo(ctx, operand, aval_in, aval_out) if config.sharding_in_types.value: + if sharding is not None: + assert aval_out.sharding == sharding proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] return [out] @@ -3353,8 +3368,6 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, out_type, swap_ans=False): - if out_type is not None: - raise NotImplementedError (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim x_kept = remaining(range(x_ndim), x_contract, x_batch) @@ -3365,10 +3378,18 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept) dims = ((ans_y, y_kept), (ans_batch, y_batch)) x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) - out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) - x_bar = transpose(dot_general(g, y, dims, precision=precision, - preferred_element_type=preferred_element_type), - tuple(out_axes)) + unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y + out_axes = np.argsort(unsorted_axes) + if config.sharding_in_types.value: + xs = x.aval.sharding + inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) + ds = NamedSharding(xs.mesh, P(*inverse_spec)) + else: + ds = None + dot_general_out = dot_general(g, y, dims, precision=precision, + preferred_element_type=preferred_element_type, + out_type=ds) + x_bar = transpose(dot_general_out, tuple(out_axes)) if x_bar.dtype != x.aval.dtype: x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type) return x_bar @@ -3376,8 +3397,6 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, out_type): - if out_type is not None: - raise NotImplementedError (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) y_bar = _dot_general_transpose_lhs( @@ -3953,7 +3972,8 @@ def _ragged_dot_impl( mlir.register_lowering(ragged_dot_p, mlir.lower_fun(_ragged_dot_impl, multiple_results=False)) -def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): +def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, + sharding): _check_shapelike('broadcast_in_dim', 'shape', shape) _check_shapelike('broadcast_in_dim', 'broadcast_dimensions', broadcast_dimensions) @@ -3988,7 +4008,10 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): raise TypeError(msg.format(broadcast_dimensions)) return shape -def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions): +def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, + sharding): + if sharding is not None: + return sharding bds = set(broadcast_dimensions) orig_spec = iter(operand.sharding.spec) new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] @@ -3996,10 +4019,11 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions): return NamedSharding(operand.sharding.mesh, P(*new_spec)) def _broadcast_in_dim_typecheck_rule( - _, operand, *dyn_shape, shape, broadcast_dimensions): + _, operand, *dyn_shape, shape, broadcast_dimensions, sharding): if not dyn_shape: out_aval, effects = broadcast_in_dim_p.abstract_eval( - operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions) + operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) return [out_aval], effects else: # TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule @@ -4010,7 +4034,7 @@ def _broadcast_in_dim_typecheck_rule( return [out_aval], core.no_effects def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, - shape, broadcast_dimensions): + shape, broadcast_dimensions, sharding): if type(ct) is ad_util.Zero: return [ad_util.Zero(operand.aval)] unit_dims = [i for i, s in enumerate(operand.aval.shape) @@ -4021,7 +4045,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, [None] * len(dyn_shape)) def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape, - broadcast_dimensions): + broadcast_dimensions, sharding): # `dyn_shape` is the dynamic portion of the target shape. `shape` # is the target shape, with `None` for dynamic sections. # broadcast_dimensions gives indices where dimensions of the input @@ -4067,6 +4091,8 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape, assert len(sizes) == stacked_size, msg dyn_limits.append(bound) new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits) + if sharding is not None: + raise NotImplementedError('Implement broadcast_in_dim_batch_rule') result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions) out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None] out_bdim = batching.make_batch_axis( @@ -4081,8 +4107,9 @@ def _broadcast_in_dim_fwd_rule(eqn): return [None], eqn def _broadcast_in_dim_staging_rule( - trace, x, *dyn, shape, broadcast_dimensions): - params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions) + trace, x, *dyn, shape, broadcast_dimensions, sharding): + params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) if not dyn: return trace.default_process_primitive(broadcast_in_dim_p, (x,), params) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) @@ -4107,24 +4134,28 @@ def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape), broadcast_dimensions=broadcast_dimensions)] -def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions): +def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions, + sharding): operand, *dyn_shape = primals operand_dot, *_ = tangents y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape, - broadcast_dimensions=broadcast_dimensions) + broadcast_dimensions=broadcast_dimensions, + sharding=sharding) if type(operand_dot) is ad_util.Zero: y_dot = ad_util.Zero.from_primal_value(y) else: y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape, - broadcast_dimensions=broadcast_dimensions) + broadcast_dimensions=broadcast_dimensions, + sharding=sharding) return y, y_dot def _broadcast_in_dim_partial_eval( - trace, operand, *dyn_shape, shape, broadcast_dimensions): + trace, operand, *dyn_shape, shape, broadcast_dimensions, sharding): if not dyn_shape: return trace.default_process_primitive( broadcast_in_dim_p, (operand, *dyn_shape), - dict(shape=shape, broadcast_dimensions=broadcast_dimensions)) + dict(shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding)) assert all(t.pval.is_known() for t in dyn_shape) operand_tracer = trace.instantiate_const(operand) dyn_shape_tracers = map(trace.instantiate_const, dyn_shape) @@ -4134,34 +4165,40 @@ def _broadcast_in_dim_partial_eval( out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe( [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, - dict(shape=shape, broadcast_dimensions=broadcast_dimensions), + dict(shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=None), core.no_effects, source_info_util.current()) out_tracer.recipe = eqn return out_tracer -def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) -> Sequence[ir.Value]: +def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, + sharding) -> Sequence[ir.Value]: aval_out, = ctx.avals_out if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=broadcast_dimensions) if config.sharding_in_types.value: + if sharding is not None: + assert sharding == aval_out.sharding proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] return [out] -def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions): +def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, + sharding): if (not dyn_shape and not any(isinstance(d, core.DArray) and type(core.get_aval(d).dtype) is core.bint for d in shape)): shape = _broadcast_in_dim_shape_rule( # error checking - x, shape=shape, broadcast_dimensions=broadcast_dimensions) + x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None) if config.sharding_in_types.value: - sharding = _broadcast_in_dim_sharding_rule( - x, shape=shape, broadcast_dimensions=broadcast_dimensions) + new_sharding = _broadcast_in_dim_sharding_rule( + x, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) else: - sharding = None - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=sharding) + new_sharding = None + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code @@ -4499,6 +4536,23 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions): raise TypeError(msg.format(dimensions, np.shape(operand))) return tuple(new_sizes) +def _reshape_sharding_rule(operand, *, new_sizes, dimensions): + filtered_spec = [ + (sh, sp) for sh, sp in zip(operand.shape, operand.sharding.spec) + if sh != 1 + ] + fs = iter(filtered_spec) + new_spec = [] + for n in new_sizes: + if n == 1: + new_spec.append(None) + else: + sh, sp = next(fs) + if n != sh: + raise NotImplementedError + new_spec.append(sp) + return NamedSharding(operand.sharding.mesh, P(*new_spec)) + def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions): if not dyn_shape: out_aval, effects = reshape_p.abstract_eval( @@ -4539,7 +4593,11 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): x = hlo.transpose(x, mlir.dense_int_array(dimensions)) if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) - return [mlir.reshape(ctx, x, aval_out)] + out = mlir.reshape(ctx, x, aval_out) + if config.sharding_in_types.value: + proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [out] def _reshape_staging_rule( trace, x, *dyn, new_sizes, dimensions): @@ -4550,7 +4608,7 @@ def _reshape_staging_rule( return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape') + 'reshape', sharding_rule=_reshape_sharding_rule) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.primitive_batchers[reshape_p] = _reshape_batch_rule mlir.register_lowering(reshape_p, _reshape_lower) @@ -4644,6 +4702,18 @@ def _select_shape_rule(which, *cases): raise TypeError(msg.format(which.shape, cases[0].shape)) return cases[0].shape +def _select_sharding_rule(which, *cases): + if any(case.sharding != cases[0].sharding for case in cases[1:]): + msg = "select cases must have the same shardings, got [{}]." + raise TypeError(msg.format(", ".join([str(c.sharding) for c in cases]))) + if which.shape and which.sharding != cases[0].sharding: + raise TypeError( + 'select `which` must be scalar or have the same sharding as cases, got' + f' `which` sharding {which.sharding} but case sharding' + f' {cases[0].sharding}.') + return cases[0].sharding + + def _select_dtype_rule(which, *cases): check_same_dtypes("select", *cases) if (not dtypes.issubdtype(which.dtype, np.bool_) and @@ -4746,18 +4816,25 @@ def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in=[aval_which_bcast, *physical_avals_cases], avals_out=[physical_aval_out])[0] +def _add_shit_to_select(ctx, op, aval_out): + if config.sharding_in_types.value: + proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return mlir.wrap_with_sharding_op(ctx, op, aval_out, proto) + return op def _select_hlo_lowering(ctx, which, *cases): which_aval = ctx.avals_in[0] aval_out, = ctx.avals_out if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - return [_select_hlo_lowering_opaque(ctx, which, *cases)] + op = _select_hlo_lowering_opaque(ctx, which, *cases) + return [_add_shit_to_select(ctx, op, aval_out)] if which_aval.dtype == np.dtype(np.bool_): assert len(cases) <= 2 if len(cases) == 1: return cases - return [hlo.select(which, cases[1], cases[0])] + op = hlo.select(which, cases[1], cases[0]) + return [_add_shit_to_select(ctx, op, aval_out)] if dtypes.issubdtype(which_aval.dtype, np.signedinteger): compare_type = 'SIGNED' @@ -4776,11 +4853,12 @@ def _select(offset, cases): return hlo.select(pred, _select(offset, cases[:mid]), _select(offset + mid, cases[mid:])) - return [_select(0, cases)] + op = _select(0, cases) + return [_add_shit_to_select(ctx, op, aval_out)] select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', - weak_type_rule=_select_weak_type_rule) + weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule @@ -4918,7 +4996,11 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) - result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) + if config.sharding_in_types.value: + result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions, + sharding=operand.aval.sharding) + else: + result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) assert result.shape == input_shape return [result] @@ -6015,7 +6097,13 @@ def _const(example, val): _zeros: Callable = partial(full_like, fill_value=0) _zero: Callable = partial(full_like, shape=(), fill_value=0) _ones: Callable = partial(full_like, fill_value=1) -_one: Callable = partial(full_like, shape=(), fill_value=1) + +def _one(x): + if config.sharding_in_types.value: + return full_like(x, shape=(), fill_value=1, + sharding=NamedSharding(x.sharding.mesh, P())) + return full_like(x, shape=(), fill_value=1) + _twos: Callable = partial(full_like, fill_value=2) _two: Callable = partial(full_like, shape=(), fill_value=2) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 8b2c18165f61..85940b4ee975 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1521,8 +1521,9 @@ def _proxy_reduce(arg, *, axes): def _broadcast_in_dim_lowering_rule( - ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions + ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding ): + del sharding (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 7517597d637f..197e189adc2c 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1037,7 +1037,9 @@ def _broadcast_in_dim_lowering_rule( *, broadcast_dimensions, shape, + sharding, ): + del sharding [x_aval] = ctx.avals_in [y_aval] = ctx.avals_out x = _ensure_fa(x, x_aval.dtype) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 79919e638d1d..605b975fce25 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1586,8 +1586,9 @@ def select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, y): @register_lowering(lax.broadcast_in_dim_p) def _broadcast_in_dim_lowering_rule( - ctx: LoweringRuleContext, x, *, broadcast_dimensions, shape + ctx: LoweringRuleContext, x, *, broadcast_dimensions, shape, sharding ): + del sharding x = _ensure_ir_value(x, *ctx.avals_in) if not ir.RankedTensorType.isinstance(x.type): return _bcast_to(x, shape) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 348baa868f09..42d3cb57f4e5 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2267,7 +2267,7 @@ def _dot_general_convert_to_common_dtype( convert_result = lambda res: res return (lhs, rhs, convert_result) -def _broadcast_in_dim(operand, *, shape, broadcast_dimensions, +def _broadcast_in_dim(operand, *, shape, broadcast_dimensions, sharding=None, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): # for i in range(len(operand.shape)): diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 477f634744ed..9dcb0cadc1b2 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1695,7 +1695,8 @@ def _update(d, i): return BCOO((new_data, new_indices), shape=shape) -def bcoo_broadcast_in_dim(mat: BCOO, *, shape: Shape, broadcast_dimensions: Sequence[int]) -> BCOO: +def bcoo_broadcast_in_dim(mat: BCOO, *, shape: Shape, broadcast_dimensions: Sequence[int], + sharding=None) -> BCOO: """Expand the size and rank of a BCOO array by duplicating the data. A BCOO equivalence to jax.lax.broadcast_in_dim. diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 372bce0344ba..ed7e53d4c64e 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -714,7 +714,8 @@ def _bcsr_dot_general_gpu_lowering( #---------------------------------------------------------------------- # BCOO functions that maybe should be primitives? -def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequence[int]) -> BCSR: +def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequence[int], + sharding=None) -> BCSR: result_bcoo = bcoo.bcoo_broadcast_in_dim( mat.to_bcoo(), shape=shape, broadcast_dimensions=broadcast_dimensions) return BCSR.from_bcoo(result_bcoo) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9e3fdde3d13e..5f44d5892044 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4598,7 +4598,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): def test_basic_mul(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(8, 2) + np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4621,6 +4621,17 @@ def f(x): else: self.assertEqual(lowered_text.count('@Sharding'), 2) + @jax.jit + def g(x): + x = f(x) + return jnp.sum(x) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + def test_fully_replicated_array_mul(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp1 = np.arange(16).reshape(8, 2) @@ -4670,9 +4681,9 @@ def g(x, y): ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), ('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce') ) - def test_dot_general_basic(self, spec1, spec2, out_spec, collective_name): + def test_dot_general(self, spec1, spec2, out_spec, collective_name): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp1 = np.arange(16).reshape(8, 2) + np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) @@ -4693,6 +4704,55 @@ def f(x, y): if collective_name is not None and compiled_text is not None: self.assertIn(collective_name, compiled_text) + @jax.jit + def g(x, y): + out = f(x, y) + return jnp.sum(out) + + out = jax.grad(g, argnums=(0, 1))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + out = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + def test_dot_general_out_type(self): + mesh = jtu.create_mesh((4,), ('x',)) + np_inp1 = np.arange(16.).reshape(8, 2) + arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) + arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) + + @jax.jit + def f(x, y): + out = jnp.einsum('xy,yz->xz', x, y, + out_type=NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(out.sharding.spec, P('x', None)) + return jnp.sum(out) + + out = f(arr1, arr2) + self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T)) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + out = jax.grad(f, argnums=(0, 1))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + jitted_grad = jax.jit(jax.grad(f, argnums=(0, 1))) + out = jitted_grad(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + jaxpr = jitted_grad.trace(arr1, arr2).jaxpr + bwd_jaxpr = jaxpr.eqns[1] + expected_spec = [('broadcast_in_dim', P('x', None)), + ('dot_general', P('x', None)), + ('transpose', P(None, 'x')), + ('dot_general', P('x', None))] + for eqn, spec in zip(bwd_jaxpr.params['jaxpr'].eqns, expected_spec): + self.assertEqual(eqn.primitive.name, spec[0]) + self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec[1]) + @parameterized.named_parameters( ('fail1', P('x', 'y'), P('y', 'x'), "PartitionSpec.*x.*x.*has duplicate entries", ValueError), @@ -4791,7 +4851,7 @@ def f(x): ) def test_reduce_max(self, axis, in_spec, out_spec, reduce=True): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(8, 2) + np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, in_spec) arr = jax.device_put(np_inp, s) @@ -4813,6 +4873,17 @@ def f(x): if reduce and compiled_text is not None: self.assertIn('all-reduce', compiled_text) + @jax.jit + def g(x): + out = f(x) + return jnp.mean(out) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + @parameterized.named_parameters( ('0', 0, P(None, 'x', 'y')), ('1', 1, P('x', None, 'y')), @@ -4946,7 +5017,7 @@ def g(x): def test_einsum_with_out_type(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(8, 2) + np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -4977,9 +5048,22 @@ def g(x, y): self.assertArraysEqual(out2, np_inp @ np_inp.T) self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None))) + @jax.jit + def h2(x, y): + out = g(x, y) + return jnp.sum(out) + + out = jax.grad(h2, argnums=(0, 1))(arr3, arr4) + self.assertEqual(out[0].sharding, arr3.sharding) + self.assertEqual(out[1].sharding, arr4.sharding) + + out = jax.jit(jax.grad(h2, argnums=(0, 1)))(arr3, arr4) + self.assertEqual(out[0].sharding, arr3.sharding) + self.assertEqual(out[1].sharding, arr4.sharding) + def test_einsum_inverse(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(64) + np_inp = np.arange(64.) @jax.jit def h(x, y): @@ -4998,6 +5082,73 @@ def h(x, y): lowered_text = h.lower(arr1, arr2).as_text() self.assertIn('@Sharding', lowered_text) + @jax.jit + def h2(x, y): + out = h(x, y) + return jnp.sum(out) + + out = jax.grad(h2, argnums=(0, 1))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + out = jax.jit(jax.grad(h2, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) + + @parameterized.named_parameters( + ('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False), + ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), + ('3', (8, 1), (1, 4, 2), P('x', None), P(None, 'x', None), True) + ) + def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, will_error): + mesh = jtu.create_mesh((2,), ('x',)) + np_inp = np.arange(math.prod(src_shape), + dtype=np.float32).reshape(src_shape) + arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) + + @jax.jit + def f(x): + y = jnp.reshape(x, dst_shape) + y = y * 2 + self.assertEqual(y.sharding.spec, dst_spec) + return y + + if will_error: + with self.assertRaises(NotImplementedError): + f(arr) + else: + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec)) + self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2) + + lowered_text = f.lower(arr).as_text() + self.assertIn('@Sharding', lowered_text) + + def test_select(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np_inp, s) + + @jax.jit + def f(pred, on_true, on_false): + y = lax.select(pred, on_true, on_false) + self.assertEqual(y.sharding.spec, s.spec) + return y + + out = f(arr1 == arr2, arr1, arr2) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, arr1) + + lowered_text = f.lower(arr1 == arr2, arr1, arr2).as_text() + self.assertIn('@Sharding', lowered_text) + + arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x'))) + with self.assertRaisesRegex( + TypeError, "select cases must have the same shardings"): + f(arr1 == arr2, arr1, arr3) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 2a671e25a78c74a03a4d568b4c2ad46394751f04 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Fri, 25 Oct 2024 11:21:39 -0700 Subject: [PATCH 056/698] [Mosaic TPU] Remove extra check PiperOrigin-RevId: 689852989 --- jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 6bec22403724..47e3baa0ff09 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1620,7 +1620,7 @@ class VectorLayoutInferer { // TODO(jevinjiang): We can fold the sublane offset into the 2nd minor // index. But we need to handle negative index in lower-to-llo. For // now, we just force the sublane offset to be 0. - if (offsets[1].value_or(0) < 0 || offsets[1].value_or(0) >= tiling[1]) { + if (offsets[1].value_or(0) >= tiling[1]) { offsets[1] = 0; } store_layout = VectorLayout(bitwidth, {0, offsets[1]}, From 7db4b254e0b84730b9816d8857d19b82abd36de3 Mon Sep 17 00:00:00 2001 From: Brian Wieder Date: Fri, 25 Oct 2024 11:34:03 -0700 Subject: [PATCH 057/698] Clear extra_jit_context when exiting. In for some reason, extra_jit_context was leaking when `pallas.core` no longer imported `pallas.pallas_call`, leading to leaking XLA Clients. PiperOrigin-RevId: 689857071 --- jax/_src/api.py | 1 + jax/_src/pallas/mosaic/core.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index d2ac5465eded..92f8a057f9f7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2770,6 +2770,7 @@ def clear_backends(): pjit._cpp_pjit_cache_fun_only.clear() pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() + xc._xla.jax_jit.thread_local_state().extra_jit_context = None @atexit.register def clean_up(): diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 6e16df2e54de..12ae5350e725 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -28,12 +28,9 @@ from jax._src import dtypes from jax._src import util from jax._src.pallas import core as pallas_core -from jax._src.pallas import pallas_call import jax.numpy as jnp import numpy as np -# TODO(b/375357542): Remove the import once the bug is fixed. -_ = pallas_call map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip From 5afdbcbae7e69e595acd829f10734ca2c784d432 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 25 Oct 2024 12:04:32 -0700 Subject: [PATCH 058/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/79142a8fdd3bb7556ab0c6d01e691ecb5f7805c9. PiperOrigin-RevId: 689867508 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b8dcce3a5396..c1b1e9e8f894 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1f6bd971dd2f531554eac88c0868b952d2543491" -XLA_SHA256 = "30225604bae42819e4212a82e6c17721c7d3bf146e6fb01dfed7f378c7ff6c49" +XLA_COMMIT = "79142a8fdd3bb7556ab0c6d01e691ecb5f7805c9" +XLA_SHA256 = "ad57f05faac50fd67ccfb22816905095cb7abcda36d3dcbc4844d8ef9b61efe8" def repo(): tf_http_archive( From 6f371212d972a2017fb58e621268e446d33e3235 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 25 Oct 2024 12:06:59 -0700 Subject: [PATCH 059/698] Implements an alternate version of ragged_attention, wherein, the actual attention kernel itself is dense. Meaning, this kernel does not have the compute saving (@when wrapped kernel) or prefetch/index skipping (via index rewriting) as part of the kernel. Rather, the kernel is invoked with a Jumble (A ragged type representation) and pallas takes care of applying the correct work skipping and index rewriting. Performance wise, we should be at parity, although this has not yet been tested. Authoring wise, the new kernel is significantly smaller and simpler to write. A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs. PiperOrigin-RevId: 689868468 --- jax/_src/interpreters/batching.py | 21 ++++++++++----- jax/_src/lax/lax.py | 43 ++++++++++++++++++++++++++++--- jax/_src/pallas/pallas_call.py | 10 ++++++- jax/_src/pjit.py | 1 + jax/_src/state/primitives.py | 2 +- 5 files changed, 65 insertions(+), 12 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index eb174cc5c052..b40a3807dea2 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -264,10 +264,17 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: + if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis): + # TODO(mvoz): A vaguely questionable assumption that it is always + # sound to have a 0 axis here. This is true for the current use cases + # and comes from how we handle intermediary products of jumbles in + # vmap. + return BatchTracer(trace, x, 0, source_info_util.current()) # TODO(mvoz): This is a terrible place to fall into if you pass # a non jumble type in, make it clearer what went wrong. assert False, f'Unexpected type in ELT? {type(x)}' + to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, @@ -328,16 +335,16 @@ def flatten_fun_for_vmap(in_tree, *args_flat): yield tree_flatten(ans, is_leaf=is_vmappable) # Propagate ragged masking rules from invars to outvars -# rule([raggedness_per_invar], outvars) -> +# rule([params], [raggedness_per_invar], outvars) -> # [raggedness_per_invar, raggedness_per_outvar] RaggedMaskingRule = Callable[ - [list[Any], list[Any]], tuple[list[Any], list[Any]] + [list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]] ] ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {} -def ragged_mask_elementwise_rule(invar_raggedness, outvars): +def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars): # TODO(mvoz): A util for getting the ragged representations first_invar_raggedness = invar_raggedness[0] for other_invar_raggedness in invar_raggedness[1:]: @@ -348,17 +355,19 @@ def ragged_mask_elementwise_rule(invar_raggedness, outvars): return invar_raggedness, outvar_raggedness -def ragged_mask_assert_no_op_rule(invar_raggedness, outvars): +def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars): if any(invar_raggedness): raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}') return invar_raggedness, [None] * len(outvars) -def ragged_mask_no_op_rule(invar_raggedness, outvars): +def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars): return invar_raggedness, [None] * len(outvars) -def ragged_mask_transfer_identity(invar_raggedness, outvar_raggedness): +def ragged_mask_transfer_identity( + eqn_params, invar_raggedness, outvar_raggedness +): assert len(invar_raggedness) == 1, invar_raggedness outvar_raggedness = invar_raggedness return invar_raggedness, outvar_raggedness diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ebcc5aac412d..9ed1d55cd936 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2296,6 +2296,7 @@ def _round_lower(ctx, x, *, rounding_method): exp_p = standard_unop(_float | _complex, 'exp') ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) +batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule exp2_p = standard_unop(_float | _complex, 'exp2') ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) @@ -2746,6 +2747,7 @@ def _sub_transpose(t, x, y): ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract)) +batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule def _mul_transpose(ct, x, y): @@ -2767,6 +2769,7 @@ def _mul_transpose(ct, x, y): lambda ydot, x, y: mul(x, ydot)) ad.primitive_transposes[mul_p] = _mul_transpose mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply)) +batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -2780,6 +2783,7 @@ def _div_transpose_rule(cotangent, x, y): lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2))) ad.primitive_transposes[div_p] = _div_transpose_rule mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide)) +batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( @@ -2803,12 +2807,14 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo)) +batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule min_p: core.Primitive = standard_naryop([_any, _any], 'min') ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo)) +batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) @@ -2895,6 +2901,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt') ad.defjvp_zero(lt_p) mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False)) +batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to') ad.defjvp_zero(eq_to_p) @@ -3536,12 +3543,37 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -def _dot_general_ragged_prop_rule(invar_raggedness, outvars): +def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 2 assert len(outvars) == 1 invar_raggedness_lhs = invar_raggedness[0] invar_raggedness_rhs = invar_raggedness[1] + dimension_numbers = eqn_params['dimension_numbers'] + (lhs_contracting, rhs_contracting), (_, _) = dimension_numbers + + if not invar_raggedness_lhs and not invar_raggedness_rhs: + # Both are dense - it is valid to reach here, because dense operations + # are legal in code running under ragged prop. + return invar_raggedness, [None] + + if not invar_raggedness_lhs or not invar_raggedness_rhs: + # One ragged, one dense + if not invar_raggedness_lhs: + # left is dense, right is ragged + _, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs + if rhs_contracting != ragged_axis_dim_rhs: + # Contraction is on a dense dimension, this is valid! + return invar_raggedness, [None] + if not invar_raggedness_rhs: + # left is ragged, right is dense + _, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs + if lhs_contracting != ragged_axis_dim_lhs: + # Contraction is on a dense dimension, this is valid! + return invar_raggedness, [None] + + raise NotImplementedError('NYI - dense and ragged dim contraction') + stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs @@ -3560,9 +3592,8 @@ def _dot_general_ragged_prop_rule(invar_raggedness, outvars): assert len(outvars) == 1 # TODO(mvoz): A constant on batching.* ? - dense_jumble_raggedness = None # Dense (m, n) - no jumble only atm - return invar_raggedness, [dense_jumble_raggedness] + return invar_raggedness, [None] dot_general_p = standard_primitive( @@ -4205,7 +4236,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type) -def _broadcast_in_dim_ragged_prop_rule(invar_raggedness, outvars): +def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 1 assert not isinstance(invar_raggedness[0], core.Var) return invar_raggedness, [None] * len(outvars) @@ -5040,6 +5071,7 @@ def _reduce_op_sharding_rule(operand, *, axes): batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum, _get_sum_identity) +batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule def _reduce_prod_jvp_rule(primals, tangents, *, axes): reducer = lambda x, y: [mul(x, y)] @@ -5074,6 +5106,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max, _get_max_identity) +batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype, @@ -5854,9 +5887,11 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) + iota_p = Primitive('iota') iota_p.def_impl(partial(dispatch.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) +batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding): params = dict(dtype=dtype, shape=shape, dimension=dimension, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2bed4a0830ea..e20c7783439e 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -945,7 +945,15 @@ def get_size(i, x, d): ) for invar in eqn.invars ] - invar_raggedness, outvar_raggedness = rule(invar_raggedness, eqn.outvars) + try: + invar_raggedness, outvar_raggedness = rule( + eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type] + ) + except Exception as e: + raise RuntimeError( + f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:" + f" {eqn.outvars}. Underlying reason: {e}" + ) from e for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment] if isinstance(invar, jax_core.Var): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 2abf81f26aa4..5b8856aa619e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1984,6 +1984,7 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None) +batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule def _pjit_batcher_for_sharding( s: sharding.Sharding | UnspecifiedValue, diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index a0f70a126c8e..7724466d3110 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -127,7 +127,7 @@ def ref_get( swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) -def swap_ragged_prop_rule(invar_raggedness, outvars): +def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 2 invar_raggedness_lhs = invar_raggedness[0] invar_raggedness_rhs = invar_raggedness[1] From 5a2128e44bf31dadc2a0484336f400b308b57bff Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 25 Oct 2024 12:16:18 -0700 Subject: [PATCH 060/698] [pallas] Removed deprecated aliases to `CostEstimate` and `run_scoped` PiperOrigin-RevId: 689871787 --- docs/pallas/CHANGELOG.md | 11 +++++++++- jax/BUILD | 1 + jax/_src/pallas/mosaic/BUILD | 7 ++++++ .../pallas/mosaic/pallas_call_registration.py | 22 +++---------------- jax/experimental/pallas/tpu.py | 6 ----- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index d7ed91011a95..bd86741c9165 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,7 +11,16 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c Remember to align the itemized text with the first line of an item within a list. --> -## Released with jax 0.4.34 +## Released with jax 0.4.35 + +* Removals + + * Removed previously deprecated aliases + {class}`jax.experimental.pallas.tpu.CostEstimate` and + {func}`jax.experimental.tpu.run_scoped`. Both are now available in + {mod}`jax.experimental.pallas`. + +## Released with jax 0.4.34 (October 4, 2024) * Changes diff --git a/jax/BUILD b/jax/BUILD index 12c239a2d63e..7a970b13aef3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -931,6 +931,7 @@ pytype_strict_library( ":mlir", ":sharding_impls", "//jax/_src/lib", + "//jax/_src/pallas", ] + if_building_jaxlib([ "//jaxlib/mlir:ir", "//jaxlib/mlir:mhlo_dialect", diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index f52ba9ddd6cd..bf0d83bb3dc9 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -74,10 +74,17 @@ py_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", + ":verification", "//jax", + "//jax:config", + "//jax:core", + "//jax:mlir", "//jax:mosaic", + "//jax:sharding_impls", "//jax:source_info_util", + "//jax:tpu_custom_call", "//jax/_src/lib", "//jax/_src/pallas", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 2bf96511b64e..e34b5dbdd162 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -19,7 +19,6 @@ import os import tempfile from typing import Any -import warnings import jax from jax import core as jax_core @@ -27,16 +26,16 @@ from jax._src import config from jax._src import core as jax_src_core from jax._src import sharding_impls +from jax._src import tpu_custom_call from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering from jax._src.pallas.mosaic import verification -from jax._src import tpu_custom_call from jax.experimental import mosaic from jax.experimental.mosaic.dialects import tpu -from jax.experimental.pallas import tpu as pltpu + def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): """Casts boolean values to integers. @@ -126,21 +125,6 @@ def pallas_call_tpu_lowering_rule( else: mosaic_params = {} - if "cost_estimate" in mosaic_params: - # TODO(amagni): Remove this branch after October 22th 2024. - if cost_estimate is not None: - raise ValueError( - "Passing cost estimate via both compiler_params=dict(mosaic=...) and" - " pallas_call(..., cost_estimate=...) is not supported." - ) - - warnings.warn( - "Passing cost estimate via compiler_params=dict(cost_estimate=...) is" - " deprecated. Use pallas_call(..., cost_estimate=...) instead.", - DeprecationWarning, - ) - cost_estimate = mosaic_params["cost_estimate"] - mesh = None axis_context = ctx.module_context.axis_context if axis_context is not None: @@ -222,7 +206,7 @@ def _maybe_cast_inputs(*args): kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) output_memory_spaces = _get_memory_spaces_from_avals(out_avals) if cost_estimate is not None: - mosaic_cost_estimate = pltpu.CostEstimate( + mosaic_cost_estimate = tpu_custom_call.CostEstimate( flops=cost_estimate.flops, bytes_accessed=cost_estimate.bytes_accessed, transcendentals=cost_estimate.transcendentals, diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index d00e0e90cd3e..41f0de3a0f61 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -49,12 +49,6 @@ from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key -# Remove this import after October 22th 2024. -from jax._src.tpu_custom_call import CostEstimate as CostEstimate - -# TODO(cperivol): Temporary alias to the global run_scoped. Remove -# this once everyone has migrated to the pallas core one. -from jax._src.pallas.primitives import run_scoped as run_scoped import types from jax._src.pallas.mosaic.verification import assume From adf14928439a1aced2a1d36c4a4c044e861eede0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 25 Oct 2024 13:14:44 -0700 Subject: [PATCH 061/698] Add some missing jax.numpy documentation --- jax/_src/dtypes.py | 32 +++++++++++++++++++++++++++++++- jax/_src/numpy/lax_numpy.py | 6 ++++++ jax/_src/numpy/ufunc_api.py | 23 +++++++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 82be38d1cb57..d2a55933cad9 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -651,7 +651,8 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy def promote_types(a: DTypeLike, b: DTypeLike) -> DType: """Returns the type to which a binary operation should cast its arguments. - For details of JAX's type promotion semantics, see :ref:`type-promotion`. + JAX implementation of :func:`numpy.promote_types`. For details of JAX's + type promotion semantics, see :ref:`type-promotion`. Args: a: a :class:`numpy.dtype` or a dtype specifier. @@ -659,6 +660,35 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType: Returns: A :class:`numpy.dtype` object. + + Examples: + Type specifiers may be strings, dtypes, or scalar types, and the return + value is always a dtype: + + >>> jnp.promote_types('int32', 'float32') # strings + dtype('float32') + >>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes + dtype('float32') + >>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types + dtype('float32') + + Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are + treated as weakly-typed and will not change the bit width of a strongly-typed + counterpart (see discussion in :ref:`type-promotion`): + + >>> jnp.promote_types('uint8', int) + dtype('uint8') + >>> jnp.promote_types('float16', float) + dtype('float16') + + This differs from the NumPy version of this function, which treats built-in scalar + types as equivalent to 64-bit types: + + >>> import numpy + >>> numpy.promote_types('uint8', int) + dtype('int64') + >>> numpy.promote_types('float16', float) + dtype('float64') """ # Note: we deliberately avoid `if a in _weak_types` here because we want to check # object identity, not object equality, due to the behavior of np.dtype.__eq__ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ffdeca84aaad..ee33be8a10d8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -194,6 +194,12 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), {"dtype": np.dtype(np_scalar_type)}) meta.__module__ = _PUBLIC_MODULE_NAME + meta.__doc__ =\ + f"""A JAX scalar constructor of type {np_scalar_type.__name__}. + + While NumPy defines scalar types for each data type, JAX represents + scalars as zero-dimensional arrays. + """ return meta bool_ = _make_scalar_type(np.bool_) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 3473e8a7468a..27e2973b212b 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -598,5 +598,28 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, Returns: wrapped : jax.numpy.ufunc wrapper of func. + + Examples: + Here is an example of creating a ufunc similar to :obj:`jax.numpy.add`: + + >>> import operator + >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0) + + Now all the standard :class:`jax.numpy.ufunc` methods are available: + + >>> x = jnp.arange(4) + >>> add(x, 10) + Array([10, 11, 12, 13], dtype=int32) + >>> add.outer(x, x) + Array([[0, 1, 2, 3], + [1, 2, 3, 4], + [2, 3, 4, 5], + [3, 4, 5, 6]], dtype=int32) + >>> add.reduce(x) + Array(6, dtype=int32) + >>> add.accumulate(x) + Array([0, 1, 3, 6], dtype=int32) + >>> add.at(x, 1, 10, inplace=False) + Array([ 0, 11, 2, 3], dtype=int32) """ return ufunc(func, nin, nout, identity=identity) From 94440c74c889e14cffe55a45e6cf8956b3ed87c3 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 25 Oct 2024 13:19:50 -0700 Subject: [PATCH 062/698] Register acos primitive to lower to CHLO acos. Related: https://github.com/openxla/stablehlo/pull/2496 PiperOrigin-RevId: 689890774 --- jax/_src/lax/lax.py | 20 +------------------ jax/experimental/jax2tf/jax2tf.py | 33 +++++++++++++++++++++++++++++-- tests/filecheck/math.filecheck.py | 2 +- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9ed1d55cd936..c0c594c4abdc 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2412,27 +2412,9 @@ def asin_impl(x): ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin)) -def acos_impl(x): - if dtypes.issubdtype(_dtype(x), np.complexfloating): - result = mul(_const(x, 1j), acosh(x)) - # By convention, numpy chooses the branch with positive real part. - rpart = real(result) - return select( - gt(rpart, _const(rpart, 0)), - result, - neg(result) - ) - else: - return select( - ne(x, _const(x, -1.0)), - mul(_const(x, 2), - atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))), - full_like(x, np.pi)) - acos_p = standard_unop(_float | _complex, 'acos') ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) -mlir.register_lowering(acos_p, - mlir.lower_fun(acos_impl, multiple_results=False)) +mlir.register_lowering(acos_p, partial(_nary_lower_hlo, chlo.acos)) def atan_impl(x): return atan2(x, _const(x, 1)) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 42d3cb57f4e5..273f756fe634 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1700,8 +1700,37 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl[lax.sinh_p] = tf.math.sinh tf_impl[lax.cos_p] = tf.math.cos tf_impl[lax.cosh_p] = tf.math.cosh -tf_impl_with_avals[lax.acos_p] = _convert_jax_impl( - lax_internal.acos_impl, multiple_results=False) + + +def _acos_impl(x): + if x.dtype.is_complex: + result = tf.multiply(tf.constant(1j, dtype=x.dtype), tf.math.acosh(x)) + # By convention, numpy chooses the branch with positive real part. + rpart = tf.math.real(result) + return tf.where( + tf.math.greater(rpart, tf.constant(0, dtype=rpart.dtype)), + result, + tf.math.negative(result), + ) + else: + return tf.where( + tf.math.not_equal(x, tf.constant(-1.0, dtype=x.dtype)), + tf.multiply( + tf.constant(2, dtype=x.dtype), + tf.math.atan2( + tf.math.sqrt( + tf.math.subtract( + tf.constant(1, dtype=x.dtype), tf.math.square(x) + ) + ), + tf.math.add(tf.constant(1, dtype=x.dtype), x), + ), + ), + tf.broadcast_to(tf.constant(np.pi, dtype=x.dtype), tf.shape(x)), + ) + + +tf_impl_with_avals[lax.acos_p] = _acos_impl tf_impl_with_avals[lax.asin_p] = _convert_jax_impl( lax_internal.asin_impl, multiple_results=False) tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index f34b8211eb33..53f69cdb3f19 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -41,7 +41,7 @@ def main(_): print_ir(np.float32(1), np.float32(2))(lax.add) # CHECK-LABEL: TEST: acos float32[] - # CHECK: hlo.atan2 + # CHECK: chlo.acos # CHECK-SAME: tensor print_ir(np.float32(1))(lax.acos) From 02daf75f9759938db860814def47950d9bc71f6c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 25 Oct 2024 13:45:54 -0700 Subject: [PATCH 063/698] Add new jnp.cumulative_prod function. This follows the API of the similar function added in NumPy 2.1.0 --- docs/jax.numpy.rst | 1 + jax/_src/numpy/reductions.py | 65 +++++++++++++++++++++++++++++- jax/numpy/__init__.py | 3 +- jax/numpy/__init__.pyi | 3 ++ tests/lax_numpy_reducers_test.py | 69 +++++++++++++++++++++++++++----- 5 files changed, 127 insertions(+), 14 deletions(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 9eb518464b4e..3922c92d98de 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -138,6 +138,7 @@ namespace; they are listed below. csingle cumprod cumsum + cumulative_prod cumulative_sum deg2rad degrees diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 1c2a4689cb85..fa8d73361e2b 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -1838,7 +1838,7 @@ def __call__(self, a: ArrayLike, axis: Axis = None, def _cumulative_reduction( name: str, reduction: Callable[..., Array], - a: ArrayLike, axis: int | None, dtype: DTypeLike | None, out: None, + a: ArrayLike, axis: int | None, dtype: DTypeLike | None, out: None = None, fill_nan: bool = False, fill_value: ArrayLike = 0, promote_integers: bool = False) -> Array: """Helper function for implementing cumulative reductions.""" @@ -2064,7 +2064,7 @@ def cumulative_sum( Args: x: N-dimensional array axis: integer axis along which to accumulate. If ``x`` is one-dimensional, - this argument is optional. + this argument is optional and defaults to zero. dtype: optional dtype of the output. include_initial: if True, then include the initial value in the cumulative sum. Default is False. @@ -2113,6 +2113,67 @@ def cumulative_sum( dimension=axis) return out + +def cumulative_prod( + x: ArrayLike, /, *, axis: int | None = None, + dtype: DTypeLike | None = None, + include_initial: bool = False) -> Array: + """Cumulative product along the axis of an array. + + JAX implementation of :func:`numpy.cumulative_prod`. + + Args: + x: N-dimensional array + axis: integer axis along which to accumulate. If ``x`` is one-dimensional, + this argument is optional and defaults to zero. + dtype: optional dtype of the output. + include_initial: if True, then include the initial value in the cumulative + product. Default is False. + + Returns: + An array containing the accumulated values. + + See Also: + - :func:`jax.numpy.cumprod`: alternative API for cumulative product. + - :func:`jax.numpy.nancumprod`: cumulative product while ignoring NaN values. + - :func:`jax.numpy.multiply.accumulate`: cumulative product via the ufunc API. + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumulative_prod(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + >>> jnp.cumulative_prod(x, axis=1, include_initial=True) + Array([[ 1, 1, 2, 6], + [ 1, 4, 20, 120]], dtype=int32) + """ + check_arraylike("cumulative_prod", x) + x = lax_internal.asarray(x) + if x.ndim == 0: + raise ValueError( + "The input must be non-scalar to take a cumulative product, however a " + "scalar value or scalar array was given." + ) + if axis is None: + axis = 0 + if x.ndim > 1: + raise ValueError( + f"The input array has rank {x.ndim}, however axis was not set to an " + "explicit value. The axis argument is only optional for one-dimensional " + "arrays.") + + axis = _canonicalize_axis(axis, x.ndim) + dtypes.check_user_dtype_supported(dtype) + out = _cumulative_reduction("cumulative_prod", lax.cumprod, x, axis, dtype) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = lax_internal.concatenate( + [lax_internal.full(zeros_shape, 1, dtype=out.dtype), out], + dimension=axis) + return out + # Quantiles # TODO(jakevdp): interpolation argument deprecated 2024-05-16 diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d50d55033c33..93405cc03ef7 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -307,8 +307,9 @@ all as all, average as average, count_nonzero as count_nonzero, - cumsum as cumsum, cumprod as cumprod, + cumsum as cumsum, + cumulative_prod as cumulative_prod, cumulative_sum as cumulative_sum, max as max, mean as mean, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 30363c8f4e47..d391abd46e13 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -276,6 +276,9 @@ def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... +def cumulative_prod(x: ArrayLike, /, *, axis: int | None = ..., + dtype: DTypeLike | None = ..., + include_initial: builtins.bool = ...) -> Array: ... def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., include_initial: builtins.bool = ...) -> Array: ... diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 623c11a51998..4dc0ff5f481f 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -820,15 +820,6 @@ def test_f16_mean(self, dtype): def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): rng = jtu.rand_some_zero(self.rng()) - def np_mock_op(x, axis=None, dtype=None, include_initial=False): - axis = axis or 0 - out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype) - if include_initial: - zeros_shape = list(x.shape) - zeros_shape[axis] = 1 - out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis) - return out - # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as # input because we rely on JAX-specific casting behavior def args_maker(): @@ -836,10 +827,20 @@ def args_maker(): if out_dtype in unsigned_dtypes: x = 10 * jnp.abs(x) return [x] - - np_op = getattr(np, "cumulative_sum", np_mock_op) kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial) + if jtu.numpy_version() >= (2, 1, 0): + np_op = np.cumulative_sum + else: + def np_op(x, axis=None, dtype=None, include_initial=False): + axis = axis or 0 + out = np.cumsum(x, axis=axis, dtype=dtype or x.dtype) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis) + return out + np_fun = lambda x: np_op(x, **kwargs) jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @@ -866,5 +867,51 @@ def testCumulativeSumBool(self): dtype=jnp.bool_) np.testing.assert_array_equal(np.array([[True], [True], [False]]), out) + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list( + range(-len(shape), len(shape)) + ) + ([None] if len(shape) == 1 else [])], + [dict(dtype=dtype, out_dtype=out_dtype) + for dtype in (all_dtypes+[None]) + for out_dtype in ( + complex_dtypes if np.issubdtype(dtype, np.complexfloating) + else all_dtypes + ) + ], + include_initial=[False, True], + ) + @jtu.ignore_warning(category=NumpyComplexWarning) + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): + rng = jtu.rand_some_zero(self.rng()) + + # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as + # input because we rely on JAX-specific casting behavior + def args_maker(): + x = jnp.array(rng(shape, dtype)) + if out_dtype in unsigned_dtypes: + x = 10 * jnp.abs(x) + return [x] + kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial) + + if jtu.numpy_version() >= (2, 1, 0): + np_op = np.cumulative_prod + else: + def np_op(x, axis=None, dtype=None, include_initial=False): + axis = axis or 0 + out = np.cumprod(x, axis=axis, dtype=dtype or x.dtype) + if include_initial: + ones_shape = list(x.shape) + ones_shape[axis] = 1 + out = jnp.concat([jnp.ones(ones_shape, dtype=out.dtype), out], axis=axis) + return out + + np_fun = lambda x: np_op(x, **kwargs) + jnp_fun = lambda x: jnp.cumulative_prod(x, **kwargs) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From e4eca9ec5975982f89903082532b49ec4d56da9d Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Fri, 25 Oct 2024 14:13:28 -0700 Subject: [PATCH 064/698] #jax Adds a missing comma to Pallas Quickstart PiperOrigin-RevId: 689907976 --- docs/pallas/quickstart.ipynb | 2 +- docs/pallas/quickstart.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 50464ce8ffd4..af34d167400b 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -412,7 +412,7 @@ "\n", "For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this\n", "carves `x` up into \"row\" blocks.\n", - "To see this see how both program instances\n", + "To see this, see how both program instances\n", "`(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`.\n", "For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`.\n", "Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`.\n", diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index b9acd6497fb5..e11868f5f671 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -294,7 +294,7 @@ To express this, we'd first use a `(2, 2)` grid (one block for each program). For `x`, we use `BlockSpec((512, 1024), lambda i, j: (i, 0))` -- this carves `x` up into "row" blocks. -To see this see how both program instances +To see this, see how both program instances `(1, 0)` and `(1, 1)` pick the `(1, 0)` block in `x`. For `y`, we use a transposed version `BlockSpec((1024, 512), lambda i, j: (0, j))`. Finally, for `z` we use `BlockSpec((512, 512), lambda i, j: (i, j))`. From 9f7f08eccb09cc423a02e7d4e74bd8d225a3a0c9 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Mon, 30 Sep 2024 14:04:07 -0700 Subject: [PATCH 065/698] Fix vmap error message when args passed by keyword See the new test for a case that used to produce the wrong message. Fixes: #24406 --- jax/_src/api.py | 38 +++++++++++++++++++++++--------------- tests/api_test.py | 19 ++++++++++++++++++- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index d2ac5465eded..d85ea46a2c8a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1038,27 +1038,35 @@ def _get_axis_size(name: str, shape: tuple[core.AxisSize, ...], axis: int def _get_argument_type(x): try: return shaped_abstractify(x).str_short() - except TypeError: #Catch all for user specified objects that can't be interpreted as a data type + except TypeError: # Catch all for user specified objects that can't be interpreted as a data type return "unknown" msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"] args, kwargs = tree_unflatten(tree, vals) try: ba = inspect.signature(fn).bind(*args, **kwargs) + signature_parameters: list[str] = list(ba.signature.parameters.keys()) except (TypeError, ValueError): - ba = None - if ba is None: - args_paths = [f'args{keystr(p)} ' - f'of type {_get_argument_type(x)}' - for p, x in generate_key_paths(args)] - kwargs_paths = [f'kwargs{keystr(p)} ' - f'of type {_get_argument_type(x)}' - for p, x in generate_key_paths(kwargs)] - key_paths = [*args_paths, *kwargs_paths] - else: - key_paths = [f'argument {name}{keystr(p)} ' - f'of type {_get_argument_type(x)}' - for name, arg in ba.arguments.items() - for p, x in generate_key_paths(arg)] + signature_parameters = None + + def arg_name(key_path): + if signature_parameters is None: + return f"args{keystr(key_path)}" + # args is a tuple, so key_path[0].idx is the index into args. + i = key_path[0].idx + res = f"argument {signature_parameters[i]}" + if len(key_path) > 1: + res += keystr(key_path[1:]) + return res + + args_paths = [ + f"{arg_name(p)} of type {_get_argument_type(x)}" + for (p, x) in generate_key_paths(args) + ] + kwargs_paths = [ + f"kwargs{keystr(p)} of type {_get_argument_type(x)}" + for p, x in generate_key_paths(kwargs) + ] + key_paths = [*args_paths, *kwargs_paths] all_sizes = [_get_axis_size(name, np.shape(x), d) if d is not None else None for x, d in zip(vals, dims)] size_counts = collections.Counter(s for s in all_sizes if s is not None) diff --git a/tests/api_test.py b/tests/api_test.py index d0a711f4a617..41268a83304d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1919,6 +1919,23 @@ def f(x1, x2, g): ): jax.vmap(f, (0, 0, None))(jnp.ones(2), jnp.ones(3), jnp.add) + def test_vmap_inconsistent_sizes_constructs_proper_error_message_kwargs(self): + # regression test for https://github.com/jax-ml/jax/issues/24406 + def f(x1, x2, a3): + return x1 + x2 + a3 + + with self.assertRaisesRegex( + ValueError, + "vmap got inconsistent sizes for array axes to be mapped:\n" + r" \* most axes \(2 of them\) had size 2, e.g. axis 0 of argument x1 of type float32\[2\];\n" + r" \* one axis had size 1: axis 0 of kwargs\['a3'\] of type float32\[1\]", + ): + jax.vmap(f)( + jnp.ones(2, dtype=jnp.float32), + a3=jnp.ones(1, dtype=jnp.float32), + x2=jnp.ones(2, dtype=jnp.float32) + ) + def test_device_get_scalar(self): x = np.arange(12.).reshape((3, 4)).astype("float32") x = api.device_put(x) @@ -3071,7 +3088,7 @@ def f(x, y): "vmap got inconsistent sizes for array axes to be mapped:\n" r" \* one axis had size 1: axis 0 of argument x of type int32\[1\];" "\n" - r" \* one axis had size 2: axis 0 of argument y of type int32\[2\]"): + r" \* one axis had size 2: axis 0 of kwargs\['y'\] of type int32\[2\]"): f(jnp.array([1], 'int32'), y=jnp.array([1, 2], 'int32')) def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): From 6b065579d4606cfe6d50d78733f211d0a723514c Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 25 Oct 2024 17:00:19 -0700 Subject: [PATCH 066/698] Support None in PmapSharding as a replacement for device_put_replicated. eg: `jax.device_put(x, PmapSharding.default(x.shape, None, jax.local_devices()))` PiperOrigin-RevId: 689956669 --- jax/_src/sharding_impls.py | 14 ++++++++------ tests/array_test.py | 8 ++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 047571754d6f..9cb1e49299ea 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -535,7 +535,7 @@ def is_equivalent_to(self: PmapSharding, other: PmapSharding, # type: ignore # TODO(yashkatariya): Expose `sharded_dim_size` in the API if required. @classmethod - def default(cls, shape: Shape, sharded_dim: int = 0, + def default(cls, shape: Shape, sharded_dim: int | None = 0, devices: Sequence[xc.Device] | None = None) -> PmapSharding: """Creates a :class:`PmapSharding` which matches the default placement used by :func:`jax.pmap`. @@ -547,6 +547,13 @@ def default(cls, shape: Shape, sharded_dim: int = 0, device order used by pmap is used, which is the order of :func:`jax.local_devices`. """ + if sharded_dim is None: + if devices is None: + raise ValueError("One of sharded_dim or devices must be set.") + nrep = len(devices) + return cls(np.array(devices), + sharding_specs.pmap_sharding_spec(nrep, nrep, shape, None)) + # The dtype doesn't matter here. Its only used for creating the # sharding_spec. sharding_spec = sharding_specs.create_pmap_sharding_spec( @@ -565,11 +572,6 @@ def default(cls, shape: Shape, sharded_dim: int = 0, raise NotImplementedError( 'Multiple chunks in Chunked dimension not supported.') - if num_ways_sharded is None: - raise NotImplementedError( - '`None` to sharded_dim is not supported. Please file a jax ' - 'issue if you need this feature.') - if devices is None: pmap_devices: np.ndarray = np.array( xla_bridge.local_devices()[:num_ways_sharded]) diff --git a/tests/array_test.py b/tests/array_test.py index b3492e4d152f..e7aad59b1ad5 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1133,6 +1133,14 @@ def test_default_pmap_sharding_with_devices(self): ps = jax.sharding.PmapSharding.default((4, 2), devices=new_order) self.assertEqual(ps._device_assignment, new_order) + def test_default_pmap_sharding_replicated(self): + x = np.zeros((len(jax.local_devices()), 8), dtype=np.float32) + x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(x) + ps = jax.sharding.PmapSharding.default( + shape=(8,), sharded_dim=None, + devices=jax.local_devices()) + self.assertEqual(x.sharding, ps) + def test_mesh_repr(self): mesh = jtu.create_mesh((1, 1), ('x', 'y')) mesh_repr = repr(mesh) From ad1d864b056c9c5a57b2482cd5d47808523c9419 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 25 Oct 2024 21:23:07 -0400 Subject: [PATCH 067/698] Fix lint at head --- jax/_src/lib/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index ea9191b2cc7f..9cc54a59f259 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -155,9 +155,5 @@ def _try_cuda_nvcc_import() -> str | None: cuda_path = _cuda_path() -if version >= (0, 4, 35): - guard_lib = xla_client._xla.guard_lib -else: - guard_lib = xla_client._xla.transfer_guard_lib - +guard_lib = xla_client._xla.guard_lib Device = xla_client._xla.Device From 56dc89f1a295c737790b028bc389de8bd8a1f519 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 26 Oct 2024 11:27:20 -0700 Subject: [PATCH 068/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/4d0fe880a542470c79e647580fc30aa35c576cfd. PiperOrigin-RevId: 690147893 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c1b1e9e8f894..5af3576c041d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "79142a8fdd3bb7556ab0c6d01e691ecb5f7805c9" -XLA_SHA256 = "ad57f05faac50fd67ccfb22816905095cb7abcda36d3dcbc4844d8ef9b61efe8" +XLA_COMMIT = "4d0fe880a542470c79e647580fc30aa35c576cfd" +XLA_SHA256 = "4e59a6bfcb9ccd4a6b3b1e04927c9e752b073024216bc29dcde7700f8a58941a" def repo(): tf_http_archive( From 6f3c01238e9b9a4ca01c6e442d40217f0909179d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 26 Oct 2024 16:58:07 -0700 Subject: [PATCH 069/698] [mosaic] Directly build IR in _device_id_to_logical, rather than using lower_fun. This is just as simple and faster. PiperOrigin-RevId: 690196495 --- jax/_src/pallas/mosaic/lowering.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 85940b4ee975..74245b6d9f5b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2801,16 +2801,15 @@ def _device_id_to_logical( # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides - def _linearize_mesh_indices(*indices): - return sum(a * b for a, b in zip(indices, mesh_strides)) - lower_ctx = LoweringRuleContext( - lowering_context=ctx.lowering_context, - avals_in=[pallas_core.index_map_grid_aval] * len(device_ids), - avals_out=[pallas_core.index_map_grid_aval], - block_shapes=(None,) * len(device_ids), + + i32 = ir.IntegerType.get_signless(32) + return functools.reduce( + arith.addi, + ( + arith.muli(a, arith.constant(i32, b)) + for a, b in zip(device_ids, mesh_strides) + ), ) - return lower_fun(_linearize_mesh_indices, multiple_results=False)( - lower_ctx, *device_ids) elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}") From ef83ac6443ffea96e8435781f0942acd602a71ec Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 27 Oct 2024 11:00:41 -0700 Subject: [PATCH 070/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a7bd6795dbdfb0845ae6b14b84a035ee357e35d0. PiperOrigin-RevId: 690363812 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 5af3576c041d..e3ed1f3a86d1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "4d0fe880a542470c79e647580fc30aa35c576cfd" -XLA_SHA256 = "4e59a6bfcb9ccd4a6b3b1e04927c9e752b073024216bc29dcde7700f8a58941a" +XLA_COMMIT = "a7bd6795dbdfb0845ae6b14b84a035ee357e35d0" +XLA_SHA256 = "2dcea3c95e35f966dbeb95f8de6e2ab30a9992b5c00ac2af88176b38425b5fcd" def repo(): tf_http_archive( From 343cf18e09a6777709b9991c13c045ab3e196120 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Oct 2024 05:39:27 -0700 Subject: [PATCH 071/698] [Pallas:MGPU] Wire up the Mosaic GPU profiler into Pallas PiperOrigin-RevId: 690574747 --- jax/_src/pallas/mosaic_gpu/BUILD | 1 + jax/_src/pallas/mosaic_gpu/core.py | 12 ++ jax/_src/pallas/mosaic_gpu/lowering.py | 110 +++++++++++++----- .../mosaic_gpu/pallas_call_registration.py | 39 ++++++- jax/experimental/mosaic/gpu/core.py | 2 +- jax/experimental/mosaic/gpu/profiler.py | 3 +- tests/mosaic/gpu_test.py | 11 ++ tests/pallas/mosaic_gpu_test.py | 29 +++++ 8 files changed, 174 insertions(+), 33 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 9ee7c04c3c3e..681c3996359d 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -61,6 +61,7 @@ pytype_strict_library( "//jax:mosaic_gpu", "//jax:pallas", "//jax:partial_eval", + "//jax:source_info_util", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index ff22f276001f..8e760646ef05 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -72,12 +72,24 @@ class GPUCompilerParams(pallas_core.CompilerParams): references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. + profile_space: The number of profiler events that can be collected in a + single invocation. It is undefined behavior if a thread collects more + events than this. + profile_dir: The directory to which profiling traces will be written to. """ PLATFORM: ClassVar[str] = "mosaic_gpu" approx_math: bool = False dimension_semantics: Sequence[DimensionSemantics] | None = None max_concurrent_steps: int = 1 delay_release: int = 0 + profile_space: int = 0 + profile_dir: str = "" + + def __post_init__(self): + if bool(self.profile_space) ^ bool(self.profile_dir): + raise ValueError( + "Either both profile_space and profile_dir must be set, or neither." + ) class GPUMemorySpace(enum.Enum): diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 197e189adc2c..dd576704fa75 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -30,6 +30,7 @@ from jax._src import core as jax_core from jax._src import pjit from jax._src import util +from jax._src import source_info_util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir @@ -50,6 +51,7 @@ import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import profiler as mgpu_profiler import jax.numpy as jnp import numpy as np @@ -196,6 +198,8 @@ class ModuleContext: runtime_barriers: MutableMapping[ mgpu.Barrier, MutableSequence[mgpu.BarrierRef] ] + name_stack: source_info_util.NameStack + traceback_caches: mlir.TracebackCaches def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: """Reserves a barrier. @@ -269,7 +273,15 @@ class LoweringRuleContext: class LoweringResult: module: ir.Module grid: tuple[int, ...] + block: tuple[int, ...] out_structs: tuple[jax.ShapeDtypeStruct, ...] + profiler_context: ProfilerContext | None + + +@dataclasses.dataclass(frozen=True) +class ProfilerContext: + dump_path: str + spec: mgpu_profiler.ProfilerSpec class LoweringError(Exception): # pylint: disable=g-bad-exception-name @@ -505,6 +517,8 @@ def make_program_ids(step: ir.Value): runtime_smem, smem_used_bytes=0, runtime_barriers=grouped_barriers, + name_stack=source_info_util.NameStack(), + traceback_caches=mlir.TracebackCaches(), ) del runtime_smem, grouped_barriers, runtime_barriers @@ -758,14 +772,19 @@ def _(step, carry): if not isinstance(aval.dtype, gpu_core.BarrierType) and aval.memory_space == gpu_core.SMEM ] - smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") + smem_scratch_bytes = params.get("smem_scratch_bytes") if smem_scratch_bytes is None: smem_scratch_bytes = rs.smem_scratch_bytes extra_smem_scratch.append( jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) ) - module, out_structs_smem, _ = mgpu_core._lower_as_gpu_kernel( + prof_ctx = prof_spec = None + if prof_space := params.get("profile_space", 0): + # Each range is 2 events, each event is 4 bytes. + prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) + prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) + module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel( body, grid=grid, cluster=(), @@ -782,9 +801,10 @@ def _(step, carry): ), ), module_name=name_and_src_info.name, + prof_spec=prof_spec, ) - return LoweringResult(module, grid, out_structs_smem) + return LoweringResult(module, grid, block, out_structs_gmem, prof_ctx) mosaic_lowering_rules = {} @@ -798,6 +818,19 @@ def deco(fn): return deco +def _compute_name_stack_updates( + old_name_stack: list[str], + new_name_stack: list[str] +) -> tuple[list[str], list[str]]: + common_prefix_idx = 0 + for i, (old, new) in enumerate(unsafe_zip(old_name_stack, new_name_stack)): + if old == new: + common_prefix_idx = i+1 + else: + break + return old_name_stack[common_prefix_idx:], new_name_stack[common_prefix_idx:] + + def lower_jaxpr_to_mosaic_gpu( module_ctx: ModuleContext, launch_ctx: mgpu.LaunchContext, @@ -815,35 +848,54 @@ def write_env(var: jax_core.Var, val): map(write_env, jaxpr.constvars, consts) map(write_env, jaxpr.invars, args) + # TODO(justinfu): Handle transform scopes. + last_local_name_stack: list[str] = [] + named_regions = [] for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) - if eqn.primitive not in mosaic_lowering_rules: - raise NotImplementedError( - "Unimplemented primitive in Pallas Mosaic GPU lowering: " - f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/jax-ml/jax/issues." - ) - rule = mosaic_lowering_rules[eqn.primitive] - rule_ctx = LoweringRuleContext( - module_ctx, - launch_ctx, - avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], - avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], + source_info = eqn.source_info.replace( + name_stack=module_ctx.name_stack + eqn.source_info.name_stack ) - try: - outvals = rule(rule_ctx, *invals, **eqn.params) - except LoweringError: - raise # We only add the extra info to the innermost exception. - except Exception as e: - inval_types = map(lambda t: getattr(t, "type", None), invals) - raise LoweringError( - f"Exception while lowering eqn:\n {eqn}\nWith context:\n " - f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}" - ) from e - if eqn.primitive.multiple_results: - map(write_env, eqn.outvars, outvals) - else: - write_env(eqn.outvars[0], outvals) + loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) + with source_info_util.user_context(eqn.source_info.traceback), loc: + if eqn.primitive not in mosaic_lowering_rules: + raise NotImplementedError( + "Unimplemented primitive in Pallas Mosaic GPU lowering: " + f"{eqn.primitive.name}. " + "Please file an issue on https://github.com/jax-ml/jax/issues." + ) + new_local_name_stack = [scope.name for scope in eqn.source_info.name_stack.stack] + popped, pushed = _compute_name_stack_updates(last_local_name_stack, new_local_name_stack) + last_local_name_stack = new_local_name_stack + for _ in popped: + named_regions.pop().close() + for name in pushed: + wrapper_stack = contextlib.ExitStack() + wrapper_stack.enter_context(launch_ctx.named_region(name)) + named_regions.append(wrapper_stack) + rule = mosaic_lowering_rules[eqn.primitive] + rule_ctx = LoweringRuleContext( + module_ctx, + launch_ctx, + avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], + avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], + ) + try: + outvals = rule(rule_ctx, *invals, **eqn.params) + except LoweringError: + raise # We only add the extra info to the innermost exception. + except Exception as e: + inval_types = map(lambda t: getattr(t, "type", None), invals) + raise LoweringError( + f"Exception while lowering eqn:\n {eqn}\nWith context:\n " + f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}" + ) from e + if eqn.primitive.multiple_results: + map(write_env, eqn.outvars, outvals) + else: + write_env(eqn.outvars[0], outvals) + while named_regions: # Drain the name stack. + named_regions.pop().close() return map(read_env, jaxpr.outvars) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 960fe7d71856..05785cb511ea 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -17,8 +17,12 @@ from __future__ import annotations +import os +import time from typing import Any +import warnings +import jax from jax import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core @@ -63,10 +67,41 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - return mosaic_core._mosaic_gpu_lowering_rule( - ctx, + new_avals_out = [ + jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs + ] + outs = mosaic_core._mosaic_gpu_lowering_rule( + ctx.replace(avals_out=new_avals_out), *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), out_types=lowering_result.out_structs, input_output_aliases=input_output_aliases, ) + if (prof_ctx := lowering_result.profiler_context) is not None: + *outs, prof_buffer = outs + if (dump_path := prof_ctx.dump_path) == "sponge": + dump_path = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR") # type: ignore + out_file = os.path.join( + dump_path, f"{name_and_src_info.name}-{time.time_ns()}-trace.json" + ) + def dump_profile(prof_buffer): + try: + with open(out_file, "x") as f: + prof_ctx.spec.dump( + prof_buffer, + f, + grid=lowering_result.grid, + block=lowering_result.block, + ) + except FileExistsError: + warnings.warn( + f"Failed to dump profile for pallas_call {name_and_src_info}, " + f"profile already exists at {out_file}" + ) + def do_callback(prof_buffer): + jax.debug.callback(dump_profile, prof_buffer) + return () + mlir.lower_fun(do_callback, multiple_results=True)( + ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer + ) + return outs diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index d35340340695..b08b40aa1860 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -97,7 +97,7 @@ def _mosaic_gpu_lowering_rule( out_types, input_output_aliases: tuple[tuple[int, int], ...] = (), ): - del out_types # Unused. + assert len(out_types) == len(ctx.avals_out) kernel_id = hashlib.sha256(module).digest() # Note that this is technically only a half measure. Someone might load a # compiled module with a hash collision from disk. But that's so unlikely with diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index bf6631cbca16..e4949b325507 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -203,7 +203,8 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): "tid": 1 + wg_idx + warpgroups_per_block * block_idx, }) else: # If we didn't break - events.append(block_events) + if block_events: + events.append(block_events) events = sorted(events, key=lambda x: x[0]["ts"]) flat_events = list(itertools.chain.from_iterable(events)) return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1ece3f62e3a3..f4be2d767dbe 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1481,6 +1481,17 @@ def test_measure(self): x = jnp.arange(1024 * 1024) profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test + def test_profile(self): + def kernel(ctx, src, dst, _): + mgpu.FragmentedArray.load_strided(src).store_untiled(dst) + x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + spec = profiler.ProfilerSpec(1024) + # This is just a smoke test. + f = jax.jit(mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (), prof_spec=spec + )) + jax.block_until_ready(f(x)) + def test_multigpu(self): if len(jax.devices()) < 2: self.skipTest("Need at least 2 devices") diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6a2ba17bb34b..21e8380a4083 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -14,7 +14,9 @@ import functools import math +import os import re +import tempfile import traceback from absl.testing import absltest @@ -979,6 +981,33 @@ def kernel(o_ref): x = jnp.full(shape, 42.0) np.testing.assert_array_equal(kernel(), x) + def test_profiler(self): + def kernel(x_ref, o_ref): + with jax.named_scope("add"): + with jax.named_scope("load"): + x = x_ref[...] + o = x + x + with jax.named_scope("store"): + o_ref[...] = o + with tempfile.TemporaryDirectory() as tmpdir: + x = jnp.arange(256).astype(jnp.float32) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + profile_space=16, profile_dir=tmpdir + ), + )(x) + jax.block_until_ready(y) + jax.effects_barrier() + [name] = os.listdir(tmpdir) + with open(os.path.join(tmpdir, name), "r") as f: + data = f.read() + self.assertEqual(data.count('"name": "add"'), 2) + self.assertEqual(data.count('"name": "load"'), 2) + self.assertEqual(data.count('"name": "store"'), 2) + np.testing.assert_array_equal(y, x + x) + class PipelineTest(PallasTest): From 2a35b0b9a34fdca7a13bf9fcf3facd5cdcb3e096 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 28 Oct 2024 06:15:46 -0700 Subject: [PATCH 072/698] Remove int4 from jtu.dtypes.all_dtypes. Why? It's not included in the supported() enumeration on any platforms, so there is no need to mention it here. I tried fixing this by including it in supported(), but this led to many errors. It's better to not list it here, because it might mislead us into thinking it's being tested. --- jax/_src/test_util.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 81737f27540b..4ec3123bd3e6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1452,8 +1452,7 @@ def integer(self): @_cached_property def all_integer(self): - return self.supported([ - _dtypes.int4, np.int8, np.int16, np.int32, np.int64]) + return self.supported([np.int8, np.int16, np.int32, np.int64]) @_cached_property def unsigned(self): @@ -1461,8 +1460,7 @@ def unsigned(self): @_cached_property def all_unsigned(self): - return self.supported([ - _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64]) + return self.supported([np.uint8, np.uint16, np.uint32, np.uint64]) @_cached_property def complex(self): From 7763d149130ec9b09673320f7e78472b0e58c6bb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Oct 2024 07:01:41 -0700 Subject: [PATCH 073/698] [Mosaic GPU] Don't assume register vectors are exactly 2 in length This is not true of the WGMMA strided layout. We were previously emitting incorrect MLIR in our rsqrt test case which led to miscompiles and flaky executions. PiperOrigin-RevId: 690595119 --- jax/experimental/mosaic/gpu/fragmented_array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 98c56de9ccda..6bc6fbef86d9 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -571,7 +571,8 @@ def fast_instr(x): elif ir.VectorType.isinstance(x.type): index = ir.IndexType.get() result = llvm.mlir_undef(x.type) - for i in range(2): + [vec_len] = ir.VectorType(x.type).shape + for i in range(vec_len): v = vector.extractelement(x, position=c(i, index)) vr = fast_instr(v) result = vector.insertelement(vr, result, position=c(i, index)) From 321fa007412bf7eccfc459a3658b217906862eab Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 28 Oct 2024 07:14:34 -0700 Subject: [PATCH 074/698] Skip testVectorizedDeprecation on Python 3.13 to unblock the CI PiperOrigin-RevId: 690598772 --- tests/extend_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/extend_test.py b/tests/extend_test.py index d9059d49ecc7..69b4591f3e85 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import sys import unittest from functools import partial @@ -260,6 +261,10 @@ def testFfiCallBatching(self, shape, vmap_method): @jtu.run_on_devices("gpu", "cpu") def testVectorizedDeprecation(self): + if sys.version_info.major == 3 and sys.version_info.minor == 13: + # TODO(b/376025274): Remove the skip once the bug is fixed. + raise unittest.SkipTest("Crashes on Python 3.13") + x = self.rng().randn(3, 5, 4).astype(np.float32) with self.assertWarns(DeprecationWarning): ffi_call_geqrf(x, vectorized=True) From 1336c2d5c460682b9cfd80d00bc5a91bc38e376d Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 28 Oct 2024 07:49:30 -0700 Subject: [PATCH 075/698] Fix breaking PGLE test-cases PiperOrigin-RevId: 690608075 --- tests/BUILD | 1 - tests/pgle_test.py | 79 ++++++++++++++++++++++++++++------------------ 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index dc5d5b37316b..47234adeddbf 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -293,7 +293,6 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_backends = ["gpu"], - env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"}, tags = [ "config-cuda-only", "multiaccelerator", diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 8f3f7b2d3c0f..cf248d1c2cea 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -41,10 +41,17 @@ jax.config.parse_flags_with_absl() +dump_dir = tempfile.TemporaryDirectory().name +os.environ['XLA_FLAGS'] = ( + f'--xla_dump_to={dump_dir}' + ' --xla_gpu_experimental_dump_fdo_profiles=true' + ' --xla_gpu_enable_latency_hiding_scheduler=true' +) @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): def setUp(self): + super().setUp() cc.set_cache_dir(None) cc.reset_cache() @@ -52,7 +59,6 @@ def tearDown(self): cc.set_cache_dir(None) super().tearDown() - @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfile(self): mesh = jtu.create_mesh((2,), ('x',)) @@ -171,7 +177,6 @@ def f(x): self.assertArraysEqual(compiled(x), expected) self.assertEqual(cache_miss_count[0], 0) - @unittest.skip("Test failing in CI") def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) @@ -199,56 +204,68 @@ def f(x): config.persistent_cache_min_entry_size_bytes(0), config.persistent_cache_min_compile_time_secs(0), config.pgle_profiling_runs(2), - tempfile.TemporaryDirectory() as tmpdir): - cc.set_cache_dir(tmpdir) + tempfile.TemporaryDirectory() as cache_dir): + cc.set_cache_dir(cache_dir) # Run 1: Module should be compiled without FDO with jtu.count_cached_compilation_cache_miss() as cache_miss_count: f(x) self.assertEqual(cache_miss_count[0], 1) # Non-pgle profiled version of module should be saved - non_pgle_profiled_files = os.listdir(tmpdir) - if len(non_pgle_profiled_files) > 1: - non_pgle_profiled_files = [ - f for f in non_pgle_profiled_files if 'cache' in f - ] - - self.assertLen(non_pgle_profiled_files, 1) + non_pgle_profiled_files = os.listdir(cache_dir) + self.assertNotEmpty(non_pgle_profiled_files) # Run 2: Compilation should not be called with jtu.count_cached_compilation_cache_miss() as cache_miss_count: f(x) self.assertEqual(cache_miss_count[0], 0) + module_before_pgle = os.listdir(dump_dir) + print(module_before_pgle) + self.assertNotEmpty(module_before_pgle) # Run 3: Module should be compiled with FDO and stored to persistent cache with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + # Add xla_dump_to to env flags f(x) self.assertEqual(cache_miss_count[0], 1) + # Check if FDO profile file of the biggest module is not empty + module_after_pgle = [ + x + for x in os.listdir(dump_dir) + if x not in module_before_pgle + ] + self.assertNotEmpty(module_after_pgle) + biggest_module_after_pgle = max( + module_after_pgle, + key=lambda x: os.path.getsize( + os.path.join(dump_dir, x) + ), + ) + base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) + + # Check if FDO profile file in dump directory is not empty + for module in module_after_pgle: + if module.startswith(base_module_name) and module.endswith( + '.fdo_profile' + ): + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, module)), 0 + ) + for pgle_profiler in profilers_dict.values(): self.assertTrue(pgle_profiler.is_enabled()) self.assertTrue(pgle_profiler.is_fdo_consumed()) - # One module is PGLEd version another one is not PGLEd - files_after_pgle_profile = os.listdir(tmpdir) - if len(files_after_pgle_profile) > 2: - files_after_pgle_profile = [ - f for f in files_after_pgle_profile if 'cache' in f - ] - self.assertLen(os.listdir(tmpdir), 2) - - self.assertLen(files_after_pgle_profile, 2) - non_pgled_file_size = os.path.getsize( - os.path.join(tmpdir, files_after_pgle_profile[0]) - ) - pgled_file_size = os.path.getsize( - os.path.join(tmpdir, files_after_pgle_profile[1]) + + files_after_pgle_profile = os.listdir(cache_dir) + self.assertGreater( + len(files_after_pgle_profile), len(non_pgle_profiled_files) ) - # Make sure that FDO profile were applied to the module - self.assertNotEqual(pgled_file_size, non_pgled_file_size) # Removing non-pgle profiled module from cache to check that later pgle # profiled version will be used. - os.remove(os.path.join(tmpdir, non_pgle_profiled_files[0])) + for non_pgle_file in non_pgle_profiled_files: + os.remove(os.path.join(cache_dir, non_pgle_file)) api.clear_caches() profilers_dict.clear() @@ -286,11 +303,11 @@ def f(x, y): f_lowered = f.lower(x, y) compiled = f_lowered.compile() - with tempfile.TemporaryDirectory() as tmpdir: - jax.profiler.start_trace(tmpdir) + with tempfile.TemporaryDirectory() as cache_dir: + jax.profiler.start_trace(cache_dir) compiled(x, y) jax.profiler.stop_trace() - directories = glob.glob(os.path.join(tmpdir, 'plugins/profile/**/')) + directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) directories = [d for d in directories if os.path.isdir(d)] rundir = directories[-1] logging.info('rundir: %s', rundir) From 04bdd07f6685ca62b225bc8cf330221a19892835 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 28 Oct 2024 07:51:43 -0700 Subject: [PATCH 076/698] [mosaic_gpu] `mgpu.FragmentedArray` now supports `//` This is needed to compute grid index from the iteration step counter in `emit_pipeline`. PiperOrigin-RevId: 690608581 --- .../mosaic/gpu/fragmented_array.py | 26 ++++++++++++ jaxlib/mosaic/gpu/BUILD | 1 + jaxlib/mosaic/gpu/custom_call.cc | 9 +++-- tests/mosaic/gpu_test.py | 40 +++++++++++++------ 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 6bc6fbef86d9..1767f3edb976 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -430,6 +430,32 @@ def __rtruediv__(self, other): return NotImplemented return self._pointwise(lambda s, o: arith.divf(o, s), other) + def __floordiv__(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise( + lambda s, o: mlir_math.floor(arith.divf(s, o)), other + ) + elif ir.IntegerType.isinstance(self.mlir_dtype): + if self.is_signed: + return self._pointwise(arith.floordivsi, other) + else: + return self._pointwise(arith.divui, other) + else: + return NotImplemented + + def __rfloordiv__(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise( + lambda s, o: mlir_math.floor(arith.divf(o, s)), other + ) + elif ir.IntegerType.isinstance(self.mlir_dtype): + if self.is_signed: + return self._pointwise(lambda s, o: arith.floordivsi(o, s), other) + else: + return self._pointwise(lambda s, o: arith.divui(o, s), other) + else: + return NotImplemented + def __mod__(self, other): if not ir.IntegerType.isinstance(self.mlir_dtype): return NotImplemented diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 49c488e73850..6bd98fe35880 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -129,6 +129,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:ControlFlowToLLVM", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 2eeac946afb1..e3bbcf0cd0e3 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -31,14 +31,12 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/mosaic/gpu/target.h" #include "absl/base/optimization.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "llvm/include/llvm/ADT/SmallVector.h" #include "llvm/include/llvm/Support/CodeGen.h" @@ -54,6 +52,7 @@ limitations under the License. #include "mlir/include/mlir/Conversion/Passes.h" #include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" @@ -81,8 +80,10 @@ limitations under the License. #include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/include/mlir/Transforms/Passes.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" #include "jaxlib/mosaic/gpu/passes.h" +#include "jaxlib/mosaic/gpu/target.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -141,12 +142,14 @@ mlir::FailureOr GetPassPipeline( mosaic::gpu::registerGpuLaunchLoweringPass(); mosaic::gpu::registerConvertGpuToLLVMPass(); mosaic::gpu::registerByvalInsertionPass(); + mlir::arith::registerArithExpandOpsPass(); return true; }(); (void)register_once; return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( + arith-expand, canonicalize, gpu-launch-sink-index-computations, convert-nvgpu-to-nvvm, diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f4be2d767dbe..acd06e4e258f 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1209,27 +1209,21 @@ class FragmentedArrayTest(TestCase): operator.add, operator.mul, operator.sub, - operator.truediv, - operator.mod, (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), ), dtype=[jnp.float32, jnp.int32, jnp.uint32], m=(64, 128), n=(8, 16, 32, 64, 80, 128, 256), ) - @jtu.ignore_warning(message="(invalid value|divide by zero)", - category=RuntimeWarning) + @jtu.ignore_warning( + message="(invalid value|divide by zero)", category=RuntimeWarning + ) def test_binary(self, op, dtype, m=64, n=32): if isinstance(op, tuple): op, np_op = op else: np_op = op - if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv: - self.skipTest("Unsupported for integer types") - if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod: - self.skipTest("Unsupported for floating types") - for scalar_rhs in [None, 2]: def kernel(ctx, dst, _): mlir_dtype = utils.dtype_to_ir_type(dtype) @@ -1242,10 +1236,30 @@ def kernel(ctx, dst, _): )() ref_x = np.arange(m * n, dtype=dtype).reshape(m, n) ref_rhs = scalar_rhs or ref_x - if op is operator.truediv: - np.testing.assert_allclose(result, np_op(ref_x, ref_rhs), atol=2e-7) - else: - np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + + @parameterized.product( + op=[operator.truediv, operator.floordiv, operator.mod], + dtype=[jnp.float32, jnp.int32, jnp.uint32], + ) + def test_division(self, op, dtype, m=64, n=32): + if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv: + self.skipTest("Unsupported for integer types") + if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod: + self.skipTest("Unsupported for floating types") + + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, dtype) + op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_allclose( + result, op(dtype(4.2).item() * iota, iota + 1), atol=2e-7 + ) @parameterized.product( op=[ From 232b63d765975d7231863d9a8de304e5f10b431f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 28 Oct 2024 07:57:52 -0700 Subject: [PATCH 077/698] Add base case to _device_id_to_logical. PiperOrigin-RevId: 690610435 --- jax/_src/pallas/mosaic/lowering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 74245b6d9f5b..13d321754b23 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2803,6 +2803,8 @@ def _device_id_to_logical( mesh_strides = ctx.lowering_context.mesh_context.mesh_strides i32 = ir.IntegerType.get_signless(32) + if len(device_ids) == 0: + return arith.constant(i32, 0) return functools.reduce( arith.addi, ( From dfa6fcd56bd5da21a71498db7bf8abfcedf27edc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 28 Oct 2024 08:25:08 -0700 Subject: [PATCH 078/698] [pallas:mosaic_gpu] Extracted a basic `emit_pipeline` API from the in kernel pipelining test PiperOrigin-RevId: 690619853 --- docs/jax.experimental.pallas.mosaic_gpu.rst | 1 + jax/BUILD | 1 + jax/_src/pallas/mosaic_gpu/BUILD | 13 ++ jax/_src/pallas/mosaic_gpu/core.py | 4 +- jax/_src/pallas/mosaic_gpu/lowering.py | 9 +- jax/_src/pallas/mosaic_gpu/pipeline.py | 183 ++++++++++++++++++++ jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 52 +++++- 8 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 jax/_src/pallas/mosaic_gpu/pipeline.py diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 82c9f08145eb..71bf9c3ffae4 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -29,6 +29,7 @@ Functions barrier_wait copy_gmem_to_smem copy_smem_to_gmem + emit_pipeline layout_cast set_max_registers wait_smem_to_gmem diff --git a/jax/BUILD b/jax/BUILD index 7a970b13aef3..ccdc27392843 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -688,6 +688,7 @@ pytype_strict_library( deps = [ "//jax/_src/pallas/mosaic_gpu:core", "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/mosaic_gpu:pipeline", "//jax/_src/pallas/mosaic_gpu:primitives", ], ) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 681c3996359d..6f98c83fdfd8 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -99,3 +99,16 @@ pytype_strict_library( "//jax/_src/pallas", ], ) + +pytype_strict_library( + name = "pipeline", + srcs = ["pipeline.py"], + deps = [ + ":core", + ":primitives", + "//jax", + "//jax:pallas", + "//jax:util", + "//jax/_src/pallas", + ], +) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 8e760646ef05..174efd0757b5 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -104,11 +104,13 @@ def __str__(self) -> str: return self.value def __call__( + self, shape: tuple[int, ...], dtype: jnp.dtype, transforms: Sequence[MemoryRefTransform] = (), - ): + + ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index dd576704fa75..b43fc0147d1d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1126,7 +1126,6 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), - lax.div_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x / y), lax.rem_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x % y), lax.and_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x & y), lax.or_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x | y), @@ -1142,6 +1141,14 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): }) +@register_lowering_rule(lax.div_p) +def _div_lowering_rule(ctx: LoweringRuleContext, x, y): + x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) + if ir.FloatType.isinstance(x.mlir_dtype): + return x / y + return x // y + + @register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): [x_aval] = ctx.avals_in diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py new file mode 100644 index 000000000000..8d2274f1408c --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -0,0 +1,183 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for emitting custom GPU pipelines within a Pallas kernel.""" + +from __future__ import annotations + +from collections.abc import Sequence +import dataclasses +import functools +import itertools as it +import math +from typing import Any + +import jax +from jax import lax +from jax._src import util +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives +from jax.experimental import pallas as pl + + +map = util.safe_map +zip = util.safe_zip + + +@dataclasses.dataclass(frozen=True) +class BufferedRef: + spec: pallas_core.BlockSpec + gmem_ref: pallas_core.AbstractMemoryRef + smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape] + + def compute_gmem_slice(self, grid_indices) -> tuple[Any, ...]: + index_map = self.spec.index_map + assert index_map is not None + return tuple( + pl.ds(idx * size, size) + for idx, size in zip( + index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type] + ) + ) + + def copy_in(self, slot, grid_indices, barrier_ref): + gmem_slices = self.compute_gmem_slice(grid_indices) + gpu_primitives.copy_gmem_to_smem( + self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands + self.smem_ref.at[slot], + barrier=barrier_ref.at[slot], + ) + + def copy_out(self, slot, grid_indices): + gmem_slices = self.compute_gmem_slice(grid_indices) + gpu_primitives.copy_smem_to_gmem( + self.smem_ref.at[slot], self.gmem_ref.at[gmem_slices] # pytype: disable=unsupported-operands + ) + + +jax.tree_util.register_dataclass( + BufferedRef, data_fields=["gmem_ref", "smem_ref"], meta_fields=["spec"] +) + + +def make_grid_indices( + step: jax.typing.ArrayLike, grid: Sequence[int] +) -> tuple[jax.Array, ...]: + # TODO(slebedev): Maintain the grid index through the fori_loop instead. + indices = [] + for size in reversed(grid): + indices.append(lax.rem(step, size)) + step = lax.div(step, size) + return tuple(reversed(indices)) + + +def emit_pipeline( + body, + *, + grid: pallas_core.StaticGrid, + in_specs: Sequence[pallas_core.BlockSpec] = (), + out_specs: Sequence[pallas_core.BlockSpec] = (), + max_concurrent_steps: int = 1, +): + """Creates a function to emit a manual pipeline within a Pallas kernel.""" + num_steps = math.prod(grid) + + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to + # reduce the size of the allocated buffers below. + if max_concurrent_steps > num_steps: + max_concurrent_steps = num_steps + + def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): + in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) + in_smem_refs, out_smem_refs = util.split_list( + map( + lambda spec, ref: gpu_core.SMEM( + (max_concurrent_steps, *spec.block_shape), # type: ignore + ref.dtype, + ), + it.chain(in_specs, out_specs), + gmem_refs, + ), + [len(in_specs)], + ) + return pl.run_scoped( + functools.partial( + scoped_pipeline, + in_gmem_refs=in_gmem_refs, + out_gmem_refs=out_gmem_refs, + ), + in_smem_refs=in_smem_refs, + out_smem_refs=out_smem_refs, + barrier_ref=gpu_core.Barrier( + # TODO(slebedev): Change this to arrive only once. + len(in_specs), + num_barriers=max_concurrent_steps, + ), + ) + + def scoped_pipeline( + *, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref + ): + + in_brefs: Sequence[BufferedRef] = map( + BufferedRef, in_specs, in_gmem_refs, in_smem_refs + ) + out_brefs: Sequence[BufferedRef] = map( + BufferedRef, out_specs, out_gmem_refs, out_smem_refs + ) + + for step, indices in enumerate( + it.islice(it.product(*map(range, grid)), max_concurrent_steps) + ): + map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + + def loop_body(step, _): + slot = step % max_concurrent_steps + + # Wait for the current GMEM->SMEM copy to complete. + gpu_primitives.barrier_wait(barrier_ref.at[slot]) + # Wait for the previous output SMEM->GMEM copy to complete. + gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) + + indices = make_grid_indices(step, grid) + with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): + body( + *(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs)) + ) + + # Copy the output from SMEM to GMEM. + map(lambda bref: bref.copy_out(slot, indices), out_brefs) + + fetch_step = step + max_concurrent_steps + fetch_slot = slot # (x + y) % y == x % y + jax.lax.cond( + fetch_step < num_steps, + lambda: map( + lambda bref: bref.copy_in( + fetch_slot, make_grid_indices(fetch_step, grid), barrier_ref + ), + in_brefs, + ), + lambda: [None] * len(in_brefs), + ) + + return () + + lax.fori_loop(0, num_steps, loop_body, ()) + + # Finalize the pipeline. + gpu_primitives.wait_smem_to_gmem(0) + + return pipeline diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index fbb3a3857c68..3903f3a9c0ae 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -28,6 +28,7 @@ from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 21e8380a4083..2df4512a8b25 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1032,7 +1032,7 @@ def body(step, _): # Wait for the previous output SMEM->GMEM copy to complete. plgpu.wait_smem_to_gmem(max_concurrent_steps - 1) - o_smem[...] = x_smem[...] + 1.0 + o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0 plgpu.copy_smem_to_gmem( o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)] @@ -1074,6 +1074,56 @@ def body(step, _): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit(self, max_concurrent_steps=2, num_steps=4): + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=max_concurrent_steps, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + + def test_emit_with_parallel_grid(self, max_concurrent_steps=2, num_steps=4): + self.skipTest("Enable once we support multiple levels of indexing") + + def kernel(x_gmem, o_gmem): + gmem_slice = pl.ds(pl.program_id(0) * 32, 32) + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=max_concurrent_steps, + )(x_gmem.at[gmem_slice], o_gmem.at[gmem_slice]) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(4 * 32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 1), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + class CoreMapTest(PallasTest): From 36c56fa19be6c8d6c4a19a9adaf58cbf382ad9df Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 28 Oct 2024 08:41:18 -0700 Subject: [PATCH 079/698] [Pallas:MGPU] Fix flaky debug_print tests Turns out that waiting for the kernel to finish it not enough, since the prints also need to be processed by the CUDA runtime. Using a test-only function that synchronizes all the devices seems to suffice. PiperOrigin-RevId: 690624999 --- jaxlib/mosaic/gpu/BUILD | 2 -- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 16 ++++++++++++++-- tests/pallas/mosaic_gpu_test.py | 28 ++++++++++++++++++++-------- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 6bd98fe35880..4ec643dc63c8 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -184,8 +184,6 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/synchronization", "@nanobind", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index ec574de4368f..55801ebdb8d4 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/service/custom_call_status.h" @@ -57,6 +55,20 @@ NB_MODULE(_mosaic_gpu_ext, m) { }); m.def("_record_event_capsule", []() { return EncapsulateFunction(EventRecordCall); }); + m.def("_sync_all_devices", []() { + int devices = 0; + if (cudaGetDeviceCount(&devices) != gpuSuccess) { + throw std::runtime_error("Failed to get device count"); + } + for (int i = 0; i < devices; ++i) { + if (cudaSetDevice(i) != gpuSuccess) { + throw std::runtime_error("Failed to set device"); + } + if (cudaDeviceSynchronize() != gpuSuccess) { + throw std::runtime_error("Failed to synchronize device"); + } + } + }); } } // namespace diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 2df4512a8b25..8d17d8458134 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import math import os @@ -28,6 +29,10 @@ from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib +except ImportError: + mosaic_gpu_lib = None jax.config.parse_flags_with_absl() @@ -43,6 +48,15 @@ def setUp(self): super().setUp() + @contextlib.contextmanager + def capture_stdout(self): + if mosaic_gpu_lib is None: + raise ValueError("Running tests but missing Mosaic GPU extension") + with jtu.capture_stdout() as stdout: + yield stdout + # We need to cudaDeviceSynchronize to make sure printfs are flushed. + mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() + class PallasCallTest(PallasTest): @@ -466,9 +480,8 @@ def kernel(x_ref, o_ref): pl.debug_print("It works!") x = jnp.arange(256).astype(jnp.float32) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) - self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): @@ -480,7 +493,7 @@ def kernel(x_ref, o_ref): x = jnp.arange(size, dtype=jnp.float32).reshape(shape) f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) - with jtu.capture_stdout() as get_output: + with self.capture_stdout() as get_output: jax.block_until_ready(f(x)) output = get_output() @@ -500,7 +513,7 @@ def kernel(x_ref, o_ref): pl.debug_print("x.sum() = {}", x_ref[...].sum()) x = jnp.arange(256) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn(f"x.sum() = {x.sum()}", output()) @@ -515,7 +528,7 @@ def kernel(x_ref, o_ref): pl.debug_print("x.sum() = {}", x_ref[...].sum() + 1) x = jnp.arange(256) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn(f"x.sum() = {x.sum() + 1}", output()) @@ -532,7 +545,7 @@ def kernel(x_ref, o_ref): pl.debug_print("x: {}", x_ref[...]) x = jnp.arange(math.prod(in_shape)).reshape(in_shape) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) @@ -675,7 +688,6 @@ def body(idx, _): np.testing.assert_array_equal(kernel(x, y), x + y) def test_cond(self): - @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), @@ -690,7 +702,7 @@ def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(acc, o_ref.shape) x = jnp.arange(256) - with jtu.capture_stdout() as output: + with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) self.assertIn("acc * 2:", output()) From c9f96642664f8ea844655db8226a778558f63986 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:39:11 +0000 Subject: [PATCH 080/698] Bump actions/cache from 4.1.1 to 4.1.2 Bumps [actions/cache](https://github.com/actions/cache) from 4.1.1 to 4.1.2. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/3624ceb22c1c5a301c8db4169662070a689d9ea8...6849a6489940f00c2f30c0fb92c6274307ccb58a) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index ef752c66b294..0afcee9335b7 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -72,7 +72,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -119,7 +119,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -156,7 +156,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -192,7 +192,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -231,7 +231,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} From 68428488c8d1b62ac5cc9385d3d3623424dba4dd Mon Sep 17 00:00:00 2001 From: minigoel Date: Mon, 28 Oct 2024 10:47:59 -0700 Subject: [PATCH 081/698] Add a link to Intel plugin for JAX --- README.md | 2 ++ docs/installation.md | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/README.md b/README.md index c99d3db10a2a..26c0797db6b0 100644 --- a/README.md +++ b/README.md @@ -390,6 +390,7 @@ Some standouts: | Google TPU | yes | n/a | n/a | n/a | n/a | n/a | | AMD GPU | yes | no | experimental | n/a | no | no | | Apple GPU | n/a | no | n/a | experimental | n/a | n/a | +| Intel GPU | experimental | n/a | n/a | n/a | no | no | ### Instructions @@ -401,6 +402,7 @@ Some standouts: | Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | | AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | +| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) for information on alternative installation strategies. These include compiling diff --git a/docs/installation.md b/docs/installation.md index 5b8893628d85..7cf64955722c 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -35,6 +35,7 @@ The table below shows all supported platforms and installation options. Check if | Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | | AMD GPU | {ref}`experimental ` | no | {ref}`experimental ` | n/a | no | no | | Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a | +| Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no | (install-cpu)= @@ -230,6 +231,17 @@ JAX has experimental ROCm support. There are two ways to install JAX: * Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or * Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_). +(install-intel-gpu)= +## Intel GPU + +Intel provides an experimental OneAPI plugin: intel-extension-for-openxla for Intel GPU hardware. For more details and installation instructions, refer to one of the following two methods: +1. Pip installation: [JAX acceleration on Intel GPU](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). +2. Using [Intel's XLA Docker container](https://hub.docker.com/r/intel/intel-optimized-xla). + +Please report any issues related to: +* JAX: [JAX issue tracker](https://github.com/jax-ml/jax/issues). +* Intel's OpenXLA plugin: [Intel-extension-for-openxla issue tracker](https://github.com/intel/intel-extension-for-openxla/issues). + ## Conda (community-supported) ### Conda installation From bc03e5053a041720f3942ab7c3609c9246dc8ccf Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 28 Oct 2024 10:48:53 -0700 Subject: [PATCH 082/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/7b4c7f36ccb2a0afa511d98fe4cb024599c275ae. PiperOrigin-RevId: 690672528 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e3ed1f3a86d1..084ae0c1f3ba 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a7bd6795dbdfb0845ae6b14b84a035ee357e35d0" -XLA_SHA256 = "2dcea3c95e35f966dbeb95f8de6e2ab30a9992b5c00ac2af88176b38425b5fcd" +XLA_COMMIT = "7b4c7f36ccb2a0afa511d98fe4cb024599c275ae" +XLA_SHA256 = "3ebbee39182dfc8373e870aa69aa9821b6a5149da440a3f7503bdd8c8073165e" def repo(): tf_http_archive( From 987dfaef1c7b8e0ae5b69a77c67b5c29f67e2500 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 28 Oct 2024 10:49:06 -0700 Subject: [PATCH 083/698] Raise a better error message if `None` is passed to with_sharding_constraint. PiperOrigin-RevId: 690672618 --- jax/_src/pjit.py | 8 +++++++- tests/pjit_test.py | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 5b8856aa619e..c0a1cde4f8b6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2449,12 +2449,18 @@ def with_sharding_constraint(x, shardings): shardings_flat = [_create_sharding_for_array(mesh, a, 'shardings', 'with_sharding_constraint') for a in user_shardings_flat] + for s, u in zip(shardings_flat, user_shardings_flat): + if isinstance(s, (UnspecifiedValue, AUTO)): + raise ValueError( + f'One of with_sharding_constraint arguments got sharding {u} which is' + ' not allowed. Please only pass `jax.sharding.Sharding` instances.') + del user_shardings_flat + # TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's # already part of the shardings. unconstrained_dims = [get_unconstrained_dims(s) if isinstance(s, NamedSharding) else {} for s in shardings_flat] - del user_shardings_flat pjit_check_aval_sharding( shardings_flat, x_flat, None, "with_sharding_constraint arguments", diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5f44d5892044..df98f0156c92 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3544,6 +3544,13 @@ def identity(x): out2 = pjit(identity)(arr2) self.assertIsInstance(out2.sharding, PositionalSharding) + def test_wsc_error_on_none(self): + with self.assertRaisesRegex( + ValueError, + 'One of with_sharding_constraint arguments got sharding None which is' + ' not allowed'): + with_sharding_constraint(jnp.arange(8), None) + def test_sharding_preserved_aot(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) From ae0433afdcf5a27c5ebaceb01f17bf7b7d3772d4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:07:27 +0000 Subject: [PATCH 084/698] Bump actions/checkout from 4.2.1 to 4.2.2 Bumps [actions/checkout](https://github.com/actions/checkout) from 4.2.1 to 4.2.2. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871...11bd71901bbe5b1630ceea73d27597364c9af683) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/asan.yaml | 4 ++-- .github/workflows/ci-build.yaml | 12 ++++++------ .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- .github/workflows/jax-array-api.yml | 4 ++-- .github/workflows/metal_plugin_ci.yml | 2 +- .github/workflows/upstream-nightly.yml | 4 ++-- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index a4ec78f96c97..9a49ed2a3e61 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -25,10 +25,10 @@ jobs: run: shell: bash -l {0} steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: python/cpython path: cpython diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 0afcee9335b7..d56ff2e2a22c 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -29,7 +29,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: @@ -57,7 +57,7 @@ jobs: prng-upgrade: 0 num_generated_cases: 1 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Image Setup run: | apt update @@ -108,7 +108,7 @@ jobs: matrix: python-version: ['3.10'] steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: @@ -145,7 +145,7 @@ jobs: matrix: python-version: ['3.10'] steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: @@ -181,7 +181,7 @@ jobs: enable-x64: 0 num_generated_cases: 10 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: @@ -220,7 +220,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 4bff1e87e7f3..a5fac5ebdbc3 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -43,7 +43,7 @@ jobs: # https://opensource.google/documentation/reference/github/services#actions # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install JAX test requirements run: | pip install -U -r build/test-requirements.txt diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 010ebae78c43..d755f1f4a754 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -22,9 +22,9 @@ jobs: steps: - name: Checkout jax - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Checkout array-api-tests - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 3f6d4be94323..2b1100cc048b 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Get repo - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax - name: Setup build and test enviroment diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 46cd3f335fc6..ffbe3ba662bb 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -36,7 +36,7 @@ jobs: outputs: artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: @@ -106,7 +106,7 @@ jobs: run: shell: bash steps: - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: "3.x" diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index c06a12922a05..a788b0ce3d09 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -25,7 +25,7 @@ jobs: - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 795f1f6157ba..54628a5ca9c7 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -31,7 +31,7 @@ jobs: - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax From aef405b623221b61e68749b30cc74bf9b1315de3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:23:52 +0000 Subject: [PATCH 085/698] Bump actions/setup-python from 5.2.0 to 5.3.0 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.2.0 to 5.3.0. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/f677139bbe7f9c59b41e40162b753c062f5d49a3...0b93645e9fea7318ecaed2b359559ac225c90a2b) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 12 ++++++------ .github/workflows/jax-array-api.yml | 2 +- .github/workflows/upstream-nightly.yml | 4 ++-- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index d56ff2e2a22c..581fb858732c 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -31,7 +31,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: 3.11 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 @@ -63,7 +63,7 @@ jobs: apt update apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -110,7 +110,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -147,7 +147,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -183,7 +183,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -222,7 +222,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python 3.11 - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: 3.11 - name: Get pip cache dir diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index d755f1f4a754..648ea0bbe26c 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -32,7 +32,7 @@ jobs: submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index ffbe3ba662bb..04df278019a5 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -38,7 +38,7 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements @@ -107,7 +107,7 @@ jobs: shell: bash steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: "3.x" - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index a788b0ce3d09..2b4a616e224a 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 54628a5ca9c7..3173b81e6819 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.pyver }} cache: 'pip' From 77797f434d29fdbda415a561f102f1807ccf4cdc Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 28 Oct 2024 12:17:34 -0700 Subject: [PATCH 086/698] [JAX] Add the function API of jax.experimental.colocated_python This change adds an experimental API `jax.experimental.colocated_python`. The ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python code that runs close to (or on) accelerator hosts. Multi-controller JAX can trivially achieve this colocated Python code execution today, while single-controller JAX needed its own solution for distributed Python code execution, which creates fragmentation of the user code for these two runtime architectures. `colocated_python` is an attempt to define a single device model and portable API to allow the user to write a single code once that can run on both runtime architectures. This change includes an implementation of the function API portion of `jax.experimental.colocated_python`. A (stateful) object API will be added separately. Also there will be a separate change that expresses serialized functions as an IFRT `CustomCallProgram`. It is currently in an early development stage. Please proceed with a caution when using the API. PiperOrigin-RevId: 690705899 --- jax/BUILD | 21 + jax/experimental/colocated_python/__init__.py | 23 + jax/experimental/colocated_python/api.py | 59 +++ jax/experimental/colocated_python/func.py | 417 ++++++++++++++++++ .../colocated_python/func_backend.py | 44 ++ .../colocated_python/serialization.py | 228 ++++++++++ tests/BUILD | 8 + tests/colocated_python_test.py | 210 +++++++++ 8 files changed, 1010 insertions(+) create mode 100644 jax/experimental/colocated_python/__init__.py create mode 100644 jax/experimental/colocated_python/api.py create mode 100644 jax/experimental/colocated_python/func.py create mode 100644 jax/experimental/colocated_python/func_backend.py create mode 100644 jax/experimental/colocated_python/serialization.py create mode 100644 tests/colocated_python_test.py diff --git a/jax/BUILD b/jax/BUILD index ccdc27392843..07e6b77be433 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1159,3 +1159,24 @@ pytype_library( visibility = ["//visibility:public"], deps = [":jax"], ) + +pytype_library( + name = "experimental_colocated_python", + srcs = [ + "experimental/colocated_python/__init__.py", + "experimental/colocated_python/api.py", + "experimental/colocated_python/func.py", + "experimental/colocated_python/func_backend.py", + "experimental/colocated_python/serialization.py", + ], + visibility = ["//visibility:public"], + deps = [ + ":api_util", + ":jax", + ":traceback_util", + ":tree_util", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy") + py_deps("cloudpickle"), +) diff --git a/jax/experimental/colocated_python/__init__.py b/jax/experimental/colocated_python/__init__.py new file mode 100644 index 000000000000..2e9b4f967cd7 --- /dev/null +++ b/jax/experimental/colocated_python/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python API.""" + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +# pylint: disable=useless-import-alias +from jax.experimental.colocated_python.api import ( + colocated_cpu_devices as colocated_cpu_devices, + colocated_python as colocated_python, +) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py new file mode 100644 index 000000000000..971002f51160 --- /dev/null +++ b/jax/experimental/colocated_python/api.py @@ -0,0 +1,59 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python top-level API.""" + +from __future__ import annotations + +import collections +from typing import Any, Callable, Sequence + +import jax +from jax._src import api_util +from jax._src.lib import xla_extension_version +from jax.experimental.colocated_python.func import make_callable + + +def colocated_cpu_devices( + devices: Sequence[jax.Device], +) -> Sequence[jax.Device]: + """Finds CPU devices colocated with the given devices.""" + if xla_extension_version < 290: + raise NotImplementedError("Requires xla_extension_version >= 290") + + cpu_devices_by_colocation_id = collections.defaultdict(list) + for device in devices[0].backend._get_all_devices(): # pylint: disable=protected-access + if device.device_kind == "cpu": + cpu_devices_by_colocation_id[device.colocation_id].append(device) + if not cpu_devices_by_colocation_id: + raise ValueError("No CPU devices found") + + colocated_cpu_devices = [] + for device in devices: + matches = cpu_devices_by_colocation_id[device.colocation_id] + if not matches: + raise ValueError(f"Device {device} has no colocated devices") + elif len(matches) > 1: + raise ValueError( + f"Ambiguous colocated devices; device {device} has" + f" {len(matches)} colocated devices: f{matches}" + ) + colocated_cpu_devices.append(matches[0]) + return colocated_cpu_devices + + +def colocated_python(fun: Callable[..., Any]) -> Any: + """Executes the given Python function on the same device as the arguments.""" + return make_callable( + fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun) + ) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py new file mode 100644 index 000000000000..3e95ddf03c7e --- /dev/null +++ b/jax/experimental/colocated_python/func.py @@ -0,0 +1,417 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python function API implementation.""" + +from __future__ import annotations + +import dataclasses +import inspect +import random +import threading +from typing import Any, Callable, Sequence + +import jax +from jax._src import api +from jax._src import tree_util +from jax._src.lib import xla_client as xc +from jax._src.traceback_util import api_boundary +from jax._src.util import wraps +from jax.experimental.colocated_python import func_backend +from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs + +ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] + + +@dataclasses.dataclass(frozen=True, slots=True) +class FunctionInfo: + """User function wrapped by colocated_python.""" + + fun: Callable[..., Any] + fun_sourceinfo: str | None + fun_signature: inspect.Signature | None + + +@dataclasses.dataclass(frozen=True, slots=True) +class Specialization: + """Specialization for a colocated_python function.""" + + in_specs_treedef: tree_util.PyTreeDef | None = None + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None + out_specs_treedef: tree_util.PyTreeDef | None = None + out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None + devices: xc.DeviceList | None = None + + def update( + self, + *, + in_specs_treedef: tree_util.PyTreeDef | None = None, + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, + out_specs_treedef: tree_util.PyTreeDef | None = None, + out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, + devices: Sequence[jax.Device] | xc.DeviceList | None = None, + ) -> Any: + """Creates a new specialization with overrides.""" + if in_specs_treedef is None: + in_specs_treedef = self.in_specs_treedef + elif self.in_specs_treedef is not None: + raise ValueError("in_specs already specified") + if in_specs_leaves is None: + in_specs_leaves = self.in_specs_leaves + elif self.in_specs_leaves is not None: + raise ValueError("in_specs already specified") + + if out_specs_fn is None: + out_specs_fn = self.out_specs_fn + elif self.out_specs_fn is not None: + raise ValueError("out_specs_fn already specified") + + if out_specs_treedef is None: + out_specs_treedef = self.out_specs_treedef + elif self.out_specs_treedef is not None: + raise ValueError("out_specs already specified") + if out_specs_leaves is None: + out_specs_leaves = self.out_specs_leaves + elif self.out_specs_leaves is not None: + raise ValueError("out_specs already specified") + + if devices is None: + devices = self.devices + elif self.devices is not None: + raise ValueError("devices already specified") + elif not isinstance(devices, xc.DeviceList): + devices = xc.DeviceList(tuple(devices)) + + return Specialization( + in_specs_treedef, + in_specs_leaves, + out_specs_fn, + out_specs_treedef, + out_specs_leaves, + devices, + ) + + +def _get_spec(x: Any) -> api.ShapeDtypeStruct: + """Extracts a spec for a value, which must be a JAX Array.""" + # TODO(hyeontaek): Allow Python values and automatically apply `shard_arg` + # with a suitable sharding and layout. + if not isinstance(x, jax.Array): + raise ValueError( + "colocated_python only supports jax.Array as input and output, but got" + f" {type(x)}." + ) + return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + + +def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None: + """Returns a representative device list from function call arguments.""" + device_list_set: set[xc.DeviceList] = set() + for x in args: + sharding = getattr(x, "sharding", None) + if sharding is not None: + device_list_set.add(x.sharding._internal_device_list) + if not device_list_set: + return None + if len(device_list_set) != 1: + raise ValueError( + "All arguments must use the same device list, but got" + f" multiple device lists: {device_list_set}." + ) + return device_list_set.pop() + + +def _compile_to_executable( + name: str, + fun: Callable[..., Any], + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...], + out_specs_leaves: tuple[api.ShapeDtypeStruct, ...], + devices: xc.DeviceList, +) -> Callable[..., Any]: + """Compiles a Python function into a runtime executable.""" + # TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an + # executable. + del name + del in_specs_leaves + del out_specs_leaves + del devices + return fun + + +def _make_output_specs_and_push_result_fun( + info: FunctionInfo, specialization: Specialization, uid: int +) -> Callable[..., Any]: + """Creates a function that computes output specs and pushes the result to the result store.""" + assert specialization.in_specs_treedef is not None + assert specialization.in_specs_leaves is not None + assert specialization.out_specs_treedef is None + assert specialization.out_specs_leaves is None + assert specialization.devices is not None + + devices = specialization.devices + + def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]: + result = info.fun(*args, **kwargs) + out_leaves, out_treedef = tree_util.tree_flatten(result) + out_spec_leaves = tuple(_get_spec(x) for x in out_leaves) + func_backend.SINGLETON_RESULT_STORE.push(uid, out_leaves) + return _serialize_specs(out_treedef, out_spec_leaves, devices) + + out_specs_leaves, _ = tree_util.tree_flatten( + _make_specs_for_serialized_specs(specialization.devices), + ) + name = getattr(info.fun, "__name__", "unknown") + name = f"{name}_output_specs_and_push_result" + return _compile_to_executable( + name=name, + fun=lowered_fun, + in_specs_leaves=specialization.in_specs_leaves, + out_specs_leaves=tuple(out_specs_leaves), + devices=specialization.devices, + ) + + +def _make_pop_result_fun( + info: FunctionInfo, specialization: Specialization, uid: int +) -> Callable[..., Any]: + """Makes a function that pops results from the result store.""" + assert specialization.out_specs_treedef is not None + assert specialization.out_specs_leaves is not None + assert specialization.devices is not None + + out_specs_treedef = specialization.out_specs_treedef + + def lowered_fun() -> Any: + flat_result = func_backend.SINGLETON_RESULT_STORE.pop(uid) + return tree_util.tree_unflatten(out_specs_treedef, flat_result) + + in_specs, _ = tree_util.tree_flatten(( + # args + (), + # kwargs + (), + )) + name = getattr(info.fun, "__name__", "unknown") + name = f"{name}_pop_result" + return _compile_to_executable( + name=name, + fun=lowered_fun, + in_specs_leaves=tuple(in_specs), + out_specs_leaves=specialization.out_specs_leaves, + devices=specialization.devices, + ) + + +def _make_async_execution_fun( + info: FunctionInfo, specialization: Specialization +) -> Callable[..., Any]: + """Makes a function that asynchronously executes the function.""" + assert specialization.in_specs_treedef is not None + assert specialization.in_specs_leaves is not None + assert specialization.out_specs_treedef is not None + assert specialization.out_specs_leaves is not None + assert specialization.devices is not None + + name = getattr(info.fun, "__name__", "unknown") + return _compile_to_executable( + name=name, + fun=info.fun, + in_specs_leaves=specialization.in_specs_leaves, + out_specs_leaves=specialization.out_specs_leaves, + devices=specialization.devices, + ) + + +@jax.util.cache(max_size=None) +def _get_specialized_func( + info: FunctionInfo, specialization: Specialization +) -> Callable[..., Any]: + """Returns a specialized function for the given specialization.""" + assert specialization.in_specs_treedef is not None + assert specialization.in_specs_leaves is not None + assert specialization.devices is not None + uid = random.getrandbits(63) + + mutex = threading.Lock() + # Asynchronous execution function that has known output_specs. + async_execution_func = None + + def specialized_func(*args, **kwargs) -> Any: + """Specialized function to be executed with given args and kwargs.""" + nonlocal specialization, async_execution_func + with mutex: + if async_execution_func is None: + if specialization.out_specs_treedef is None: + if specialization.out_specs_fn is None: + serialized_out_specs = _make_output_specs_and_push_result_fun( + info, specialization, uid + )(*args, **kwargs) + + # Waits for the output_specs. This may block. + out_specs_treedef, out_specs_leaves = _deserialize_specs( + serialized_out_specs + ) + + # Subsequent calls would use async_execution_func with discovered + # output_specs. + specialization = specialization.update( + out_specs_treedef=out_specs_treedef, + out_specs_leaves=out_specs_leaves, + ) + async_execution_func = _make_async_execution_fun( + info, specialization + ) + + return _make_pop_result_fun(info, specialization, uid)() + else: + # Compute out_specs using out_specs_fn and inputs. + out_specs = specialization.out_specs_fn(*args, **kwargs) + # Type checking is ignored to silence mypy error: Incompatible types + # in assignment (expression has type "list[Any]", variable has type + # "tuple[ShapeDtypeStruct, ...]") [assignment] + out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( # type: ignore[assignment] + out_specs + ) + specialization = specialization.update( + out_specs_treedef=out_specs_treedef, + out_specs_leaves=tuple(out_specs_leaves), + ) + async_execution_func = _make_async_execution_fun( + info, specialization + ) + # Fall-through. + else: + async_execution_func = _make_async_execution_fun(info, specialization) + # Fall-through. + + return async_execution_func(*args, **kwargs) + + return specialized_func + + +def make_callable( + fun: Callable[..., Any], + fun_sourceinfo: str | None, + fun_signature: inspect.Signature | None, +) -> Callable[..., Any]: + """Makes a colocated Python callable.""" + return _make_callable( + FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() + ) + + +def _make_callable( + info: FunctionInfo, + specialization: Specialization, +) -> Callable[..., Any]: + """Internal implementation of make_callable.""" + + def specialize( + in_specs: ShapeDtypeStructTree | None = None, + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, + devices: Sequence[jax.Device] | None = None, + ) -> Callable[..., Any]: + """Returns a colocated Python callable with extra specialization. + + Args: + in_specs: Optionally specifies the expected input specs. Input specs are + expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a + function call. + out_specs_fn: Optionally specifies a function that computes the output + specs from input specs. If unspecified, colocated_python will compute + the output specs during the very first execution, and this execution + will be synchronous. + devices: Optionally specifies the devices to execute the function on. Must + be provided if in_specs has no leaves because devices cannot be inferred + from input specs or arguments. + + Returns: + A colocated Python callable with extra specialization. + """ + # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if + # `out_specs_fn(in_specs)` returns at least one leaf that we can use for + # inferring `devices`. + if in_specs is None: + in_specs_leaves, in_specs_treedef = None, None + else: + in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs) + in_specs_leaves = tuple(in_specs_leaves_list) + return _make_callable( + info, + specialization.update( + in_specs_treedef=in_specs_treedef, + in_specs_leaves=in_specs_leaves, + out_specs_fn=out_specs_fn, + devices=devices, + ), + ) + + @api_boundary + def __call__(*args, **kwargs) -> Any: + """Executes the function. + + If the output specs are not known, the very first execution will be + synchronous. + """ + args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) + + in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) + if specialization.in_specs_treedef is None: + # Allow input polymorphism by applying input_specs specialization + # temporarily for this call. + return _make_callable( + info, + specialization.update( + in_specs_treedef=in_specs_treedef, + in_specs_leaves=in_specs_leaves, + ), + )(*args, **kwargs) + + if specialization.devices is None: + devices = _infer_devices_from_args(args_leaves) + if devices is None: + raise ValueError( + "No devices found. colocated_python function without input" + " arguments must be first specialized with devices." + ) + # Allow device polymorphism by applying devices specialization temporarily + # for this call. + return _make_callable(info, specialization.update(devices=devices))( + *args, **kwargs + ) + + # Assertion is added to silence mypy error: Unsupported operand types for != + # ("PyTreeDef" and "None") [operator] + assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) + + # If input_specs is known, verify that it matches actual inputs. + if (specialization.in_specs_treedef != in_specs_treedef + or specialization.in_specs_leaves != in_specs_leaves): + raise ValueError( + "Input specs in specialization and input specs of arguments must have" + " the same pytree structure, but they have the following structural" + " differences:\n" + + ("\n".join( + f" - {tree_util.keystr(path)} is a {thing1} in value 1 and" + f" a {thing2} in value 2, so {explanation}.\n" + for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( + specialization.in_specs_treedef, in_specs_treedef + )))) + + return _get_specialized_func(info, specialization)(*args, **kwargs) + + __call__ = wraps(info.fun)(__call__) + __call__.specialize = specialize + return __call__ diff --git a/jax/experimental/colocated_python/func_backend.py b/jax/experimental/colocated_python/func_backend.py new file mode 100644 index 000000000000..aa514015004d --- /dev/null +++ b/jax/experimental/colocated_python/func_backend.py @@ -0,0 +1,44 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Backend for colocated_python.func.""" + +from __future__ import annotations + +import threading +from typing import Sequence + +import jax + + +class _ResultStore: + """Temporarily stores results from synchronous execution of functions.""" + + def __init__(self) -> None: + self._lock = threading.Lock() + self._storage: dict[int, Sequence[jax.Array]] = {} + + def push(self, uid: int, out: Sequence[jax.Array]) -> None: + with self._lock: + if uid in self._storage: + raise ValueError(f"uid {uid} already exists") + self._storage[uid] = out + + def pop(self, uid: int) -> Sequence[jax.Array]: + with self._lock: + if uid not in self._storage: + raise ValueError(f"uid {uid} does not exist") + return self._storage.pop(uid) + + +SINGLETON_RESULT_STORE = _ResultStore() diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py new file mode 100644 index 000000000000..ced50b6eee3c --- /dev/null +++ b/jax/experimental/colocated_python/serialization.py @@ -0,0 +1,228 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Colocated Python serialization utilities.""" + +# TODO(jmudigonda): Use a string-typed array for output structure when it +# becomes available. Using a fixed uint8 array is only for prototyping. + +from __future__ import annotations + +import collections +import io +from typing import Any, Callable, Sequence + +try: + import cloudpickle # type: ignore[import-not-found] +except ImportError: + cloudpickle = None + +import jax +from jax._src import api +from jax._src import tree_util +from jax._src import xla_bridge as xb +from jax._src.lib import xla_client as xc +import numpy as np + +DeviceList = xc.DeviceList + +# Hard-coded limit for serialized specs size. +# TODO(jmudigonda): Use a string-typed array for output structure when it +# becomes available. Using a fixed uint8 array is only for prototyping. +_MAX_SERIALIZED_SPECS_SIZE = 1048576 + + +@jax.util.cache(max_size=None) +def _get_cpu_device_map() -> dict[int, jax.Device]: + """Returns a map from a device id to a matching device.""" + cpu_device_map: dict[int, jax.Device] = {} + # TODO(hyeontaek): We should look up CPU devices for a specific CPU backend. + # When deserializing a device on the controller, the backend should be the one + # associated with colocated_python. When deserializing on the colocated_python + # executor, it should be the CPU backend visible to the user function running + # under colocated_python. + for backed in xb.backends().values(): + for d in backed._get_all_devices(): # pylint: disable=protected-access + if d.device_kind == "cpu": + if d.id in cpu_device_map: + raise ValueError( + f"Multiple CPU devices with id {d.id} found:" + f" {cpu_device_map[d.id]} and {d}" + ) + cpu_device_map[d.id] = d + return cpu_device_map + + +def _reduce_mesh( + mesh: jax.sharding.Mesh, +) -> tuple[Callable[..., jax.sharding.Mesh], Any]: + def make_mesh( + mesh_device_ids: np.ndarray, axis_names: Any + ) -> jax.sharding.Mesh: + cpu_device_map = _get_cpu_device_map() + mesh_devices = np.vectorize(lambda device_id: cpu_device_map[device_id])( + mesh_device_ids + ) + return jax.sharding.Mesh(mesh_devices, axis_names) + + mesh_device_ids = np.vectorize(lambda d: d.id, otypes=[int])(mesh.devices) + return make_mesh, (mesh_device_ids, mesh.axis_names) + + +def _reduce_device_list( + device_list: DeviceList, +) -> tuple[Callable[..., DeviceList], Any]: + def make_device_list(device_ids: Sequence[int]) -> DeviceList: + cpu_device_map = _get_cpu_device_map() + devices = np.vectorize(lambda device_id: cpu_device_map[device_id])( + device_ids + ) + return DeviceList(devices) + + device_ids = [d.id for d in device_list] + return make_device_list, (device_ids,) + + +def _reduce_single_device_sharding( + sharding: jax.sharding.SingleDeviceSharding, +) -> tuple[Callable[..., jax.sharding.SingleDeviceSharding], Any]: + + def make_single_device_sharding(device_id: int): + cpu_device_map = _get_cpu_device_map() + return jax.sharding.SingleDeviceSharding(cpu_device_map[device_id]) + + return make_single_device_sharding, (sharding.device_set.pop().id,) + + +def _serialize(obj: Any) -> bytes: + """Serializes callables and input/output spec objects. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. + + This module contains utility functions used internally for implementiong + `colocated_python` when it ships callables and input/output specs through + IFRT. The pickled data is produced and consumed in an ephermeral fashion + without any persistence, and it does not expect any version compatibility + (which cloudpickle does not guarantee). Furthermore, serialization and + deserialization is expected to be done on machine(s) that are controlled by a + single tenant, which allows unpickling done during deserialization to be + trusted. + + Raises: + ModuleNotFoundError: If cloudpickle is not available. + """ + if cloudpickle is None: + raise ModuleNotFoundError('No module named "cloudpickle"') + + class _CustomPickler(cloudpickle.Pickler): + dispatch_table = collections.ChainMap( + {jax.sharding.Mesh: _reduce_mesh}, + {DeviceList: _reduce_device_list}, + {jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding}, + cloudpickle.CloudPickler.dispatch_table, # pylint: disable=attribute-error + ) + dispatch = dispatch_table + + with io.BytesIO() as file: + _CustomPickler(file).dump(obj) + return file.getvalue() + + +def _deserialize(serialized: bytes) -> Any: + """Deserializes callables and input/output spec objects. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. See serialize() for details. + + Raises: + ModuleNotFoundError: If cloudpickle is not available. + """ + if cloudpickle is None: + raise ModuleNotFoundError('No module named "cloudpickle"') + + return cloudpickle.loads(serialized) + + +def _make_specs_for_serialized_specs( + devices: DeviceList, +) -> tuple[api.ShapeDtypeStruct, api.ShapeDtypeStruct]: + """Makes output specs for serialized specs.""" + # TODO(jmudigonda): Use a string-typed array for output structure when it + # becomes available. Using a fixed uint8 array is only for prototyping. + mesh = jax.sharding.Mesh(tuple(devices), ("x",)) + replicated_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + return ( + api.ShapeDtypeStruct( + shape=(), dtype=np.int32, sharding=replicated_sharding + ), + api.ShapeDtypeStruct( + shape=(_MAX_SERIALIZED_SPECS_SIZE,), + dtype=np.uint8, + sharding=replicated_sharding, + ), + ) + + +def _serialize_specs( + specs_treedef: tree_util.PyTreeDef, + specs_leaves: tuple[api.ShapeDtypeStruct, ...], + devices: DeviceList, +) -> tuple[jax.Array, ...]: + """Serializes the output specs into a tuple of arrays. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. See serialize() for details. + """ + s = _serialize((specs_treedef, specs_leaves)) + assert ( + len(s) <= _MAX_SERIALIZED_SPECS_SIZE + ), f"Too large serialized spec size: {len(s)}" + # TODO(jmudigonda): Use a string-typed array for output structure when it + # becomes available. Using a fixed uint8 array is only for prototyping. + mesh = jax.sharding.Mesh(tuple(devices), ("x",)) + replicated_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + len_array = jax.make_array_from_callback( + shape=(), + sharding=replicated_sharding, + data_callback=lambda _: np.array(len(s), dtype=np.int32), + ) + data_array = jax.make_array_from_callback( + shape=(_MAX_SERIALIZED_SPECS_SIZE,), + sharding=replicated_sharding, + data_callback=lambda _: np.frombuffer( + s + b"\0" * (_MAX_SERIALIZED_SPECS_SIZE - len(s)), + dtype=np.uint8, + ), + ) + return len_array, data_array + + +def _deserialize_specs( + serialized_specs: tuple[jax.Array, ...], +) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]: + """Deserializes the specs from the serialized specs. + + DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF + colocated_python. See serialize() for details. + """ + # TODO(jmudigonda): Use a string-typed array for output structure when it + # becomes available. Using a fixed uint8 array is only for prototyping. + len_array, data_array = serialized_specs + length = int(len_array.addressable_shards[0].data) + data = np.asarray(data_array.addressable_shards[0].data).tobytes() + return _deserialize(data[:length]) diff --git a/tests/BUILD b/tests/BUILD index 47234adeddbf..657d169ba18a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1345,6 +1345,14 @@ jax_multiplatform_test( ], ) +jax_multiplatform_test( + name = "colocated_python_test", + srcs = ["colocated_python_test.py"], + deps = [ + "//jax:experimental_colocated_python", + ], +) + jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py new file mode 100644 index 000000000000..602904757d7a --- /dev/null +++ b/tests/colocated_python_test.py @@ -0,0 +1,210 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from typing import Sequence + +from absl.testing import absltest +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member +from jax.experimental import colocated_python +from jax.experimental.colocated_python import func as colocated_python_func +import jax.numpy as jnp +import numpy as np + +config.parse_flags_with_absl() + + +def _colocated_cpu_devices( + devices: Sequence[jax.Device], +) -> Sequence[jax.Device]: + """Returns CPU devices colocated with the given devices.""" + # TODO(hyeontaek): Use `colocated_python.colocated_cpu_devices(devices)` once + # PjRt-IFRT prepares CPU devices by its own. + cpu_backend_devices = jax.local_devices(backend="cpu") + device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + return [ + cpu_backend_devices[device_index_map[device.id]] for device in devices + ] + + +@contextlib.contextmanager +def _count_colocated_python_specialization_cache_miss() -> list[int]: + """Counts the number of cache misses for colocated_python specialization.""" + original_get_specialized_func = colocated_python_func._get_specialized_func + count = [0] + + @jax.util.cache(max_size=None) + def get_specialized_func(*args, **kwargs): + count[0] += 1 + return original_get_specialized_func(*args, **kwargs) + + colocated_python_func._get_specialized_func = get_specialized_func + try: + yield count + finally: + colocated_python_func._get_specialized_func = original_get_specialized_func + + +_exit_stack = contextlib.ExitStack() + + +def setUpModule(): + # TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT + # prepares CPU devices by its own. + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + + +def tearDownModule(): + _exit_stack.close() + + +class ColocatedPythonTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if xla_extension_version < 290: + self.skipTest("Requires xla_extension_version >= 290") + + def testSimpleFunction(self): + @colocated_python.colocated_python + def add_one(x): + return x + 1 + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + + with _count_colocated_python_specialization_cache_miss() as count: + out = add_one(x) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + out = add_one(x) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + def testSimpleFunctioWithTree(self): + @colocated_python.colocated_python + def add_one(x): + return jax.tree.map(lambda x: x + 1, x) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = [np.array(1), (np.array(2), {"v": np.array(3)})] + x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) + + with _count_colocated_python_specialization_cache_miss() as count: + out = add_one(x) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 1) + + out = add_one(x) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 1) + + def testEmptyInputFailsWithoutSpecialization(self): + @colocated_python.colocated_python + def make_zero(): + return jnp.array(0) + + with self.assertRaisesRegex( + ValueError, + "No devices found. colocated_python function without input arguments" + " must be first specialized with devices.", + ): + _ = make_zero() + + def testEmptyInputWithDevicesSpecialization(self): + @colocated_python.colocated_python + def make_zero(): + return jnp.array(0) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + + with _count_colocated_python_specialization_cache_miss() as count: + make_zero = make_zero.specialize(devices=cpu_devices[:1]) + out = make_zero() + self.assertEqual(out, np.array(0)) + self.assertEqual(count[0], 1) + + out = make_zero() + self.assertEqual(out, np.array(0)) + self.assertEqual(count[0], 1) + + def testInputPolymorphismWithoutOutSpecsFn(self): + @colocated_python.colocated_python + def add_one(x): + return jax.tree.map(lambda x: x + 1, x) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + + with _count_colocated_python_specialization_cache_miss() as count: + out = add_one(x) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + out = add_one(x) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + # Different input tree structure and dtype/shape. + x = [np.array(1), (np.array(2), {"v": np.array(3)})] + x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) + + out = add_one(x) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + out = add_one(x) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + def testInputPolymorphismAllowedWithOutSpecsFn(self): + @colocated_python.colocated_python + def add_one(x): + return jax.tree.map(lambda x: x + 1, x) + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + + with _count_colocated_python_specialization_cache_miss() as count: + add_one = add_one.specialize(out_specs_fn=lambda x: x) + out = add_one(x) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + out = add_one(x) + self.assertEqual(out, np.array(2)) + self.assertEqual(count[0], 1) + + # Different input tree structure and dtype/shape. + x = [np.array(1), (np.array(2), {"v": jnp.array(3)})] + x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) + + out = add_one(x) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + out = add_one(x) + self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) + self.assertEqual(count[0], 2) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 20ed2f3317084d0371e1813c0207f54f59d44c6f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 28 Oct 2024 14:17:41 -0700 Subject: [PATCH 087/698] Improve docs for jnp.arctan2 --- jax/_src/numpy/ufuncs.py | 53 ++++++++++++++++++++++++++++++++++++++-- tests/lax_numpy_test.py | 1 - 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index b18b06d02f2c..ade9cb2062b8 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -35,7 +35,7 @@ from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, - promote_shapes, _where, implements, check_no_float0s) + promote_shapes, _where, check_no_float0s) from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy import reductions @@ -1463,9 +1463,58 @@ def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.sub(*promote_args("subtract", x, y)) -@implements(np.arctan2, module='numpy') @partial(jit, inline=True) def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + r"""Compute the arctangent of x1/x2, choosing the correct quadrant. + + JAX implementation of :func:`numpy.arctan2` + + Args: + x1: numerator array. + x2: denomniator array; should be broadcast-compatible with x1. + + Returns: + The elementwise arctangent of x1 / x2, tracking the correct quadrant. + + See also: + - :func:`jax.numpy.tan`: compute the tangent of an angle + - :func:`jax.numpy.atan2`: the array API version of this function. + + Examples: + Consider a sequence of angles in radians between 0 and :math:`2\pi`: + + >>> theta = jnp.linspace(-jnp.pi, jnp.pi, 9) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(theta) + [-3.14 -2.36 -1.57 -0.79 0. 0.79 1.57 2.36 3.14] + + These angles can equivalently be represented by ``(x, y)`` coordinates + on a unit circle: + + >>> x, y = jnp.cos(theta), jnp.sin(theta) + + To reconstruct the input angle, we might be tempted to use the identity + :math:`\tan(\theta) = y / x`, and compute :math:`\theta = \tan^{-1}(y/x)`. + Unfortunately, this does not recover the input angle: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.arctan(y / x)) + [-0. 0.79 1.57 -0.79 0. 0.79 1.57 -0.79 0. ] + + The problem is that :math:`y/x` contains some ambiguity: although + :math:`(y, x) = (-1, -1)` and :math:`(y, x) = (1, 1)` represent different points in + Cartesian space, in both cases :math:`y / x = 1`, and so the simple arctan + approach loses information about which quadrant the angle lies in. :func:`arctan2` + is built to address this: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.arctan2(y, x)) + [ 3.14 -2.36 -1.57 -0.79 0. 0.79 1.57 2.36 -3.14] + + The results match the input ``theta``, except at the endpoints where :math:`+\pi` + and :math:`-\pi` represent indistinguishable points on the unit circle. By convention, + :func:`arctan2` alwasy returns values between :math:`-\pi` and :math:`+\pi` inclusive. + """ return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ccfce51ec909..f853c742c811 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6189,7 +6189,6 @@ def testWrappedSignaturesMatch(self): jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)} func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items() if getattr(fun, '__np_wrapped__', None) is not None} - assert len(func_pairs) > 0 # TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names. unsupported_params = { From 14030801a57251a7d37e433ee70bc37e2886dd7b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 28 Oct 2024 15:22:09 -0700 Subject: [PATCH 088/698] Remove obsolete implements() decorator & fix tests --- jax/_src/numpy/util.py | 173 +------------------------------ tests/lax_numpy_test.py | 223 +++++++++++++++++++++------------------- 2 files changed, 122 insertions(+), 274 deletions(-) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 27496ad99056..15cbc22dfa0d 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -13,11 +13,9 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import partial -import re -import textwrap -from typing import Any, NamedTuple, TypeVar +from typing import Any import warnings @@ -34,173 +32,6 @@ zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map -_T = TypeVar("_T") - -_parameter_break = re.compile("\n(?=[A-Za-z_])") -_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE) -_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE) -_versionadded = re.compile(r'^\s+\.\.\s+versionadded::', re.MULTILINE) -_docreference = re.compile(r':doc:`(.*?)\s*<.*?>`') - -class ParsedDoc(NamedTuple): - """ - docstr: full docstring - signature: signature from docstring. - summary: summary from docstring. - front_matter: front matter before sections. - sections: dictionary of section titles to section content. - """ - docstr: str | None - signature: str = "" - summary: str = "" - front_matter: str = "" - sections: dict[str, str] = {} - - -def _parse_numpydoc(docstr: str | None) -> ParsedDoc: - """Parse a standard numpy-style docstring. - - Args: - docstr: the raw docstring from a function - Returns: - ParsedDoc: parsed version of the docstring - """ - if docstr is None or not docstr.strip(): - return ParsedDoc(docstr) - - # Remove any :doc: directives in the docstring to avoid sphinx errors - docstr = _docreference.sub( - lambda match: f"{match.groups()[0]}", docstr) - - signature, body = "", docstr - match = _numpy_signature_re.match(body) - if match: - signature = match.group() - body = docstr[match.end():] - - firstline, _, body = body.partition('\n') - body = textwrap.dedent(body.lstrip('\n')) - - match = _numpy_signature_re.match(body) - if match: - signature = match.group() - body = body[match.end():] - - summary = firstline - if not summary: - summary, _, body = body.lstrip('\n').partition('\n') - body = textwrap.dedent(body.lstrip('\n')) - - front_matter = "" - body = "\n" + body - section_list = _section_break.split(body) - if not _section_break.match(section_list[0]): - front_matter, *section_list = section_list - sections = {section.split('\n', 1)[0]: section for section in section_list} - - return ParsedDoc(docstr=docstr, signature=signature, summary=summary, - front_matter=front_matter, sections=sections) - - -def _parse_parameters(body: str) -> dict[str, str]: - """Parse the Parameters section of a docstring.""" - title, underline, content = body.split('\n', 2) - assert title == 'Parameters' - assert underline and not underline.strip('-') - parameters = _parameter_break.split(content) - return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} - - -def implements( - original_fun: Callable[..., Any] | None, - update_doc: bool = True, - sections: Sequence[str] = ('Parameters', 'Returns', 'References'), - module: str | None = None, -) -> Callable[[_T], _T]: - """Decorator for JAX functions which implement a specified NumPy function. - - This mainly contains logic to copy and modify the docstring of the original - function. In particular, if `update_doc` is True, parameters listed in the - original function that are not supported by the decorated function will - be removed from the docstring. For this reason, it is important that parameter - names match those in the original numpy function. - - Args: - original_fun: The original function being implemented - update_doc: whether to transform the numpy docstring to remove references of - parameters that are supported by the numpy version but not the JAX version. - If False, include the numpy docstring verbatim. - sections: a list of sections to include in the docstring. The default is - ["Parameters", "Returns", "References"] - module: an optional string specifying the module from which the original function - is imported. This is useful for objects such as ufuncs, where the module cannot - be determined from the original function itself. - """ - def decorator(wrapped_fun): - wrapped_fun.__np_wrapped__ = original_fun - # Allows this pattern: @implements(getattr(np, 'new_function', None)) - if original_fun is None: - return wrapped_fun - docstr = getattr(original_fun, "__doc__", None) - name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))) - try: - mod = module or original_fun.__module__ - except AttributeError: - if config.enable_checks.value: - raise ValueError(f"function {original_fun} defines no __module__; pass module keyword to implements().") - else: - name = f"{mod}.{name}" - if docstr: - try: - parsed = _parse_numpydoc(docstr) - - if update_doc and 'Parameters' in parsed.sections: - code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None) - # Remove unrecognized parameter descriptions. - parameters = _parse_parameters(parsed.sections['Parameters']) - parameters = {p: desc for p, desc in parameters.items() - if (code is None or p in code.co_varnames)} - if parameters: - parsed.sections['Parameters'] = ( - "Parameters\n" - "----------\n" + - "\n".join(_versionadded.split(desc)[0].rstrip() - for p, desc in parameters.items()) - ) - else: - del parsed.sections['Parameters'] - - docstr = parsed.summary.strip() + "\n" if parsed.summary else "" - docstr += f"\nLAX-backend implementation of :func:`{name}`.\n" - docstr += "\n*Original docstring below.*\n" - - # We remove signatures from the docstrings, because they redundant at best and - # misleading at worst: e.g. JAX wrappers don't implement all ufunc keyword arguments. - # if parsed.signature: - # docstr += "\n" + parsed.signature.strip() + "\n" - - if parsed.front_matter: - docstr += "\n" + parsed.front_matter.strip() + "\n" - kept_sections = (content.strip() for section, content in parsed.sections.items() - if section in sections) - if kept_sections: - docstr += "\n" + "\n\n".join(kept_sections) + "\n" - except: - if config.enable_checks.value: - raise - docstr = original_fun.__doc__ - - wrapped_fun.__doc__ = docstr - for attr in ['__name__', '__qualname__']: - try: - value = getattr(original_fun, attr) - except AttributeError: - pass - else: - setattr(wrapped_fun, attr, value) - return wrapped_fun - return decorator - _dtype = partial(dtypes.dtype, canonicalize=True) def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index f853c742c811..b37237cae28c 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,7 +51,6 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements from jax._src.util import safe_zip, NumpyComplexWarning config.parse_flags_with_absl() @@ -6186,9 +6185,114 @@ class NumpySignaturesTest(jtu.JaxTestCase): def testWrappedSignaturesMatch(self): """Test that jax.numpy function signatures match numpy.""" - jnp_funcs = {name: getattr(jnp, name) for name in dir(jnp)} - func_pairs = {name: (fun, fun.__np_wrapped__) for name, fun in jnp_funcs.items() - if getattr(fun, '__np_wrapped__', None) is not None} + # NumPy functions explicitly not implemented in JAX: + skip = {'array2string', + 'asanyarray', + 'asarray_chkfinite', + 'ascontiguousarray', + 'asfortranarray', + 'asmatrix', + 'base_repr', + 'binary_repr', + 'bmat', + 'broadcast', + 'busday_count', + 'busday_offset', + 'busdaycalendar', + 'common_type', + 'copyto', + 'datetime_as_string', + 'datetime_data', + 'errstate', + 'flatiter', + 'format_float_positional', + 'format_float_scientific', + 'fromregex', + 'genfromtxt', + 'get_include', + 'getbufsize', + 'geterr', + 'geterrcall', + 'in1d', + 'info', + 'is_busday', + 'isfortran', + 'isnat', + 'loadtxt', + 'matrix', + 'may_share_memory', + 'memmap', + 'min_scalar_type', + 'mintypecode', + 'ndenumerate', + 'ndindex', + 'nditer', + 'nested_iters', + 'poly1d', + 'put_along_axis', + 'putmask', + 'real_if_close', + 'recarray', + 'record', + 'require', + 'row_stack', + 'savetxt', + 'savez_compressed', + 'setbufsize', + 'seterr', + 'seterrcall', + 'shares_memory', + 'show_config', + 'show_runtime', + 'test', + 'trapz', + 'typename'} + + # symbols removed in NumPy 2.0 + skip |= {'add_docstring', + 'add_newdoc', + 'add_newdoc_ufunc', + 'alltrue', + 'asfarray', + 'byte_bounds', + 'compare_chararrays', + 'cumproduct', + 'deprecate', + 'deprecate_with_doc', + 'disp', + 'fastCopyAndTranspose', + 'find_common_type', + 'get_array_wrap', + 'geterrobj', + 'issctype', + 'issubclass_', + 'issubsctype', + 'lookfor', + 'mat', + 'maximum_sctype', + 'msort', + 'obj2sctype', + 'product', + 'recfromcsv', + 'recfromtxt', + 'round_', + 'safe_eval', + 'sctype2char', + 'set_numeric_ops', + 'set_string_function', + 'seterrobj', + 'sometrue', + 'source', + 'who'} + + self.assertEmpty(skip.intersection(dir(jnp))) + + names = (name for name in dir(np) if not (name.startswith('_') or name in skip)) + names = (name for name in names if callable(getattr(np, name))) + names = {name for name in names if not isinstance(getattr(np, name), type)} + self.assertEmpty(names.difference(dir(jnp))) + + self.assertNotEmpty(names) # TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names. unsupported_params = { @@ -6199,6 +6303,7 @@ def testWrappedSignaturesMatch(self): 'copy': ['subok'], 'corrcoef': ['ddof', 'bias', 'dtype'], 'cov': ['dtype'], + 'cumulative_prod': ['out'], 'cumulative_sum': ['out'], 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], @@ -6210,9 +6315,7 @@ def testWrappedSignaturesMatch(self): 'full': ['order', 'like'], 'full_like': ['subok', 'order'], 'fromfunction': ['like'], - 'histogram': ['normed'], - 'histogram2d': ['normed'], - 'histogramdd': ['normed'], + 'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'], 'nanpercentile': ['weights'], 'nanquantile': ['weights'], 'nanstd': ['correction', 'mean'], @@ -6222,7 +6325,6 @@ def testWrappedSignaturesMatch(self): 'partition': ['kind', 'order'], 'percentile': ['weights'], 'quantile': ['weights'], - 'reshape': ['shape', 'copy'], 'row_stack': ['casting'], 'stack': ['casting'], 'std': ['mean'], @@ -6233,18 +6335,19 @@ def testWrappedSignaturesMatch(self): } extra_params = { - # TODO(micky774): Remove when np.clip has adopted the Array API 2023 - # standard - 'clip': ['x', 'max', 'min'], + 'compress': ['size', 'fill_value'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], + 'load': ['args', 'kwargs'], 'take_along_axis': ['mode', 'fill_value'], 'fill_diagonal': ['inplace'], } mismatches = {} - for name, (jnp_fun, np_fun) in func_pairs.items(): + for name in names: + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) if name in ['histogram', 'histogram2d', 'histogramdd']: # numpy 1.24 re-orders the density and weights arguments. # TODO(jakevdp): migrate histogram APIs to match newer numpy versions. @@ -6258,12 +6361,15 @@ def testWrappedSignaturesMatch(self): # TODO(dfm): After our deprecation period for the clip arguments ends # it should be possible to reintroduce the check. continue - # Note: can't use inspect.getfullargspec due to numpy issue + if name == "reshape": + # Similar issue to clip: we'd need logic specific to the NumPy version + # because of the change in argument name from `newshape` to `shape`. + continue + # Note: can't use inspect.getfullargspec for some functions due to numpy issue # https://github.com/numpy/numpy/issues/12225 try: np_params = inspect.signature(np_fun).parameters except ValueError: - # Some functions cannot be inspected continue jnp_params = inspect.signature(jnp_fun).parameters extra = set(extra_params.get(name, [])) @@ -6350,8 +6456,6 @@ def testUfuncInputTypes(self, name, arg_dtypes): class NumpyDocTests(jtu.JaxTestCase): def test_lax_numpy_docstrings(self): - # Test that docstring wrapping & transformation didn't fail. - unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', 'amax', 'amin', 'around', 'bitwise_invert', 'bitwise_left_shift', @@ -6371,15 +6475,6 @@ def test_lax_numpy_docstrings(self): elif hasattr(np, name) and obj is getattr(np, name): # Some APIs are imported directly from NumPy; we don't check these. pass - elif hasattr(obj, '__np_wrapped__'): - # Functions decorated with @implements(...) should have __np_wrapped__ - wrapped_fun = obj.__np_wrapped__ - if wrapped_fun is not None: - # If the wrapped function has a docstring, obj should too - if wrapped_fun.__doc__ and not obj.__doc__: - raise Exception(f"jnp.{name} does not contain wrapped docstring.") - if obj.__doc__ and "*Original docstring below.*" not in obj.__doc__: - raise Exception(f"jnp.{name} does not have a wrapped docstring.") elif name in aliases: assert "Alias of" in obj.__doc__ elif name not in skip_args_check: @@ -6391,84 +6486,6 @@ def test_lax_numpy_docstrings(self): if name not in ["frompyfunc", "isdtype", "promote_types"]: self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}") - @parameterized.named_parameters( - {"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False]) - def test_wrapped_function_parameters(self, jit): - def orig(x): - """Example Docstring - - Parameters - ---------- - x : array_like - Input Data - - .. versionadded:: 1.8.0 - out : array_like, optional - Output to overwrite - other_arg : Any - not used - - Returns - ------- - x : input - """ - return x - - def wrapped(x, out=None): - return x - - if jit: - wrapped = jax.jit(wrapped) - - wrapped = implements(orig)(wrapped) - doc = wrapped.__doc__ - - self.assertStartsWith(doc, "Example Docstring") - self.assertIn("Original docstring below", doc) - self.assertIn("Parameters", doc) - self.assertIn("Returns", doc) - self.assertNotIn('other_arg', doc) - self.assertNotIn('versionadded', doc) - - - def test_parse_numpydoc(self): - # Unit test ensuring that _parse_numpydoc correctly parses docstrings for all - # functions in NumPy's top-level namespace. - section_titles = {'Attributes', 'Examples', 'Notes', - 'Parameters', 'Raises', 'References', - 'Returns', 'See also', 'See Also', 'Warnings', 'Warns'} - headings = [title + '\n' + '-'*len(title) for title in section_titles] - - for name in dir(np): - if name.startswith('_'): - continue - obj = getattr(np, name) - if isinstance(obj, type): - continue - if not callable(obj): - continue - if 'built-in function' in repr(obj): - continue - parsed = _parse_numpydoc(obj.__doc__) - - # Check that no docstring is handled gracefully. - if not obj.__doc__: - self.assertEqual(parsed, ParsedDoc(obj.__doc__)) - continue - - # Check that no unexpected section names are found. - extra_keys = parsed.sections.keys() - section_titles - if extra_keys: - raise ValueError(f"Extra section headers found in np.{name}: {extra_keys}") - - # Check that every docstring has a summary. - if not parsed.summary: - raise ValueError(f"No summary found for np.{name}") - - # Check that no expected headings are missed. - for heading in headings: - assert heading not in parsed.front_matter - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 12d26053e31c5c9f45da5a15ce7fb7fcbb0a96b7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 28 Oct 2024 16:57:11 -0700 Subject: [PATCH 089/698] [TPU][Mosaic] Add support for a no-op reshape where sublane_tiling = 1 and the res_tiled and src_tiled shapes both fill a full vreg (1024) PiperOrigin-RevId: 690796348 --- .../mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 7 +++++++ .../mosaic/dialect/tpu/transforms/infer_vector_layout.cc | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 10585eec2920..79ebb725ccab 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4112,6 +4112,13 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, dst_tiled_dims[1] % dst_vreg_slice[1] == 0) { // Shapecast (..., 128) -> (..., m * 128 * packing). no_op = true; + } else if (layout_in.offsets() == LayoutOffsets{0, 0} && + layout_out.offsets() == LayoutOffsets{0, 0} && + layout_in.tiling()[0] == 1 && layout_out.tiling()[0] == 1 && + src_vreg_slice[1] == dst_vreg_slice[1] && + src_tiled_dims[1] % src_vreg_slice[1] == 0 && + dst_tiled_dims[1] % dst_vreg_slice[1] == 0) { + no_op = true; } FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_vregs, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 47e3baa0ff09..9fcb8afc7a47 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1448,7 +1448,13 @@ class VectorLayoutInferer { // 1D tilings that use 1 in the sublane dimension. int64_t sublane_tiling = vreg_slice[0]; do { - if (src_tiled_ishape[1] == res_tiled_ishape[1] && + auto src_res_tiled_equal = src_tiled_ishape[1] == res_tiled_ishape[1]; + auto vreg_num_elements = target_shape_[0] * target_shape_[1]; + auto single_subline_mod_1024 = + (sublane_tiling == 1 && + src_tiled_ishape[1] % vreg_num_elements == 0 && + res_tiled_ishape[1] % vreg_num_elements == 0); + if ((src_res_tiled_equal || single_subline_mod_1024) && src_tiled_ishape[0] % sublane_tiling == 0 && res_tiled_ishape[0] % sublane_tiling == 0) { std::array tiling = {sublane_tiling, target_shape_[1]}; From 86a47a7d4ebb7dcd0f77e3646249558f8853574f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 27 Oct 2024 02:21:01 +0000 Subject: [PATCH 090/698] fix jax.custom_gradient to allow closing over non-autodiff tracers --- jax/_src/custom_derivatives.py | 67 ++++++++++++++++++++++------------ tests/api_test.py | 16 +++++++- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index f5ecdfcda286..216fb51f6a46 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -41,7 +41,7 @@ from jax._src.interpreters.batching import not_mapped from jax._src.lax import lax from jax._src.tree_util import ( - tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, + tree_flatten, tree_unflatten, tree_map, treedef_tuple, register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr, treedef_children) from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable, @@ -1029,32 +1029,51 @@ def custom_gradient(fun): >>> print(jax.grad(f, argnums=(0, 1))(3., 4.)) (Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True)) """ - @custom_vjp + # TODO(mattjj): better debug info def wrapped_fun(*args, **kwargs): - ans, _ = fun(*args, **kwargs) - return ans - - def fwd(*args, **kwargs): - ans, rule = fun(*args, **kwargs) - ans_flat, out_tree = tree_flatten((ans,)) - rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) - ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) - return ans, Residuals(jaxpr, in_tree(), out_tree, consts) - - def bwd(res, cts): - jaxpr, in_tree, out_tree, consts = res - cts_flat, out_tree_ = tree_flatten((cts,)) - if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}') - cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat) - cts_out = tree_unflatten(in_tree, cts_out) - if treedef_is_leaf(in_tree): - cts_out = (cts_out,) - return cts_out - - wrapped_fun.defvjp(fwd, bwd) + args_flat, in_tree = tree_flatten((args, kwargs)) + in_avals = [core.get_aval(x) for x in args_flat] + primal_jaxpr, fwd_jaxpr, bwd_jaxpr, consts, out_tree = \ + _primal_fwd_bwd(in_tree, in_avals) + + @custom_vjp + def primal(consts, args): + return core.eval_jaxpr(primal_jaxpr, (), *consts, *args) + def fwd(consts, args): + ans_res = core.eval_jaxpr(fwd_jaxpr, (), *consts, *args) + return split_list(ans_res, [out_tree.num_leaves]) + def bwd(res, cts): + return None, core.eval_jaxpr(bwd_jaxpr, res, *cts) + primal.defvjp(fwd, bwd) + + out_flat = primal(consts, args_flat) + return tree_unflatten(out_tree, out_flat) + + def _primal_fwd_bwd(in_tree, in_avals): + out_tree, rule_jaxpr = None, None + @lu.wrap_init + def run(*args_flat): + nonlocal rule_jaxpr, out_tree + args, kwargs = tree_unflatten(in_tree, args_flat) + ans, rule = fun(*args, **kwargs) + ans_flat, out_tree = tree_flatten((ans,)) + ans_bar_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] + rule_, in_tree_ = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) + rule_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule_, ans_bar_avals) + out_tree, = treedef_children(out_tree) + return *ans_flat, *consts + fwd_jaxpr, _, fwd_consts, () = pe.trace_to_jaxpr_dynamic(run, in_avals) + fwd_jaxpr = pe.convert_constvars_jaxpr(fwd_jaxpr) + assert out_tree is not None and rule_jaxpr is not None + num_ans = out_tree.num_leaves + num_res = len(fwd_jaxpr.outvars) - num_ans + primal_jaxpr, _ = pe.dce_jaxpr(fwd_jaxpr, + [True] * num_ans + [False] * num_res, True) + return primal_jaxpr, fwd_jaxpr, rule_jaxpr, fwd_consts, out_tree + return wrapped_fun + @register_pytree_node_class class Residuals: def __init__(self, jaxpr, in_tree, out_tree, consts): diff --git a/tests/api_test.py b/tests/api_test.py index 2c2412093805..e7ac95baa55b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8976,13 +8976,27 @@ def f(x): vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) return jnp.sum(jnp.sin(x)), vjp - self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), + self.assertAllClose(f(jnp.arange(3.)), jnp.sum(jnp.sin(jnp.arange(3.))), check_dtypes=False) self.assertAllClose( api.grad(f)(jnp.arange(3.)), api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), check_dtypes=False) + def test_custom_gradient_jit_closure(self): + @jax.jit + def f(x, y): + y = jnp.sin(y) + + @jax.custom_gradient + def g(x): + return y * jnp.sin(x), lambda g: (y * jnp.cos(x) * g,) + + return g(x) + + g = jax.grad(f)(1., 2.) + self.assertAllClose(g, jnp.sin(2.) * jnp.cos(1.), check_dtypes=False) + def test_custom_gradient_can_return_singleton_value_in_vjp(self): @jax.custom_gradient def f(x): From a8d1048cb67524eca1255b6ee23a44bbdf9607e6 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 28 Oct 2024 18:52:39 -0700 Subject: [PATCH 091/698] [Pallas] Add tests for `jnp.logical_not` PiperOrigin-RevId: 690825419 --- tests/pallas/ops_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0988e56d7e41..9867fb2c0a9d 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -756,6 +756,7 @@ def kernel(x_ref, o_ref): ["float32", "float64"], ), ([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]), + ([jnp.logical_not], ["bool"]) ] @parameterized.named_parameters( @@ -792,6 +793,14 @@ def test_elementwise(self, fn, dtype): ): self.skipTest(f"{fn.__name__} not implemented on TPU") + # TODO: https://github.com/jax-ml/jax/issues/24243 + if ( + jtu.test_device_matches(["tpu"]) + and fn == jnp.logical_not + and not self.INTERPRET + ): + self.skipTest("logical_not on TPU is only supported in interpret mode") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), From abf14323dc3e2c0cf8114d310592ed828f5cecc1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 28 Oct 2024 18:53:38 -0700 Subject: [PATCH 092/698] Adjust copyright notice. Previously we had been pulling-in NumPy and SciPy docs at runtime, but after the work in #21461 this is no longer the case. --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index d57420dec881..8007c0b3d828 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,7 +49,7 @@ def _do_not_evaluate_in_jax( # -- Project information ----------------------------------------------------- project = 'JAX' -copyright = '2024, The JAX Authors. NumPy and SciPy documentation are copyright the respective authors.' +copyright = '2024, The JAX Authors' author = 'The JAX authors' # The short X.Y version From 3a87348bfc56c60a0ed381407845b17eb41f962b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 29 Oct 2024 02:04:07 -0700 Subject: [PATCH 093/698] [Pallas:MGPU] Use shfl.sync after computing the warpgroup index The shuffle is completely unnecessary, but there's some mysterious black magic pattern patcher in ptxas that really wants us to do it. This tiny difference is what makes or breaks a kernel: if we shuffle the warpgroup index in attention kernels, we see ~70% utilization; if we don't we get at most ~50%... PiperOrigin-RevId: 690928489 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b43fc0147d1d..6697b4ada895 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1217,7 +1217,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.module_ctx.grid_mapping.grid_names if grid_names and axis_name in grid_names: if axis_name == grid_names[-1]: - return mgpu.warpgroup_idx(sync=False) + return mgpu.warpgroup_idx(sync=True) else: idx = grid_names.index(axis_name) return arith_dialect.index_cast( From de680184731f29e8fc46ac2820627ad6cfdc9b3e Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 29 Oct 2024 05:19:26 -0700 Subject: [PATCH 094/698] [NFC][Mosaic TPU] Clarify layout comment block PiperOrigin-RevId: 690977672 --- jaxlib/mosaic/dialect/tpu/layout.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 1b6b8b935c99..66217858fa7d 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -169,7 +169,7 @@ class RectangularVregBounds : public VRegDataBounds { // // The tiling attribute makes it possible to subdivide a single vector register // into multiple subtiles that traverse the last dimension of a value. For -// example, consider vregs of shape (4, 5) an array: +// example, consider vregs of shape (4, 5) on (2, 10) array: // // a b c d e f g h i j // k l m n o p q r s t From 8b216149730e724a058c018309bef0023c0dfc5e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 29 Oct 2024 05:20:02 -0700 Subject: [PATCH 095/698] [Pallas:MGPU] Add FlashAttention3 as an example PiperOrigin-RevId: 690977852 --- jax/BUILD | 18 +- jax/_src/pallas/mosaic_gpu/core.py | 4 +- jax/_src/pallas/mosaic_gpu/primitives.py | 4 +- jax/experimental/pallas/ops/gpu/BUILD | 35 +++ .../pallas/ops/gpu/attention_mgpu.py | 278 ++++++++++++++++++ tests/pallas/BUILD | 27 ++ tests/pallas/mgpu_attention_test.py | 78 +++++ 7 files changed, 439 insertions(+), 5 deletions(-) create mode 100644 jax/experimental/pallas/ops/gpu/BUILD create mode 100644 jax/experimental/pallas/ops/gpu/attention_mgpu.py create mode 100644 tests/pallas/mgpu_attention_test.py diff --git a/jax/BUILD b/jax/BUILD index 07e6b77be433..71be67368f3b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -627,7 +627,7 @@ pytype_strict_library( pytype_strict_library( name = "pallas_gpu_ops", - srcs = glob(["experimental/pallas/ops/gpu/**/*.py"]), + srcs = ["//jax/experimental/pallas/ops/gpu:triton_ops"], visibility = [ ":pallas_gpu_users", ], @@ -638,6 +638,22 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "pallas_experimental_gpu_ops", + testonly = True, + srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"], + visibility = [ + ":mosaic_gpu_users", + ], + deps = [ + ":jax", + ":mosaic_gpu", + ":pallas", + ":pallas_mosaic_gpu", + ":test_util", # This is only to make them runnable as jax_multiplatform_test... + ] + py_deps("numpy"), +) + pytype_strict_library( name = "pallas_tpu_ops", srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]), diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 174efd0757b5..7f9e0bef822e 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -497,6 +497,7 @@ class GPUMesh: # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None axis_names: tuple[str, ...] = () + approx_math: bool = False def __post_init__(self): if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): @@ -547,9 +548,8 @@ def _gpu_mesh_discharge_rule( jaxpr=jaxpr, grid=tuple(mesh.shape.items()), backend="mosaic_gpu", - compiler_params=GPUCompilerParams(), + compiler_params=GPUCompilerParams(approx_math=mesh.approx_math), ) - pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 3874c7125d5e..f87e96a30c5f 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -108,7 +108,7 @@ def copy_smem_to_gmem( """ if src.memory_space is not gpu_core.SMEM: raise TypeError(f"src must be a SMEM reference, got {src.memory_space}") - if dst.memory_space is not gpu_core.GMEM: + if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM: raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}") src, src_transforms = state_primitives.get_ref_and_transforms( src, None, "copy_smem_to_gmem", force_trailing_indexer=False, @@ -203,7 +203,7 @@ def copy_gmem_to_smem( :func:`jax.experimental.mosaic.gpu.barrier_arrive` :func:`jax.experimental.mosaic.gpu.barrier_wait` """ - if src.memory_space is not gpu_core.GMEM: + if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM: raise TypeError(f"src must be a GMEM reference, got {src.memory_space}") if dst.memory_space is not gpu_core.SMEM: raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}") diff --git a/jax/experimental/pallas/ops/gpu/BUILD b/jax/experimental/pallas/ops/gpu/BUILD new file mode 100644 index 000000000000..20ff2152c356 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/BUILD @@ -0,0 +1,35 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_applicable_licenses = [], + default_visibility = ["//jax:__subpackages__"], +) + +exports_files( + srcs = glob(["*.py"]), +) + +filegroup( + name = "triton_ops", + srcs = glob( + ["*.py"], + exclude = ["*_mgpu.py"], + ), +) + +filegroup( + name = "mgpu_ops", + srcs = glob(["*_mgpu.py"]), +) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py new file mode 100644 index 000000000000..9320550969e8 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -0,0 +1,278 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FlashAttention3 implementation (using Mosaic GPU as the backend).""" + +import dataclasses +import functools +import itertools +import jax +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + block_q: int + block_kv: int + max_concurrent_steps: int + + +@functools.partial(jax.jit, static_argnames=["config"]) +def attention(q, k, v, config: TuningConfig): + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") + batch_size, q_seq_len, num_q_heads, head_dim = q.shape + _, kv_seq_len, num_kv_heads, _ = k.shape + kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim) + if k.shape != kv_shape: + raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)") + if k.shape != kv_shape: + raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)") + if (dtype := q.dtype) != k.dtype or dtype != v.dtype: + raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}") + if batch_size != 1 or num_q_heads != 1 or num_kv_heads != 1: + raise NotImplementedError( + "Only batch_size=1, num_q_heads=1, and num_kv_heads=1 are supported," + f" got: {batch_size=}, {num_q_heads=}, {num_kv_heads=}" + ) + if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): + raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") + q, k, v = map(lambda x: x[0, :, 0, :], (q, k, v)) + max_concurrent_steps = min( + config.max_concurrent_steps, kv_seq_len // config.block_kv + ) + block_q, block_kv = config.block_q, config.block_kv + + def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped + wg_idx = lax.axis_index("wg") + qo_smem2, k_smem, v_smem = smem_buffers + k_barriers, v_barriers, q_barriers = buffer_barriers + k_consumed_barrier, v_consumed_barrier = consumed_barriers + def perform_schedule_barrier(): + plgpu.barrier_arrive(schedule_barrier) + plgpu.barrier_wait(schedule_barrier) + + @pl.when(wg_idx < 2) + def _compute_wg(): + plgpu.set_max_registers(232, action="increase") + qo_smem = qo_smem2.at[wg_idx] + q_seq_base = lax.axis_index("q") * (2 * block_q) + wg_idx * block_q + + plgpu.copy_gmem_to_smem( + q_ref.at[pl.ds(q_seq_base, block_q)], + qo_smem, + barrier=q_barriers.at[wg_idx], + ) + plgpu.barrier_wait(q_barriers.at[wg_idx]) + + m_i = plgpu.layout_cast( + jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, + ) + l_i = plgpu.layout_cast( + jnp.full((block_q,), 0, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, + ) + acc = plgpu.layout_cast( + jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + + plgpu.barrier_wait(k_barriers.at[0]) + + pl.when(wg_idx == 1)(perform_schedule_barrier) + def kv_loop(kv_step, carry): + acc, m_i, l_i = carry + slot = lax.rem(kv_step, max_concurrent_steps) + + # QK + def compute_qk(acc_ref): + plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem.at[slot], (1, 0))) + perform_schedule_barrier() + return acc_ref[...] + qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32)) + plgpu.barrier_arrive(k_consumed_barrier) + + # Softmax + m_ij = jnp.maximum(m_i, qk.max(axis=1)) + alpha = jnp.exp(m_i - m_ij) + m_i = m_ij + p = jnp.exp(qk - lax.broadcast_in_dim(m_ij, (block_q, block_kv), [0])) + acc *= lax.broadcast_in_dim(alpha, (block_q, head_dim), [0]) + l_i *= alpha + p16 = p.astype(dtype) + + plgpu.barrier_wait(v_barriers.at[slot]) + perform_schedule_barrier() + + l_i += p.sum(axis=1) + + # PV + def compute_pv(acc_ref): + plgpu.wgmma(acc_ref, p16, v_smem.at[slot]) + + wait_step = kv_step + 1 + wait_slot = lax.rem(wait_step, max_concurrent_steps) + @pl.when(wait_step < kv_seq_len // block_kv) + def _wait(): + plgpu.barrier_wait(k_barriers.at[wait_slot]) + acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc)) + plgpu.barrier_arrive(v_consumed_barrier) + return acc, m_i, l_i + if kv_seq_len % block_kv: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") + acc, m_i, l_i = lax.fori_loop( + 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) + ) + pl.when(wg_idx == 0)(perform_schedule_barrier) + del m_i # Not needed anymore + + # TODO(apaszke): Invert and multiply to avoid expensive divisions. + acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) + qo_smem[...] = acc.astype(dtype) + plgpu.copy_smem_to_gmem( + qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)], + ) + plgpu.wait_smem_to_gmem(0) + @pl.when(wg_idx == 2) + def _memory_wg(): + plgpu.set_max_registers(40, action="decrease") + for i in range(max_concurrent_steps): + s = pl.ds(i * block_kv, block_kv) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], barrier=k_barriers.at[i]) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], barrier=v_barriers.at[i]) + + def kv_loop(kv_step, _): + tma_step = kv_step + max_concurrent_steps + tma_slot = lax.rem(kv_step, max_concurrent_steps) + s = pl.ds(tma_step * block_kv, block_kv) + plgpu.barrier_wait(k_consumed_barrier) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], barrier=k_barriers.at[tma_slot]) + plgpu.barrier_wait(v_consumed_barrier) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], barrier=v_barriers.at[tma_slot]) + lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) + + def kv_epilogue(i, _): + plgpu.barrier_wait(k_consumed_barrier) + plgpu.barrier_wait(v_consumed_barrier) + lax.fori_loop(0, max_concurrent_steps, kv_epilogue, None) + + def run(refs): + q_ref, k_ref, v_ref, out_ref = refs + + num_q_tiles, rem = divmod(q_seq_len, block_q * 2) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") + mesh = plgpu.GPUMesh( + grid=(num_q_tiles,), num_threads=3, axis_names=("q", "wg"), approx_math=True, + ) + @pl.core_map(mesh) + def _kernel_entry(): + compute_wgs = 2 + barrier_2wg = plgpu.Barrier(num_arrivals=compute_wgs) + tiling = plgpu.TilingTransform((64, 64)) + swizzle = plgpu.SwizzleTransform(128) + qo_scratch = plgpu.SMEM( + (compute_wgs, block_q, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + k_scratch = plgpu.SMEM( + (max_concurrent_steps, block_kv, head_dim), jnp.float16, + transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), + ) + v_scratch = plgpu.SMEM( + (max_concurrent_steps, block_kv, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + pl.run_scoped( + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), + (qo_scratch, k_scratch, v_scratch), + ( + plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(1, num_barriers=compute_wgs), + ), + (barrier_2wg, barrier_2wg), + barrier_2wg, + ) + + _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) + return out[None, :, None, :] + + +@jax.jit +def attention_reference(q, k, v): + batch_size, q_seq_len, num_q_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) + q_reshaped = q.reshape( + batch_size, q_seq_len, num_kv_heads, num_q_heads // num_kv_heads, head_dim + ) + logits = jnp.einsum("bqHhc,bkHc->bqHhk", q_reshaped, k) + m = logits.max(axis=-1, keepdims=True) + unnormalized = jnp.exp(logits - m) + l = unnormalized.sum(axis=-1, keepdims=True) + weights = unnormalized / l + return jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) + + +def main(unused_argv): + batch_size = 1 + num_q_heads = 1 + num_kv_heads = 1 + problem_it = itertools.product((4096, 32768,), (64, 128, 256,)) + for seq_len, head_dim in problem_it: + q_seq_len = kv_seq_len = seq_len + print(f"==== {kv_seq_len=:<6} {q_seq_len=:<6} {num_q_heads=:<4} {head_dim=:<6} ====") + param_it = itertools.product((64,), (64, 128, 256)) + best = None + k1, k2, k3 = jax.random.split(jax.random.key(42), 3) + q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + for block_q, block_kv in param_it: + config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2) + try: + out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v) + out_ref = attention_reference(q, k, v) + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + continue + raise + runtime_us = runtime_ms * 1e3 + matmul_flops = ( + 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size + ) + peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + print( + f"block_q={block_q:<4}block_kv={block_kv:<4}: {runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + if best is None or runtime_us < best[0]: + best = (runtime_us, achieved_tc_util) + if best is not None: + print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization") + + +if __name__ == "__main__": + from absl import app + import jax + jax.config.config_with_absl() + app.run(main) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index b5af90272510..09728e432a8c 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -467,3 +467,30 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", ] + py_deps("absl/testing") + py_deps("numpy"), ) + +jax_multiplatform_test( + name = "mgpu_attention_run", + srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"], + enable_backends = [], + enable_configs = ["gpu_h100_x32"], + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "mgpu_attention_test", + srcs = ["mgpu_attention_test.py"], + enable_backends = [], + enable_configs = ["gpu_h100_x32"], + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py new file mode 100644 index 000000000000..87d58e96ad40 --- /dev/null +++ b/tests/pallas/mgpu_attention_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of FlashAttention.""" + +import os + +import numpy as np +from absl.testing import absltest, parameterized +from jax._src import config +from jax._src import test_util as jtu +import jax.numpy as jnp + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + attention_mgpu = None +else: + from jax.experimental.pallas.ops.gpu import attention_mgpu + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +@jtu.with_config(jax_traceback_filtering="off") +class FlashAttentionTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if attention_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + + @parameterized.product( + batch_size=(1,), + q_seq_len=(4096,), + kv_seq_len=(4096,), + num_q_and_kv_heads=((1, 1),), + # TODO(apaszke): Enable once we support many heads. + # num_q_and_kv_heads=((4, 1), # MQA + # (6, 3), # GQA + # (4, 4),), # MHA + head_dim=(64, 128, 256), + ) + def test_flash_attention( + self, batch_size, q_seq_len, kv_seq_len, num_q_and_kv_heads, head_dim + ): + num_q_heads, num_kv_heads = num_q_and_kv_heads + k1, k2, k3 = jax.random.split(jax.random.key(42), 3) + q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + out = attention_mgpu.attention( + q, k, v, attention_mgpu.TuningConfig(block_q=64, block_kv=64, max_concurrent_steps=2) + ) + out_ref = attention_mgpu.attention_reference(q, k, v) + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 5ccfc8d71644c02e1aa45b6b5b2c02a5cd1254c4 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 29 Oct 2024 06:06:31 -0700 Subject: [PATCH 096/698] Reverts c3b4b76080dbedfebfed978c812338e2f680ee23 PiperOrigin-RevId: 690990311 --- jax/_src/custom_derivatives.py | 67 ++++++++++++---------------------- tests/api_test.py | 16 +------- 2 files changed, 25 insertions(+), 58 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 216fb51f6a46..f5ecdfcda286 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -41,7 +41,7 @@ from jax._src.interpreters.batching import not_mapped from jax._src.lax import lax from jax._src.tree_util import ( - tree_flatten, tree_unflatten, tree_map, treedef_tuple, + tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr, treedef_children) from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable, @@ -1029,51 +1029,32 @@ def custom_gradient(fun): >>> print(jax.grad(f, argnums=(0, 1))(3., 4.)) (Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True)) """ - # TODO(mattjj): better debug info + @custom_vjp def wrapped_fun(*args, **kwargs): - args_flat, in_tree = tree_flatten((args, kwargs)) - in_avals = [core.get_aval(x) for x in args_flat] - primal_jaxpr, fwd_jaxpr, bwd_jaxpr, consts, out_tree = \ - _primal_fwd_bwd(in_tree, in_avals) - - @custom_vjp - def primal(consts, args): - return core.eval_jaxpr(primal_jaxpr, (), *consts, *args) - def fwd(consts, args): - ans_res = core.eval_jaxpr(fwd_jaxpr, (), *consts, *args) - return split_list(ans_res, [out_tree.num_leaves]) - def bwd(res, cts): - return None, core.eval_jaxpr(bwd_jaxpr, res, *cts) - primal.defvjp(fwd, bwd) - - out_flat = primal(consts, args_flat) - return tree_unflatten(out_tree, out_flat) - - def _primal_fwd_bwd(in_tree, in_avals): - out_tree, rule_jaxpr = None, None - @lu.wrap_init - def run(*args_flat): - nonlocal rule_jaxpr, out_tree - args, kwargs = tree_unflatten(in_tree, args_flat) - ans, rule = fun(*args, **kwargs) - ans_flat, out_tree = tree_flatten((ans,)) - ans_bar_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] - rule_, in_tree_ = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) - rule_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule_, ans_bar_avals) - out_tree, = treedef_children(out_tree) - return *ans_flat, *consts - fwd_jaxpr, _, fwd_consts, () = pe.trace_to_jaxpr_dynamic(run, in_avals) - fwd_jaxpr = pe.convert_constvars_jaxpr(fwd_jaxpr) - assert out_tree is not None and rule_jaxpr is not None - num_ans = out_tree.num_leaves - num_res = len(fwd_jaxpr.outvars) - num_ans - primal_jaxpr, _ = pe.dce_jaxpr(fwd_jaxpr, - [True] * num_ans + [False] * num_res, True) - return primal_jaxpr, fwd_jaxpr, rule_jaxpr, fwd_consts, out_tree - + ans, _ = fun(*args, **kwargs) + return ans + + def fwd(*args, **kwargs): + ans, rule = fun(*args, **kwargs) + ans_flat, out_tree = tree_flatten((ans,)) + rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) + ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) + return ans, Residuals(jaxpr, in_tree(), out_tree, consts) + + def bwd(res, cts): + jaxpr, in_tree, out_tree, consts = res + cts_flat, out_tree_ = tree_flatten((cts,)) + if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}') + cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat) + cts_out = tree_unflatten(in_tree, cts_out) + if treedef_is_leaf(in_tree): + cts_out = (cts_out,) + return cts_out + + wrapped_fun.defvjp(fwd, bwd) return wrapped_fun - @register_pytree_node_class class Residuals: def __init__(self, jaxpr, in_tree, out_tree, consts): diff --git a/tests/api_test.py b/tests/api_test.py index e7ac95baa55b..2c2412093805 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -8976,27 +8976,13 @@ def f(x): vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) return jnp.sum(jnp.sin(x)), vjp - self.assertAllClose(f(jnp.arange(3.)), jnp.sum(jnp.sin(jnp.arange(3.))), + self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), check_dtypes=False) self.assertAllClose( api.grad(f)(jnp.arange(3.)), api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), check_dtypes=False) - def test_custom_gradient_jit_closure(self): - @jax.jit - def f(x, y): - y = jnp.sin(y) - - @jax.custom_gradient - def g(x): - return y * jnp.sin(x), lambda g: (y * jnp.cos(x) * g,) - - return g(x) - - g = jax.grad(f)(1., 2.) - self.assertAllClose(g, jnp.sin(2.) * jnp.cos(1.), check_dtypes=False) - def test_custom_gradient_can_return_singleton_value_in_vjp(self): @jax.custom_gradient def f(x): From bee2bc443a2d56c154160664b290533c487d3738 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Oct 2024 09:29:56 -0400 Subject: [PATCH 097/698] Remove some dead code from gpu_prng.py --- jaxlib/gpu_prng.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index e364b91e278c..b96040acd614 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -13,12 +13,9 @@ # limitations under the License. from __future__ import annotations - -import functools from functools import partial import importlib import itertools -import operator import jaxlib.mlir.ir as ir @@ -61,8 +58,6 @@ xla_client.register_custom_call_target(_name, _value, platform="ROCM", api_version=api_version) -_prod = lambda xs: functools.reduce(operator.mul, xs, 1) - def _threefry2x32_lowering(prng, platform: str, keys, data, length: int | ir.Value | None = None, From eff6cb445b769e2b0aa0bbff0888f4d3c6713b43 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 29 Oct 2024 07:36:12 -0700 Subject: [PATCH 098/698] [export] Enable more cross-platform lowering tests for GPU. Thanks to a lot of work by Dan Foreman-Mackey and others, there has been much progress in how we lower linalg primitives for GPU and we can now enable cross-platform lowering tests for these primitives. PiperOrigin-RevId: 691013252 --- tests/export_harnesses_multi_platform_test.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 0f0c20fd78e3..c74eec550342 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -45,24 +45,6 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: else: return re.compile("(" + "|".join(parts) + ")") -# TODO(necula): Failures to be investigated (on GPU). -_known_failures_gpu = make_disjunction_regexp( - # Failures on GPU due to failure to export custom call targets, these - # involve GPU custom call targets withoutbackwards compatibility tests. - "custom_linear_solve_", - "lu_", - "svd_", - "tridiagonal_solve_", -) - -# Some primitive lowering rules need the GPU backend to be able to create -# CUDA lowering. -_skip_cuda_lowering_unless_have_gpus = make_disjunction_regexp( - "svd_", "lu_", "eigh_", "qr_", "custom_linear_", "tridiagonal_solve_", - # TODO(b/350111820): random should work once we enable FFI threefry2x32 - "random_", -) - class PrimitiveTest(jtu.JaxTestCase): @@ -105,8 +87,8 @@ def test_prim(self, harness: test_harnesses.Harness): "decompositions for equality.") if (jtu.device_under_test() == "gpu" - and _known_failures_gpu.search(harness.fullname)): - self.skipTest("failure to be investigated") + and "tridiagonal_solve_" in harness.fullname): + self.skipTest("tridiagonal_solve_ is not yet guaranteed stable.") if harness.params.get("enable_xla", False): self.skipTest("enable_xla=False is not relevant") @@ -118,11 +100,14 @@ def test_prim(self, harness: test_harnesses.Harness): for l in harness.jax_unimplemented: if l.filter(dtype=harness.dtype): unimplemented_platforms = unimplemented_platforms.union(l.devices) - if (_skip_cuda_lowering_unless_have_gpus.search(harness.fullname) + # Some primitive lowering rules need the GPU backend to be able to create + # CUDA lowering. + if ("tridiagonal_solve_" in harness.fullname and all(d.platform != "gpu" for d in self.devices)): unimplemented_platforms.add("gpu") - logging.info("Harness is not implemented on %s", unimplemented_platforms) + if unimplemented_platforms: + logging.info("Harness is not implemented on %s", unimplemented_platforms) # Tolerances. tol = None From 1785479cbd7203a90a11368df005477ae4d5b746 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 29 Oct 2024 10:41:59 -0400 Subject: [PATCH 099/698] Fix segfault caused by uninitialized LAPACK in FFI test. --- tests/extend_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/extend_test.py b/tests/extend_test.py index 69b4591f3e85..84a907c7331d 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import sys import unittest from functools import partial @@ -35,6 +34,7 @@ from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.layout import DeviceLocalLayout +from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal @@ -261,10 +261,6 @@ def testFfiCallBatching(self, shape, vmap_method): @jtu.run_on_devices("gpu", "cpu") def testVectorizedDeprecation(self): - if sys.version_info.major == 3 and sys.version_info.minor == 13: - # TODO(b/376025274): Remove the skip once the bug is fixed. - raise unittest.SkipTest("Crashes on Python 3.13") - x = self.rng().randn(3, 5, 4).astype(np.float32) with self.assertWarns(DeprecationWarning): ffi_call_geqrf(x, vectorized=True) @@ -332,6 +328,9 @@ def fun(x): def ffi_call_geqrf(x, **kwargs): + if jtu.test_device_matches(["cpu"]): + lapack._lapack.initialize() + assert x.dtype == np.float32 ndim = x.ndim x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) From 03854cfce449263bbd4b95a3be99590484bacc60 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 24 Oct 2024 13:44:18 -0400 Subject: [PATCH 100/698] Allow dot algorithms in default_matmul_precision config. --- jax/_src/config.py | 18 ++++++++++++++++-- tests/lax_test.py | 12 ++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index a05e6e190d44..80804832476b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1469,7 +1469,16 @@ def _update_disable_jit_thread_local(val): default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', - enum_values=['default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32'], + enum_values=[ + # Legacy precision API values + 'default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32', + # Dot algorithm presets + 'ANY_F8_ANY_F8_F32', 'ANY_F8_ANY_F8_F32_FAST_ACCUM', 'ANY_F8_ANY_F8_ANY', + 'ANY_F8_ANY_F8_ANY_FAST_ACCUM', 'F16_F16_F16', 'F16_F16_F32', + 'BF16_BF16_BF16', 'BF16_BF16_F32', 'BF16_BF16_F32_X3', + 'BF16_BF16_F32_X6', 'TF32_TF32_F32', 'TF32_TF32_F32_X3', 'F32_F32_F32', + 'F64_F64_F64', + ], default=None, help=('Control the default matmul and conv precision for 32bit inputs.\n\n' @@ -1486,7 +1495,12 @@ def _update_disable_jit_thread_local(val): 'convolution on 32bit inputs. The levels roughly describe the ' "precision at which scalar products are computed. The 'bfloat16' " "option is the fastest and least precise; 'float32' is similar to " - "full float32 precision; 'tensorfloat32' is intermediate.\n\n"), + "full float32 precision; 'tensorfloat32' is intermediate.\n\n" + + 'This parameter can also be used to specify an accumulation ' + '"algorithm" for functions that perform matrix multiplications, like ' + ':func:`jax.lax.dot`. To specify an algorithm, set this option to ' + 'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'), update_global_hook=lambda val: \ _update_global_jit_state(default_matmul_precision=val), update_thread_local_hook=lambda val: \ diff --git a/tests/lax_test.py b/tests/lax_test.py index 4d20240e1940..79ad9fcfa02c 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1156,6 +1156,18 @@ def fun(lhs, rhs): lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) self.assertEqual(fun(lhs, rhs).dtype, np.float16) + def testDotAlgorithmConfig(self): + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + + expected = ("algorithm = Date: Tue, 29 Oct 2024 15:23:53 +0000 Subject: [PATCH 101/698] Fix missing f-string format in slogdet error message --- jax/_src/numpy/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 79b47d9090af..14798f6f6913 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -532,7 +532,7 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: a, = promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: - raise ValueError("Argument to slogdet() must have shape [..., n, n], got {a_shape}") + raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}") if method is None or method == "lu": return SlogdetResult(*_slogdet_lu(a)) elif method == "qr": From c36e1f7c1ad4782060cbc8e8c596d85dfb83986f Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 29 Oct 2024 11:03:49 -0700 Subject: [PATCH 102/698] Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on. PiperOrigin-RevId: 691086496 --- jax/_src/ad_checkpoint.py | 11 +- jax/_src/api.py | 29 +- jax/_src/callback.py | 1 - jax/_src/config.py | 24 +- jax/_src/core.py | 833 ++++++---------------- jax/_src/custom_batching.py | 6 +- jax/_src/custom_derivatives.py | 129 +--- jax/_src/custom_partitioning.py | 4 +- jax/_src/custom_transpose.py | 15 +- jax/_src/dispatch.py | 3 +- jax/_src/dtypes.py | 2 +- jax/_src/interpreters/ad.py | 196 +++-- jax/_src/interpreters/batching.py | 539 ++++++-------- jax/_src/interpreters/partial_eval.py | 644 ++++++----------- jax/_src/interpreters/pxla.py | 133 ++-- jax/_src/lax/control_flow/__init__.py | 1 - jax/_src/lax/control_flow/conditionals.py | 46 +- jax/_src/lax/control_flow/for_loop.py | 14 +- jax/_src/lax/control_flow/loops.py | 55 +- jax/_src/lax/control_flow/solves.py | 22 +- jax/_src/lax/lax.py | 19 +- jax/_src/lax/parallel.py | 326 +++++---- jax/_src/linear_util.py | 22 +- jax/_src/numpy/array_methods.py | 1 - jax/_src/pallas/core.py | 11 - jax/_src/pallas/mosaic/primitives.py | 10 +- jax/_src/pallas/primitives.py | 15 +- jax/_src/pjit.py | 48 +- jax/_src/state/discharge.py | 16 - jax/_src/test_util.py | 6 +- jax/core.py | 34 +- jax/experimental/attrs.py | 79 +- jax/experimental/jax2tf/jax2tf.py | 83 +-- jax/experimental/jet.py | 76 +- jax/experimental/multihost_utils.py | 10 +- jax/experimental/shard_map.py | 424 ++++------- jax/experimental/sparse/transform.py | 89 +-- jax/interpreters/ad.py | 3 - jax/interpreters/batching.py | 2 +- jax/interpreters/partial_eval.py | 7 - jax/lax/__init__.py | 1 - tests/api_test.py | 35 +- tests/for_loop_test.py | 8 +- tests/infeed_test.py | 2 + tests/lax_control_flow_test.py | 1 + tests/pmap_test.py | 2 +- tests/xla_metadata_test.py | 5 +- 47 files changed, 1411 insertions(+), 2631 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index f5d5be6a2751..5ed0b0192a7b 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -701,20 +701,17 @@ def transposed(*args_flat): transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error -def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, - jaxpr, **params): +def remat_vmap(axis_data, args, dims, *, jaxpr, **params): assert not jaxpr.constvars jaxpr_batched_, out_batched = batching.batch_jaxpr_axes( - pe.close_jaxpr(jaxpr), axis_size, dims, - [batching.zero_if_mapped] * len(jaxpr.outvars), - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + pe.close_jaxpr(jaxpr), axis_data, dims, + [batching.zero_if_mapped] * len(jaxpr.outvars)) jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts if consts: jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched) out_dims = [0 if b else None for b in out_batched] return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims -batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None) -batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap +batching.fancy_primitive_batchers[remat_p] = remat_vmap # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn diff --git a/jax/_src/api.py b/jax/_src/api.py index 0c46517b2191..390d3ea337bb 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -34,7 +34,7 @@ import weakref import numpy as np -from contextlib import contextmanager, ExitStack +from contextlib import contextmanager from jax._src import linear_util as lu from jax._src import stages @@ -989,10 +989,10 @@ def vmap_f(*args, **kwargs): axis_size_ = (axis_size if axis_size is not None else _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) try: + axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name) out_flat = batching.batch( - flat_fun, axis_name, axis_size_, in_axes_flat, - lambda: flatten_axes("vmap out_axes", out_tree(), out_axes), - spmd_axis_name=spmd_axis_name + flat_fun, axis_data, in_axes_flat, + lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) ).call_wrapped(*args_flat) except batching.SpecMatchError as e: out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes) @@ -1546,16 +1546,13 @@ def cache_miss(*args, **kwargs): is_explicit_global_axis_size=p.is_explicit_global_axis_size, ) - map_bind_continuation, top_trace, fun_, tracers, params = ( - core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun, - *p.flat_args, **params)) execute: Callable | None = None - if isinstance(top_trace, core.EvalTrace): - execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params) - out = map_bind_continuation(execute(*tracers)) - else: - out = map_bind_continuation( - pxla.xla_pmap_p.process(top_trace, fun_, tracers, params)) + with core.take_current_trace() as trace: + if isinstance(trace, core.EvalTrace): + execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params) + out = execute(*p.flat_args) + else: + out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) out_tree, out_flat = p.out_tree, out out_pytree_def = out_tree() @@ -1802,7 +1799,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) - (Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True)) + (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 @@ -2160,9 +2157,7 @@ def make_jaxpr( @wraps(fun) @api_boundary def make_jaxpr_f(*args, **kwargs): - with ExitStack() as stack: - for axis_name, size in axis_env or []: - stack.enter_context(core.extend_axis_env(axis_name, size, None)) + with core.extend_axis_env_nd(axis_env or []): traced = jit(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes).trace(*args, **kwargs) # `jit` converts tracers in consts to args but that breaks the semantics of diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 0b918c7a994e..71886b453bef 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -633,7 +633,6 @@ def io_callback( flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype), flat_shape_dtypes) - flat_args = map(core.raise_as_much_as_possible, flat_args) out_flat = io_callback_p.bind( *flat_args, callback=_FlatCallback(callback, in_tree), diff --git a/jax/_src/config.py b/jax/_src/config.py index a05e6e190d44..533f0a1b528d 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -217,7 +217,9 @@ def trace_context(): return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, compute_on_context_manager, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, numpy_dtype_promotion.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, threefry_partitionable.value, threefry_gpu_kernel_lowering.value, @@ -832,6 +834,7 @@ class _GlobalExtraJitContext(NamedTuple): numpy_dtype_promotion: str | None = None default_matmul_precision: Any | None = None dynamic_shapes: bool = False + eager_constant_folding: bool = False random_seed_offset: int = 0 threefry_partitionable: bool = False threefry_gpu_kernel_lowering: bool = False @@ -858,7 +861,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): The initialization, which uses both config.py and core.py is done using `_update_thread_local_jit_state` in core.py to prevent circular imports. """ - dynamic_trace_state: Any | None = None + trace_state: Any | None = None axis_env_state: Hashable = () mesh_context_manager: Hashable = () compute_on_context_manager: Hashable = () @@ -873,6 +876,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): numpy_dtype_promotion: str | None = None default_matmul_precision: Any | None = None dynamic_shapes: bool | None = None + eager_constant_folding : bool | None = None random_seed_offset: int | None = None threefry_partitionable: bool | None = None threefry_gpu_kernel_lowering: bool | None = None @@ -909,7 +913,6 @@ def update_thread_local_jit_state(**kw): tmp = context._replace(**kw) tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) - # TODO(b/214340779): remove flag when XLA:CPU is improved. jax2tf_associative_scan_reductions = bool_state( name='jax2tf_associative_scan_reductions', @@ -1163,6 +1166,11 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: update_thread_local_jit_state( sharding_in_types=val)) +data_dependent_tracing_fallback = bool_state( + name='jax_data_dependent_tracing_fallback', + default=False, + help=('When True, falls back to trace dispatch based on data dependence ' + 'instead of throwing an escaped tracer error.')) softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', @@ -1530,6 +1538,16 @@ def _update_disable_jit_thread_local(val): update_thread_local_hook=lambda val: \ update_thread_local_jit_state(dynamic_shapes=val)) +# This is for stackless backward compat with e.g. equinox +eager_constant_folding = bool_state( + name='eager_constant_folding', + default=False, + help=('Attempt constant folding during staging.'), + update_global_hook=lambda val: \ + _update_global_jit_state(eager_constant_folding=val), + update_thread_local_hook=lambda val: \ + update_thread_local_jit_state(eager_constant_folding=val)) + # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. remat_opt_barrier = bool_state( diff --git a/jax/_src/core.py b/jax/_src/core.py index 8379ce5e070f..2a2a0d601848 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -14,9 +14,8 @@ from __future__ import annotations from collections import Counter, defaultdict, deque, namedtuple -from collections.abc import (Callable, Collection, Generator, Hashable, - Iterable, Iterator, Set, Sequence, MutableSet, - MutableMapping) +from collections.abc import (Callable, Collection, Hashable, Iterable, Iterator, + Sequence, MutableSet, MutableMapping) from contextlib import contextmanager, ExitStack from dataclasses import dataclass import functools @@ -29,7 +28,7 @@ import threading import types from typing import (Any, ClassVar, Generic, NamedTuple, TypeVar, - cast, overload, Union) + overload, Union) import warnings from weakref import ref @@ -47,7 +46,7 @@ from jax._src import source_info_util from jax._src.util import (safe_zip, safe_map, curry, tuple_insert, - tuple_delete, as_hashable_function, + tuple_delete, HashableFunction, HashableWrapper, weakref_lru_cache, partition_list, StrictABCMeta) import jax._src.pretty_printer as pp @@ -433,14 +432,17 @@ def __repr__(self): return f'{self.name}' def bind(self, *args, **params): - assert (not config.enable_checks.value or - all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args - return self.bind_with_trace(find_top_trace(args), args, params) + for arg in args: + if isinstance(arg, Tracer) and not arg._trace.is_valid(): + raise escaped_tracer_error(arg) + # TODO: figure out how to handle function arguments + # assert (not config.enable_checks.value or + # all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args + with take_current_trace() as cur_trace: + return self.bind_with_trace(cur_trace, args, params) def bind_with_trace(self, trace, args, params): - with pop_level(trace.level): - out = trace.process_primitive(self, map(trace.full_raise, args), params) - return map(full_lower, out) if self.multiple_results else full_lower(out) + return trace.process_primitive(self, args, params) def def_impl(self, impl): self.impl = impl @@ -454,9 +456,9 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval): self.abstract_eval = effectful_abstract_eval return effectful_abstract_eval - def def_custom_bind(self, bind): - self.bind = bind - return bind + def def_bind_with_trace(self, bind_with_trace): + self.bind_with_trace = bind_with_trace + return bind_with_trace def impl(self, *args, **params): raise NotImplementedError("Evaluation rule for '{}' not implemented" @@ -519,65 +521,18 @@ def write(v: Var, val: Any) -> None: TracerType = TypeVar('TracerType', bound='Tracer') class Trace(Generic[TracerType]): - __slots__ = ['main', 'level', 'sublevel'] - - main: MainTrace - level: int - sublevel: Sublevel - - def __init__(self, main: MainTrace, sublevel: Sublevel) -> None: - self.main = main - self.level = main.level - self.sublevel = sublevel - - def full_raise(self, val) -> TracerType: - if not isinstance(val, Tracer): - # This check is only applied to non-Tracers, because the hasattr() is - # expensive (Tracer.__getattr__) in the common case that val is a Tracer. - if hasattr(val, "dimension_as_value"): # Used for shape_poly._DimExpr - val = val.dimension_as_value() - if not isinstance(val, Tracer): - return self.pure(val) - else: - return self.pure(val) - val._assert_live() - level = self.level - sublevel = self.sublevel - if val._trace.main is self.main: - if val._trace.sublevel == sublevel: - return cast(TracerType, val) - elif val._trace.sublevel < sublevel: - return self.sublift(val) - else: - raise escaped_tracer_error( - val, f"Can't lift sublevels {val._trace.sublevel} to {sublevel}") - elif val._trace.level < level: - if val._trace.sublevel > sublevel: - raise escaped_tracer_error( - val, f"Incompatible sublevel: {val._trace}, {(level, sublevel)}") - return self.lift(val) - elif val._trace.level > level: - raise escaped_tracer_error( - val, f"Can't lift level {val} to {self}") - else: # val._trace.level == self.level: - raise escaped_tracer_error( - val, f"Different traces at same level: {val}, {self}") - - def pure(self, val) -> TracerType: - raise NotImplementedError("must override") - def lift(self, tracer) -> TracerType: + def process_primitive(self, primitive, tracers, params): raise NotImplementedError("must override") - def sublift(self, tracer) -> TracerType: - raise NotImplementedError("must override") + def invalidate(self): + self._invalidated = True - def process_primitive(self, primitive, tracers, params): - raise NotImplementedError("must override") + def is_valid(self): + return not hasattr(self, "_invalidated") def __repr__(self): - return '{}(level={}/{})'.format( - self.__class__.__name__, self.level, self.sublevel) + return '{}'.format(self.__class__.__name__) def process_call(self, call_primitive, f, tracers, params): msg = (f"{type(self)} must override process_call to handle call-like " @@ -606,24 +561,14 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, "to handle custom_vjp primitives") raise NotImplementedError(msg) + # TODO(dougalm): deprecate/delete + def full_raise(self, x): + return x -def raise_as_much_as_possible(tracer) -> Tracer: - # Find effective bottom of trace stack (highest dynamic Trace on the stack). - trace_stack = thread_local_state.trace_state.trace_stack.stack - idx = next(i for i, m in enumerate(trace_stack) if m is - thread_local_state.trace_state.trace_stack.dynamic) - - # Only pay attention to effective part of trace stack. - trace_stack = trace_stack[idx:] - - # Lift tracer into everything in the effective stack higher than its level - for trace in trace_stack: - trace = trace.with_cur_sublevel() - if (not isinstance(tracer, Tracer) or tracer._trace.level < trace.level): - tracer = trace.full_raise(tracer) - - return tracer - + # TODO(dougalm): deprecate/delete + @property + def main(self): + return getattr(self, "tag", None) def escaped_tracer_error(tracer, detail=None): num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value @@ -729,6 +674,10 @@ def tobytes(self, order="C"): f"The tobytes() method was called on {self._error_repr()}." f"{self._origin_msg()}") + # TODO(dougalm): deprecate/delete + def full_lower(self): + raise NotImplementedError("must override: ", type(self)) + def __iter__(self): return iter(self.aval._iter(self)) @@ -777,9 +726,6 @@ def at(self): def aval(self): raise NotImplementedError("must override") - def _assert_live(self) -> None: - pass # Override for liveness checking - def get_referent(self) -> Any: return self # Override for object equivalence checking @@ -809,7 +755,7 @@ def __oct__(self): def __index__(self): check_integer_conversion(self) - raise self.aval._index(self) + return self.aval._index(self) # raises a useful error on attempts to pickle a Tracer. def __reduce__(self): @@ -940,19 +886,23 @@ def unsafe_buffer_pointer(self): aval_property = namedtuple("aval_property", ["fget"]) aval_method = namedtuple("aval_method", ["fun"]) - class EvalTrace(Trace): - # See comments in https://github.com/jax-ml/jax/pull/3370 - def pure(self, x): return x - lift = sublift = pure - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, args, params): if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error - return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params) + return call_impl_with_key_reuse_checks(primitive, primitive.impl, *args, **params) else: - return primitive.impl(*tracers, **params) + # TODO(dougalm): delete. this shouldn't be necessary + args = map(full_lower, args) + for arg in args: + if isinstance(arg, Tracer): + if config.data_dependent_tracing_fallback.value: + return primitive.bind_with_trace(arg._trace, args, params) + else: + raise escaped_tracer_error(arg) + return primitive.impl(*args, **params) def process_call(self, primitive, f, tracers, params): if config.debug_key_reuse.value: @@ -965,128 +915,134 @@ def process_call(self, primitive, f, tracers, params): def process_custom_transpose(self, primitive, call, tracers, **_): del primitive, _ - with new_sublevel(): - return call.call_wrapped(*tracers) + return call.call_wrapped(*tracers) def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_): del primitive, jvp, _ # Unused. - with new_sublevel(): - return fun.call_wrapped(*tracers) + return fun.call_wrapped(*tracers) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch del primitive, fwd, bwd, _ # Unused. - with new_sublevel(): - return fun.call_wrapped(*tracers) - - -class MainTrace: - level: int - trace_type: type[Trace] - payload: dict[str, Any] - - def __init__(self, level, trace_type, **payload) -> None: - self.level = level - self.trace_type = trace_type - self.payload = payload - - def __repr__(self) -> str: - return f"MainTrace({self.level},{self.trace_type.__name__})" - - def __hash__(self) -> int: - return hash((self.level, self.trace_type)) - - def __eq__(self, other: object) -> bool: - return (isinstance(other, MainTrace) and - self.level == other.level and - self.trace_type == other.trace_type and - self.payload == other.payload) - - def with_cur_sublevel(self): - return self.trace_type(self, cur_sublevel(), **self.payload) - -class TraceStack: - # See comments in https://github.com/jax-ml/jax/pull/3370 - stack: list[MainTrace] - dynamic: MainTrace + return fun.call_wrapped(*tracers) + + +class TraceTag: + # TODO: this works for surprisingly subtle reasons. Function transformations + # like `jvp_subtrace` are parameterized by a tag that identifies the set of + # pre-existing tracers we want to unpack during the transformation. A function + # defined in an outer scope can't have any closed-over traces, so the tag is + # irrelevant. A function defined in the current scope may have closed-over + # traces, but the tag will never change so we'll never get a spurious cache + # hit. The plan is to do away with `lu.cache` altogether, and use a simpler + # caching scheme that only caches top-level functions. Then we can remove this + # hack. + def __hash__(self): + return hash(TraceTag) + def __eq__(self, other): + return isinstance(other, TraceTag) - def __init__(self): - eval_trace = MainTrace(0, EvalTrace) - self.stack = [eval_trace] - self.dynamic = eval_trace +ParamDict = dict[str, Any] +AxisName = Hashable - def next_level(self) -> int: - return len(self.stack) +no_axis_name = object() - def push(self, main_trace: MainTrace) -> None: - self.stack.append(main_trace) +@dataclass(frozen=True) +class AxisEnv: + axis_sizes : dict[AxisName, int] - def pop(self) -> None: - self.stack.pop() + def axis_size(self, axis_name): + if axis_name not in self.axis_sizes: + raise NameError(f"unbound axis name: {axis_name}") + else: + return self.axis_sizes[axis_name] - def __repr__(self) -> str: - stack_str = map(' {}\n'.format, self.stack[::-1]) - return f'Trace stack\n{stack_str}\n{self.dynamic}' + def axis_exists(self, axis_name): + return axis_name in self.axis_sizes - def copy(self): - new = self.__new__(TraceStack) - new.stack = self.stack[:] - new.dynamic = self.dynamic - return new + def axis_names(self): + return tuple(k for k in self.axis_sizes) + def pop_pure(self, axis_name): + new_sizes = self.axis_sizes.copy() + new_sizes.pop(axis_name) + return AxisEnv(new_sizes) -@total_ordering -class Sublevel: + def extend_pure(self, name_size_pairs): + new_sizes = self.axis_sizes.copy() + new_sizes.update((name, size) for name, size in name_size_pairs + if name is not no_axis_name) + return AxisEnv(new_sizes) - def __init__(self, level: int): - self.level = level + def as_hashable_key(self): + return tuple((name, size) for (name, size) in self.axis_sizes.items() + if name is not no_axis_name) - def __repr__(self): - return str(self.level) +eval_trace = EvalTrace() +top_axis_env = AxisEnv({}) - def __eq__(self, other): - return type(other) is Sublevel and self.level == other.level +class TracingContext(threading.local): + trace: Trace | None + axis_env : AxisEnv - def __lt__(self, other): - return type(other) is Sublevel and self.level < other.level + def __init__(self): + self.reset() + def reset(self): + self.trace = eval_trace + self.axis_env = top_axis_env -AxisEnvFrame = namedtuple('AxisEnvFrame', ['name', 'size', 'main_trace']) -AxisName = Hashable - -no_axis_name = object() + def is_top_level(self) -> bool: + return (self.trace is eval_trace and + self.axis_env is top_axis_env) -class TraceState: - trace_stack: TraceStack - substack: list[Sublevel] - axis_env: list[AxisEnvFrame] + def set_trace(self, trace): + self.trace = trace + ts = ref(trace) if trace is not None else None + config.update_thread_local_jit_state(trace_state=ts) - def __init__(self) -> None: - self.trace_stack = TraceStack() - self.substack = [Sublevel(0)] - self.axis_env = [] + def set_axis_env(self, axis_env): + self.axis_env = axis_env + config.update_thread_local_jit_state( + axis_env_state=self.axis_env.as_hashable_key()) - def copy(self): - new = self.__new__(TraceState) - new.trace_stack = self.trace_stack.copy() - new.substack = self.substack[:] - new.axis_env = self.axis_env[:] - return new + def update_thread_local_jit_state(self): + ts = ref(self.trace) if self.trace is not None else None + config.update_thread_local_jit_state( + trace_state=ts, + axis_env_state=self.axis_env.as_hashable_key()) +trace_ctx = TracingContext() -def _update_thread_local_jit_state(dynamic): - state = (dynamic.level, dynamic.trace_type) - config.update_thread_local_jit_state(dynamic_trace_state=state) +@contextmanager +def take_current_trace(): + prev = trace_ctx.trace + try: + trace_ctx.set_trace(eval_trace) + yield prev + finally: + trace_ctx.set_trace(prev) -# The global state of the tracer is accessed by a thread-local object. -# This allows concurrent tracing in separate threads; passing traced objects -# between threads is forbidden. -class ThreadLocalState(threading.local): - def __init__(self): - self.trace_state = TraceState() +@contextmanager +def set_current_trace(new): + prev = trace_ctx.trace + try: + trace_ctx.set_trace(new) + yield + finally: + trace_ctx.set_trace(prev) -thread_local_state = ThreadLocalState() +@contextmanager +def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]): + prev = trace_ctx.axis_env + try: + trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs)) + yield + finally: + trace_ctx.set_axis_env(prev) +def get_axis_env(): + return trace_ctx.axis_env def _initialize_jax_jit_thread_local_state(): """Initializes the C++ thread-local context. @@ -1098,33 +1054,25 @@ def _initialize_jax_jit_thread_local_state(): This function does not live in `config.py`, to prevent circular imports. """ tls = jax_jit.thread_local_state() - if tls.extra_jit_context is None: - dynamic = thread_local_state.trace_state.trace_stack.dynamic - state = (dynamic.level, dynamic.trace_type) - config.update_thread_local_jit_state(dynamic_trace_state=state) + if tls.extra_jit_context is None: + trace_ctx.update_thread_local_jit_state() jax_jit.set_thread_local_state_initialization_callback( _initialize_jax_jit_thread_local_state) def trace_state_clean() -> bool: - trace_state = thread_local_state.trace_state - return (trace_state.substack == [Sublevel(0)] and - trace_state.axis_env == [] and - trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and - trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace)) + return trace_ctx.is_top_level() def reset_trace_state() -> bool: """Resets the global trace state and returns True if it was already clean.""" - if not trace_state_clean(): - thread_local_state.trace_state.__init__() + if not trace_ctx.is_top_level(): + trace_ctx.reset() + trace_ctx.update_thread_local_jit_state() return False else: return True -def cur_sublevel() -> Sublevel: - return thread_local_state.trace_state.substack[-1] - TRACER_LEAK_DEBUGGER_WARNING = """\ JAX check_tracer_leaks behavior can trigger false positives when used with a debugger. To avoid false positives and silence this warning, you can disable thread tracing using @@ -1134,13 +1082,21 @@ def cur_sublevel() -> Sublevel: threading.current_thread().pydev_do_not_trace = True """ -def maybe_find_leaked_tracers(x: MainTrace | Sublevel | None - ) -> list[Tracer]: - """Find the leaked tracers holding a reference to the MainTrace or SubLevel. +@contextmanager +def ensure_no_leaks(trace:Trace): + yield + trace.invalidate() + if config.check_tracer_leaks.value: + trace_ref = ref(trace) + del trace + live_trace = trace_ref() + if live_trace is not None: + leaked_tracers = maybe_find_leaked_tracers(live_trace) + if leaked_tracers: + raise leaked_tracer_error("trace", live_trace, leaked_tracers) - It's possible there's none! eg. there's some cases where JAX itself holds a - reference to `x` inside of a lambda closure, and no tracers were leaked - by the user. In this case an empty list is returned. +def maybe_find_leaked_tracers(trace: Trace) -> list[Tracer]: + """Find the leaked tracers holding a reference to the Trace """ if not getattr(threading.current_thread(), 'pydev_do_not_trace', True): warnings.warn(TRACER_LEAK_DEBUGGER_WARNING) @@ -1148,8 +1104,7 @@ def maybe_find_leaked_tracers(x: MainTrace | Sublevel | None # only due to cyclical dependencies. (We don't care about unreachable leaked # tracers since they can't interact with user code and cause a problem.) gc.collect() - traces = list(filter(lambda x: isinstance(x, Trace), gc.get_referrers(x))) - tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(*traces))) + tracers = list(filter(lambda x: isinstance(x, Tracer), gc.get_referrers(trace))) return tracers def leaked_tracer_error(name: str, t, tracers: list[Tracer]) -> Exception: @@ -1216,83 +1171,6 @@ def _why_alive_container_info(container, obj_id) -> str: return f' named {container.__name__}' return name - -@contextmanager -def new_main(trace_type: type[Trace], dynamic: bool = False, - **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/jax-ml/jax/pull/3370 - stack = thread_local_state.trace_state.trace_stack - level = stack.next_level() - main = MainTrace(level, trace_type, **payload) - stack.push(main) - if dynamic: - prev_dynamic, stack.dynamic = stack.dynamic, main - _update_thread_local_jit_state(stack.dynamic) - - try: - yield main - finally: - stack.pop() - if dynamic: - stack.dynamic = prev_dynamic - _update_thread_local_jit_state(stack.dynamic) - - if config.check_tracer_leaks.value: - t = ref(main) - del main - if t() is not None: - leaked_tracers = maybe_find_leaked_tracers(t()) - if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers) - -@contextmanager -def new_dynamic(level: int) -> Generator[None, None, None]: - stack = thread_local_state.trace_state.trace_stack - prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level] - _update_thread_local_jit_state(stack.dynamic) - try: - yield - finally: - stack.dynamic = prev_dynamic - _update_thread_local_jit_state(stack.dynamic) - -def dynamic_level() -> int: - return thread_local_state.trace_state.trace_stack.dynamic.level - -@contextmanager -def new_base_main(trace_type: type[Trace], - **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/jax-ml/jax/pull/3370 - stack = thread_local_state.trace_state.trace_stack - main = MainTrace(0, trace_type, **payload) - prev_dynamic, stack.dynamic = stack.dynamic, main - prev_base, stack.stack[0] = stack.stack[0], main - _update_thread_local_jit_state(stack.dynamic) - try: - yield main - finally: - stack.dynamic = prev_dynamic - stack.stack[0] = prev_base - _update_thread_local_jit_state(stack.dynamic) - - if config.check_tracer_leaks.value: - t = ref(main) - del main - if t() is not None: - leaked_tracers = maybe_find_leaked_tracers(t()) - if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers) - -@contextmanager -def pop_level(level: int): - if level == 0: - return (yield) # noqa: B901 - prev, thread_local_state.trace_state.trace_stack.stack = \ - thread_local_state.trace_state.trace_stack.stack, \ - thread_local_state.trace_state.trace_stack.stack[:level] - try: - yield - finally: - thread_local_state.trace_state.trace_stack.stack = prev - @contextmanager def ensure_compile_time_eval(): """Context manager to ensure evaluation at trace/compile time (or error). @@ -1353,50 +1231,21 @@ def jax_fn(x): But in some cases it can be more convenient to use this context manager. """ - with new_base_main(EvalTrace): + with config.eager_constant_folding(True): yield -eval_context = ensure_compile_time_eval # alias, backward compatibility @contextmanager -def new_sublevel() -> Generator[None, None, None]: - sublevel = Sublevel(len(thread_local_state.trace_state.substack)) - thread_local_state.trace_state.substack.append(sublevel) - try: +def eval_context(): + with set_current_trace(eval_trace): yield - finally: - thread_local_state.trace_state.substack.pop() - - if config.check_tracer_leaks.value: - t = ref(sublevel) - del sublevel - if t() is not None: - leaked_tracers = maybe_find_leaked_tracers(t()) - if leaked_tracers: - raise leaked_tracer_error("sublevel", t(), leaked_tracers) +# TODO(dougalm): deprecate/delete def full_lower(val): if isinstance(val, Tracer): return val.full_lower() else: return val - -def _get_trace_level(t: Tracer) -> int: return t._trace.level - - -def find_top_trace(xs) -> Trace: - top_tracer = max((x for x in xs if isinstance(x, Tracer)), - default=None, key=_get_trace_level) - if top_tracer is not None: - top_tracer._assert_live() - top_main = top_tracer._trace.main - else: - top_main = None - dynamic = thread_local_state.trace_state.trace_stack.dynamic - top_main = (dynamic if top_main is None or dynamic.level > top_main.level - else top_main) - return top_main.with_cur_sublevel() - def get_referent(x: Any) -> Any: return x.get_referent() if isinstance(x, Tracer) else x @@ -2355,11 +2204,10 @@ class CallPrimitive(Primitive): multiple_results = True call_primitive = True - def bind(self, fun, *args, **params): - call_bind_continuation, top_trace, fun_, tracers, params = ( - call_bind_with_continuation(self, fun, *args, **params)) - outs = top_trace.process_call(self, fun_, tracers, params) - return call_bind_continuation(outs) + def bind_with_trace(self, trace, fun_and_args, params): + fun = fun_and_args[0] + args = fun_and_args[1:] + return trace.process_call(self, fun, args, params) def get_bind_params(self, params): new_params = dict(params) @@ -2369,45 +2217,9 @@ def get_bind_params(self, params): subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr)) return [subfun], new_params -def call_bind_with_continuation(primitive: CallPrimitive, fun, *args, **params): - top_trace = find_top_trace(args) - fun_, env_trace_todo = process_env_traces_call( - fun, primitive, top_trace.level, tuple(params.items())) - tracers = map(top_trace.full_raise, args) - fun_ = lu.annotate(fun_, fun.in_type) - - def call_bind_continuation(outs): - return map(full_lower, apply_todos(env_trace_todo(), outs)) - return call_bind_continuation, top_trace, fun_, tracers, params - -@lu.transformation_with_aux -def process_env_traces_call(primitive: CallPrimitive, level: int, - params_tuple: tuple, *args): - outs = yield args, {} - params = dict(params_tuple) - todo = [] - while True: - tracers = [x for x in outs if isinstance(x, Tracer) and x._trace.level > level] - if not tracers: - break - ans = max(tracers, key=_get_trace_level) - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, cur_todo = trace.post_process_call(primitive, outs, params) - todo.append(cur_todo) - yield outs, tuple(todo) # Ensure the aux output is immutable - -def apply_todos(todos, outs): - todos_list = list(todos) - while todos_list: - outs = map(full_lower, todos_list.pop()(outs)) - return outs - - def call_impl(f: lu.WrappedFun, *args, **params): del params # params parameterize the call primitive, not the function - with new_sublevel(): - return f.call_wrapped(*args) + return f.call_wrapped(*args) call_p: CallPrimitive = CallPrimitive('call') call = call_p.bind @@ -2459,16 +2271,15 @@ class MapPrimitive(Primitive): multiple_results = True map_primitive = True - def bind(self, fun, *args, **params): + def bind_with_trace(self, trace, fun_and_args, params): + fun = fun_and_args[0] + args = fun_and_args[1:] assert len(params['in_axes']) == len(args) - return map_bind(self, fun, *args, **params) + return trace.process_map(self, fun, args, params) def process(self, trace, fun, tracers, params): return trace.process_map(self, fun, tracers, params) - def post_process(self, trace, out_tracers, params): - return trace.post_process_map(self, out_tracers, params) - def get_bind_params(self, params): new_params = dict(params) jaxpr = new_params.pop('call_jaxpr') @@ -2477,59 +2288,6 @@ def get_bind_params(self, params): new_params['out_axes_thunk'] = HashableFunction(lambda: axes, closure=axes) return [subfun], new_params - -def map_bind_with_continuation(primitive: MapPrimitive, fun, *args, - out_axes_thunk, **params): - # The new thunk depends deterministically on the old thunk and the wrapped - # function. Any caching already has to include the wrapped function as part - # of the key, so we only use the previous thunk for equality checks. - @as_hashable_function(closure=out_axes_thunk) - def new_out_axes_thunk(): - out_axes = out_axes_thunk() - _, out_axes_transforms = todo_and_xforms() - for t in out_axes_transforms: - out_axes = t(out_axes) - return out_axes - params = dict(params, out_axes_thunk=new_out_axes_thunk) - top_trace = find_top_trace(args) - fun, todo_and_xforms = process_env_traces_map( - fun, primitive, top_trace and top_trace.level, tuple(params.items())) - tracers = map(top_trace.full_raise, args) - - def map_bind_continuation(outs): - env_trace_todo, _ = todo_and_xforms() - return map(full_lower, apply_todos(env_trace_todo, outs)) - - return map_bind_continuation, top_trace, fun, tracers, params - - -def map_bind(primitive: MapPrimitive, fun, *args, **params): - map_bind_continuation, top_trace, fun, tracers, params = ( - map_bind_with_continuation(primitive, fun, *args, **params)) - return map_bind_continuation( - primitive.process(top_trace, fun, tracers, params)) - -@lu.transformation_with_aux -def process_env_traces_map(primitive: MapPrimitive, level: int, - params_tuple: tuple, *args): - outs = yield args, {} - params = dict(params_tuple) - todo = [] - out_axes_transforms = [] - while True: - tracers = [x for x in outs if isinstance(x, Tracer) - and (level is None or x._trace.level > level)] - if not tracers: - break - ans = max(tracers, key=_get_trace_level) - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, (cur_todo, cur_xform) = primitive.post_process(trace, outs, params) - todo.append(cur_todo) - out_axes_transforms.append(cur_xform) - yield outs, (tuple(todo), tuple(out_axes_transforms)) - - def mapped_aval(size: AxisSize, axis: int | None, aval: AbstractValue) -> AbstractValue: handler, _ = aval_mapping_handlers.get(type(aval), (None, None)) @@ -2588,56 +2346,6 @@ def _unmap_dshaped_array( AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a) } -@contextmanager -def extend_axis_env(axis_name: AxisName, size: int, tag: Any): - frame = AxisEnvFrame(axis_name, size, tag) - ts = thread_local_state.trace_state - ts.axis_env.append(frame) - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - try: - yield - finally: - ts.axis_env.pop() - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - -@contextmanager -def extend_axis_env_nd(axes: Iterable[tuple[AxisName, int]], tag: Any = None): - frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes] - ts = thread_local_state.trace_state - ts.axis_env.extend(frames) - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - try: - yield - finally: - for _ in frames: ts.axis_env.pop() - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - - -@contextmanager -def stash_axis_env(): - "Promise that a function or with-suite does not depend implicitly on axis env" - # If the promise is broken, then a NameError about an unbound axis name will - # be raised. - ts = thread_local_state.trace_state - prev_axis_env, ts.axis_env = ts.axis_env, [] - config.update_thread_local_jit_state(axis_env_state=()) - try: - yield - finally: - ts.axis_env = prev_axis_env - config.update_thread_local_jit_state( - axis_env_state=tuple(f for f in ts.axis_env - if f.name is not no_axis_name)) - - # When a mapped function is given no axis name, we generate a name object based # on the id of the function object. Collisions aren't important because this # name can't be used in collectives, as user code never gets a ref to this @@ -2663,20 +2371,6 @@ def __lt__(self, other): return type(other) is _TempAxisName and self.id < other.id -def axis_frame(axis_name: AxisName, main_trace: MainTrace | None = None - ) -> AxisEnvFrame: - frames = thread_local_state.trace_state.axis_env - for frame in reversed(frames): - if (frame.name == axis_name and - (main_trace is None or frame.main_trace is main_trace)): - return frame - named_axes = [frame.name for frame in reversed(frames) - if not isinstance(frame.name, _TempAxisName)] - raise NameError( - f'unbound axis name: {axis_name}. The following axis names (e.g. defined ' - f'by pmap) are available to collective operations: {named_axes}') - - @dataclass(frozen=True) class NamedAxisEffect(effects.Effect): """A side-effect introducing a new named axis into the current scope.""" @@ -2704,98 +2398,9 @@ def remove_named_axis_effects( return jaxpr return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names)) - -ParamDict = dict[str, Any] -AxisSubst = Callable[[AxisName], tuple[AxisName, ...]] - -class NameGatheringSubst: - def __init__(self): - self.axis_names = set() - def __call__(self, axis_name): - self.axis_names.add(axis_name) - return (axis_name,) - -def used_axis_names(primitive: Primitive, params: ParamDict) -> set[AxisName]: - subst = NameGatheringSubst() - subst_axis_names(primitive, params, subst) - return subst.axis_names - -def subst_axis_names(primitive: Primitive, params: ParamDict, subst: AxisSubst, traverse: bool = True) -> ParamDict: - if primitive in axis_substitution_rules: - return axis_substitution_rules[primitive](params, subst, traverse) - if not traverse: - return params - # Default implementation: substitute names in all jaxpr parameters - if isinstance(primitive, MapPrimitive): - def shadowed_subst(name): - return (name,) if name == params['axis_name'] else subst(name) - else: - shadowed_subst = subst - jaxpr_params = [(n, v) for n, v in params.items() if isinstance(v, (Jaxpr, ClosedJaxpr))] - if not jaxpr_params: - return params - new_params = dict(params) - for name, jaxpr in jaxpr_params: - new_params[name] = subst_axis_names_jaxpr(jaxpr, shadowed_subst) - return new_params - -class DuplicateAxisNameError(Exception): - def __init__(self, var): - self.var = var - self.eqn = None - -def subst_axis_names_effects(effects: Set[Effect], subst: AxisSubst) -> Set[Effect]: - new_effects = set[Effect]() - for e in effects: - if isinstance(e, NamedAxisEffect): - new_effects.update(map(NamedAxisEffect, subst(e.name))) - else: - new_effects.add(e) - return new_effects - -def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var: - # Var identity is load-bearing, so we can't have duplicates! - if isinstance(v, DropVar): return v - assert v not in var_map - var_map[v] = v - return v - -def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn: - invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars] - try: - outvars = [subst_axis_names_var(v, subst, var_map) for v in eqn.outvars] - except DuplicateAxisNameError as e: - e.eqn = eqn - raise - params = subst_axis_names(eqn.primitive, eqn.params, subst) - effects = subst_axis_names_effects(eqn.effects, subst) - return eqn.replace(invars=invars, outvars=outvars, params=params, effects=effects) - -def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): - consts = None - if isinstance(jaxpr, ClosedJaxpr): - consts = jaxpr.consts - jaxpr = jaxpr.jaxpr - var_map: dict[Var, Var] = {} - invars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.invars] # type: ignore[union-attr] - constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr] - eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] - outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr] - effects = subst_axis_names_effects(jaxpr.effects, subst) - new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects) - if consts is not None: - return ClosedJaxpr(new_jaxpr, consts) - return new_jaxpr - def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr): return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)} -def subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): - if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it! - subst.axis_names |= used_axis_names_jaxpr(jaxpr) - return jaxpr - return do_subst_axis_names_jaxpr(jaxpr, subst) - def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects): return _replace_jaxpr_effects(jaxpr, frozenset(effects)) @@ -2803,23 +2408,6 @@ def replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: Effects): def _replace_jaxpr_effects(jaxpr: ClosedJaxpr, effects: frozenset[Effect]): return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(effects=set(effects))) - -axis_substitution_rules: dict[Primitive, Callable[[ParamDict, AxisSubst, bool], ParamDict]] = {} - -# ------------------- AxisPrimitive ------------------- -# Primitives that store axis names in params and want those axis names to -# participate in dispatch should subclass AxisPrimitive. - -class AxisPrimitive(Primitive): - def bind(self, *args, **params): - top_trace = find_top_trace(args) - axis_main = max((axis_frame(a).main_trace for a in used_axis_names(self, params)), - default=None, key=lambda t: getattr(t, 'level', -1)) - top_trace = (top_trace if not axis_main or axis_main.level < top_trace.level - else axis_main.with_cur_sublevel()) - return self.bind_with_trace(top_trace, args, params) - - # ------------------- Jaxpr checking ------------------- def typecheck(aval: AbstractValue, x) -> bool: @@ -3143,7 +2731,7 @@ def _check_map(ctx_factory, prim, in_avals, params): raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} " f"to jaxpr expecting {binder_aval}") - with extend_axis_env(params['axis_name'], axis_size, None): + with extend_axis_env_nd([(params['axis_name'], axis_size)]): _check_jaxpr(ctx_factory, call_jaxpr) mapped_out_avals = [v.aval for v in call_jaxpr.outvars] @@ -3460,46 +3048,45 @@ def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any], # Comparable object for checking whether JAX's trace state has changed. class OpaqueTraceState: - def __init__(self, trace_info, convention): - self._trace_info = trace_info - self._convention = convention + def __init__(self, trace_ref): + self._trace_ref = trace_ref def __eq__(self, other): if isinstance(other, OpaqueTraceState): - if self._convention in ["nnx"]: - return self._trace_info is other._trace_info - elif self._convention in ["haiku", "flax"]: - return self._trace_info == other._trace_info - else: - raise Exception(f"unrecognized convention: {self._convention}") - - -# Each library has its own opinion about what the important fragment of jax's -# internal state is. TODO: reconcile the differences and remove the flag. -def get_opaque_trace_state(convention="flax"): - if convention == "flax": - trace_info = find_top_trace(()).level - elif convention == "haiku": - trace_stack = thread_local_state.trace_state.trace_stack.stack - top_type = trace_stack[0].trace_type - level = trace_stack[-1].level - sublevel = cur_sublevel() - trace_info = (top_type, level, sublevel) - elif convention == "nnx": - trace_info = thread_local_state.trace_state.trace_stack.dynamic - else: - raise Exception(f"unrecognized convention: {convention}") + return self._trace_ref == other._trace_ref + else: + return False - return OpaqueTraceState(trace_info, convention) +def get_opaque_trace_state(convention): + del convention + return OpaqueTraceState(ref(trace_ctx.trace)) def nonempty_axis_env() -> bool: - return bool(thread_local_state.trace_state.axis_env) + return bool(trace_ctx.axis_env.axis_sizes) def unsafe_am_i_under_a_jit() -> bool: - return 'DynamicJaxprTrace' in str(thread_local_state.trace_state.trace_stack) + return 'DynamicJaxprTrace' in str(unsafe_get_trace_stack(trace_ctx.trace)) def unsafe_am_i_under_a_vmap() -> bool: - return 'BatchTrace' in str(thread_local_state.trace_state.trace_stack) + return 'BatchTrace' in str(unsafe_get_trace_stack(trace_ctx.trace)) + +# TODO(douglam): deprecate/delete +def find_top_trace(_): + return unsafe_get_current_trace() + + +def unsafe_get_current_trace(): + return trace_ctx.trace + +def unsafe_get_trace_stack(trace): + if hasattr(trace, "parent_trace"): + return unsafe_get_trace_stack(trace.parent_trace) + [trace] + else: + return [trace] + +def unsafe_get_axis_names() -> list[Any]: + return list(trace_ctx.axis_env.axis_sizes) -def unsafe_get_axis_names() -> list[str]: - return [axis.name for axis in thread_local_state.trace_state.axis_env] +# TODO(douglam): deprecate/delete +def axis_frame(axis_name): + return trace_ctx.axis_env.axis_size(axis_name) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 35e7d33430bd..afeef1e18456 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -138,9 +138,9 @@ def maybe_bdim_at_front(x, bdim): # axes instead of accepting and matching a given spec of output axes. Assumes # `f` is pytree-flattened def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size): - f, out_axes = batching.batch_subtrace(f) - f = batching._batch_outer(f, axis_name, axis_size, in_axes, - batching.BatchTrace, None) + axis_data = batching.AxisData(axis_name, axis_size, None) + tag = core.TraceTag() + f, out_axes = batching.batch_subtrace(f, tag, axis_data, in_axes) outs = f.call_wrapped(*args) return outs, out_axes() diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index f5ecdfcda286..0b57ff9028f6 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -354,25 +354,12 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): class CustomJVPCallPrimitive(core.Primitive): multiple_results = True - def bind(self, fun, jvp, *args, symbolic_zeros): - args = map(core.full_lower, args) - top_trace = core.find_top_trace(args) - fun, env_trace_todo1 = process_env_traces( - fun, self, top_trace and top_trace.level, False) - jvp, env_trace_todo2 = process_env_traces( - jvp, self, top_trace and top_trace.level, True) - tracers = map(top_trace.full_raise, args) - outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers, - symbolic_zeros=symbolic_zeros) - _, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) - return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) + def bind_with_trace(self, trace, args, params): + fun, jvp, tracers = args[0], args[1], args[2:] + return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params) def impl(self, fun, _, *args): - with core.new_sublevel(): - return fun.call_wrapped(*args) - - def post_process(self, trace, out_tracers, jvp_was_run: bool): - return trace.post_process_custom_jvp_call(out_tracers, jvp_was_run) + raise NotImplementedError def get_bind_params(self, params): new_params = dict(params) @@ -402,24 +389,6 @@ def jvp(*xs): return [*out_primals, *out_tangents] return jvp -@partial(lu.transformation_with_aux, use_eq_store=True) -def process_env_traces(primitive, level: int, jvp_was_run: bool, *args): - outs = yield args, {} - todo = [] - while True: - tracers = [x for x in outs if isinstance(x, core.Tracer) - and (level is None or x._trace.level > level)] - if tracers: - ans = max(tracers, key=lambda x: x._trace.level) - else: - break - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, cur_todo = primitive.post_process(trace, outs, jvp_was_run) - todo.append(cur_todo) - yield outs, tuple(todo) # Ensure the aux output is immutable - - effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') @@ -824,55 +793,12 @@ def _temporary_shape_exception(a, a_) -> bool: class CustomVJPCallPrimitive(core.CallPrimitive): initial_style: core.Primitive - def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): - args = map(core.full_lower, args) - top_trace = core.find_top_trace(args) - fun, env_trace_todo1 = process_env_traces( - fun, self, top_trace and top_trace.level, False) - fwd, env_trace_todo2 = process_env_traces_fwd( - fwd, top_trace and top_trace.level, out_trees) - tracers = map(top_trace.full_raise, args) - bwd_ = lambda *args: bwd(*args) - outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, - out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) - if fst: - return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) - else: - env_trace_todo, bwd_transform = env_trace_todo - bwd = _apply_bwd_transform(bwd_transform, bwd) - return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) - - def impl(self, fun, fwd, bwd, *args, out_trees): - del fwd, bwd, out_trees - with core.new_sublevel(): - return fun.call_wrapped(*args) + def bind_with_trace(self, trace, args, params): + fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] + return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) - def post_process(self, trace, out_tracers, params): - return trace.post_process_custom_vjp_call(out_tracers, params) custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') -@partial(lu.transformation_with_aux, use_eq_store=True) -def process_env_traces_fwd(level: int, out_trees, *args): - outs = yield args, {} - todo = [] - bwd_transforms = [] - while True: - tracers = [x for x in outs if isinstance(x, core.Tracer) - and (level is None or x._trace.level > level)] - if tracers: - ans = max(tracers, key=lambda x: x._trace.level) - else: - break - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, cur_todo, bwd_xform = trace.post_process_custom_vjp_call_fwd(outs, out_trees) - todo.append(cur_todo) - bwd_transforms.append(bwd_xform) - yield outs, (tuple(todo), tuple(bwd_transforms)) - - def _apply_bwd_transform(todos, bwd): todos_list = list(todos) while todos_list: @@ -889,7 +815,7 @@ def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): f'Effects not supported in `custom_vjp`: {disallowed_effects}') return fun_jaxpr.out_avals, fun_jaxpr.effects -custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr') +custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr') custom_vjp_call_jaxpr_p.multiple_results = True custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl) custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval) @@ -921,18 +847,16 @@ def _custom_vjp_call_jaxpr_jvp( ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp def _custom_vjp_call_jaxpr_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, + axis_data, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] _, args_batched = split_list(in_batched, [num_consts]) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name, - main_type) + fun_jaxpr, axis_data, in_batched, False) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] @@ -940,16 +864,15 @@ def _custom_vjp_call_jaxpr_vmap( def batched_fwd_jaxpr_thunk(*zeros): fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name, - main_type) + fwd_jaxpr, axis_data, args_batched, False) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] + tag = core.TraceTag() batched_bwd = batching.batch_custom_vjp_bwd( - bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type, - spmd_axis_name) + bwd, tag, axis_data, fwd_out_dims, fwd_args_batched) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, @@ -957,10 +880,7 @@ def batched_fwd_jaxpr_thunk(*zeros): num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims -batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ - _custom_vjp_call_jaxpr_vmap -batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial( - _custom_vjp_call_jaxpr_vmap, None) +batching.fancy_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) @@ -1144,11 +1064,12 @@ def rev(objective_fn, res, g): def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/jax-ml/jax/issues/6415 for motivation. - x = core.full_lower(x) + # See https://github.com/google/jax/issues/6415 for motivation. if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. return False + elif isinstance(x, ad.JVPTracer) and isinstance(x.tangent, ad.Zero): + return _maybe_perturbed(x.primal) elif isinstance(x, pe.DynamicJaxprTracer): # If x is a DynamicJaxprTracer then we're staging out; differentiation could # happen later, but some types always have trivial tangents. @@ -1532,7 +1453,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): return fwd_jaxpr.out_avals, fwd_jaxpr.effects def _remat_opt_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, + axis_data, args, in_dims, *, num_consts: int, num_res: int, @@ -1541,11 +1462,9 @@ def _remat_opt_vmap( ): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_size, in_batched, False, - axis_name, spmd_axis_name, main_type) + fwd_jaxpr, axis_data, in_batched, False) extra_consts = batched_fwd_jaxpr.consts batched_fwd_jaxpr = pe.close_jaxpr( pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) @@ -1557,8 +1476,7 @@ def _remat_opt_vmap( def batched_fun_jaxpr_thunk(): fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name, - main_type) + fun_jaxpr, axis_data, prim_batched, False) return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts batched_outs = remat_opt_p.bind(*extra_consts, *args, @@ -1592,7 +1510,7 @@ def _remat_opt_jvp( [len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out]) fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr)) - @pe._memoize + # @pe._memoize def fun_jvp_jaxpr_thunk(): fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) in_nz = [True] * len(primals) @@ -1666,8 +1584,9 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): xla.register_initial_style_primitive(remat_opt_p) mlir.register_lowering(remat_opt_p, mlir.lower_fun( _remat_opt_impl, multiple_results=True)) -batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap -batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None) + + +batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose pe.dce_rules[remat_opt_p] = _remat_opt_dce diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index c5cf0edf14c6..95e0578f0b2d 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -458,7 +458,9 @@ def __call__(self, *args, **kwargs): in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_partitioning") - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + mesh = mesh_lib.thread_resources.env.physical_mesh + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_partitioning_p.bind( diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index a4de1b8cc46c..9fe77ca0a6ac 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -155,18 +155,9 @@ class CustomTransposePrimitive(core.Primitive): map_primitive = False multiple_results = True - def bind(self, call, *args, **params): - # TODO(frostig,mattjj): This doesn't handle closures yet, which is - # a bit involved. Closures are complicated by us binding `call` - # twice in the JVP rule for custom transpose. The `env_trace_todo` - # output by `process_env_traces` due to one of those two bindings - # should be passable to the other, and need to be passed onward - # since the second bind is deferred by partial eval (since it - # typically receives unknowns) - top_trace = core.find_top_trace(args) - tracers = map(top_trace.full_raise, args) - outs = top_trace.process_custom_transpose(self, call, tracers, **params) - return outs + def bind_with_trace(self, trace, call_args, params): + call, tracers = call_args[0], call_args[1:] + return trace.process_custom_transpose(self, call, tracers, **params) # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e1e4bce2743a..97e702a9f25c 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -95,7 +95,8 @@ def apply_primitive(prim, *args, **params): @util.cache() def xla_primitive_callable(prim: core.Primitive, **params): def prim_fun(*args): - return prim.bind(*args, **params) + with config.eager_constant_folding(False): + return prim.bind(*args, **params) prim_fun.__name__ = prim.name prim_fun.__qualname__ = prim.name return api.jit(prim_fun) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d2a55933cad9..ac0418932b83 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -814,7 +814,7 @@ def check_user_dtype_supported(dtype, fun_name=None): int2, int4, uint2, - uint4, + uint4 ] if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index f1f46a5c18f7..9b350fdd6a87 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -29,7 +29,7 @@ from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( - add_jaxvals, replace_internal_symbolic_zeros, zeros_like_jaxval, + add_jaxvals, replace_internal_symbolic_zeros, replace_rule_output_symbolic_zeros, Zero, zeros_like_aval) from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 from jax._src.api_util import flatten_fun, flatten_fun_nokwargs @@ -69,16 +69,15 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, fun, aux = jvp_subtrace_aux(fun) return jvpfun(fun, instantiate, transform_stack), aux - @lu.transformation def jvpfun(instantiate, transform_stack, primals, tangents): + tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) - with core.new_main(JVPTrace) as main, ctx: - out_primals, out_tangents = yield (main, primals, tangents), {} - del main + with ctx: + out_primals, out_tangents = yield (tag, primals, tangents), {} if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst @@ -86,35 +85,26 @@ def jvpfun(instantiate, transform_stack, primals, tangents): yield out_primals, out_tangents @lu.transformation -def jvp_subtrace(main, primals, tangents): - trace = JVPTrace(main, core.cur_sublevel()) - for x in list(primals) + list(tangents): - if isinstance(x, Tracer): - if x._trace.level >= trace.level: - raise core.escaped_tracer_error( - x, f"Tracer from a higher level: {x} in trace {trace}") - assert x._trace.level < trace.level - in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x - for x, t in zip(primals, tangents)] - ans = yield in_tracers, {} - out_tracers = map(trace.full_raise, ans) - yield unzip2([(out_tracer.primal, out_tracer.tangent) - for out_tracer in out_tracers]) +def jvp_subtrace(tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = JVPTrace(parent_trace, tag) + in_tracers = [maybe_jvp_tracer(trace, x, t) + for x, t in zip(primals, tangents)] + with core.set_current_trace(trace): + ans = yield in_tracers, {} + out = unzip2(map(trace.to_primal_tangent_pair, ans)) + yield out @lu.transformation_with_aux -def jvp_subtrace_aux(main, primals, tangents): - trace = JVPTrace(main, core.cur_sublevel()) - for x in list(primals) + list(tangents): - if isinstance(x, Tracer): - assert x._trace.level < trace.level - ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {} - ans_tracers = map(trace.full_raise, ans) - out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers) - aux_primals = [core.full_lower(x.primal) - if isinstance(x, JVPTracer) and x._trace.level == trace.level - else x for x in aux] - yield (out_primals, out_tangents), aux_primals - +def jvp_subtrace_aux(tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = JVPTrace(parent_trace, tag) + with core.set_current_trace(trace): + ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {} + out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) + aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag + else x for x in aux] + yield (out_primals, out_tangents), aux_primals def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) @@ -166,7 +156,6 @@ def unpair_pval(pval): aval_1, aval_2 = aval return (aval_1, const_1), (aval_2, const_2) - # NOTE: The FIXMEs below are caused by primal/tangent mixups (type # errors if you will) def backward_pass(jaxpr: core.Jaxpr, transform_stack, @@ -281,37 +270,40 @@ def nonzero_tangent_outputs(*args, **kwargs): class JVPTrace(Trace): + def __init__(self, parent_trace, tag): + self.tag = tag + self.parent_trace = parent_trace - def pure(self, val): - tangent_zero = Zero.from_primal_value(val) - return JVPTracer(self, val, tangent_zero) - - def lift(self, val): - tangent_zero = Zero.from_primal_value(val) - return JVPTracer(self, val, tangent_zero) - - def sublift(self, val): - return JVPTracer(self, val.primal, val.tangent) + def to_primal_tangent_pair(self, val): + if isinstance(val, JVPTracer) and val._trace.tag is self.tag: + return (val.primal, val.tangent) + else: + tangent_zero = Zero.from_primal_value(val) + return (val, tangent_zero) def process_primitive(self, primitive, tracers, params): - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) + if all(type(t) is Zero for t in tangents_in): + return primitive.bind_with_trace(self.parent_trace, primals_in, params) jvp = primitive_jvps.get(primitive) if not jvp: msg = f"Differentiation rule for '{primitive}' not implemented" raise NotImplementedError(msg) - primal_out, tangent_out = jvp(primals_in, tangents_in, **params) + with core.set_current_trace(self.parent_trace): + primal_out, tangent_out = jvp(primals_in, tangents_in, **params) + if primitive.multiple_results: - return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)] + return [maybe_jvp_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] else: - return JVPTracer(self, primal_out, tangent_out) + return maybe_jvp_tracer(self, primal_out, tangent_out) def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results - primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) + primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not Zero for t in tangents] tangents = [t if type(t) is not Zero else None for t in tangents] args, in_tree = tree_flatten((primals, tangents)) - f_jvp = jvp_subtrace(f, self.main) + f_jvp = jvp_subtrace(f, self.tag) f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] @@ -328,76 +320,59 @@ def new_out_axes_thunk(): f_jvp, out_tree = traceable(f_jvp, in_tree) update_params = call_param_updaters.get(call_primitive) new_params = update_params(params, which_nz) if update_params else params - result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz), - *args, **new_params) + fun_and_args = (_update_annotation(f_jvp, f.in_type, which_nz),) + tuple(args) + result = call_primitive.bind_with_trace(self.parent_trace, fun_and_args, new_params) primal_out, tangent_out = tree_unflatten(out_tree(), result) tangent_out = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primal_out, tangent_out)] - return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] - - def post_process_call(self, call_primitive, out_tracers, params): - primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) - out, treedef = tree_flatten((primals, tangents)) - tangents_nz = [type(t) is not Zero for t in tangents] - del primals, tangents - main = self.main - def todo(x): - primals, tangents = tree_unflatten(treedef, x) - trace = JVPTrace(main, core.cur_sublevel()) - return map(partial(JVPTracer, trace), primals, tangents) - if call_primitive.map_primitive: - def out_axes_transform(out_axes): - return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz)) - todo = (todo, out_axes_transform) - return out, todo + return [maybe_jvp_tracer(self, p, t) for p, t in zip(primal_out, tangent_out)] # The only difference between process_map and process_call is that # the `in_axes` and `out_axes_thunk` params must be updated; # that's handled in process_call. process_map = process_call - post_process_map = post_process_call - def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - primals_in = map(core.full_lower, primals_in) - if not symbolic_zeros: - tangents_in = map(instantiate_zeros, tangents_in) - else: - tangents_in = map(replace_internal_symbolic_zeros, tangents_in) - outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) + def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros): + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) + if all(type(t) is Zero for t in tangents_in): + return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in), + dict(symbolic_zeros=symbolic_zeros)) + with core.set_current_trace(self.parent_trace): + if not symbolic_zeros: + tangents_in = map(instantiate_zeros, tangents_in) + else: + tangents_in = map(replace_internal_symbolic_zeros, tangents_in) + outs = f_jvp.call_wrapped(*(tuple(primals_in) + tuple(tangents_in))) + primals_out, tangents_out = split_list(outs, [len(outs) // 2]) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) - return map(partial(JVPTracer, self), primals_out, tangents_out) - - def post_process_custom_jvp_call(self, out_tracers, _): - raise CustomJVPException() + return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) - def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): - # Local import to prevent an import cycle. - from jax._src.lax import lax - - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - fwd_in = [(core.full_lower(p), type(t) is not Zero) - for p, t in zip(primals_in, tangents_in)] + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) + if all(type(t) is Zero for t in tangents_in): + return prim.bind_with_trace(self.parent_trace, + (fun, fwd, bwd, *primals_in), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) + fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] fwd_in = [x for pair in fwd_in for x in pair] # flatten - res_and_primals_out = fwd.call_wrapped(*fwd_in) + with core.set_current_trace(self.parent_trace): + res_and_primals_out = fwd.call_wrapped(*fwd_in) + _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! - tangents_in = map(instantiate_zeros, tangents_in) - tangents_out = custom_lin_p.bind( + with core.set_current_trace(self.parent_trace): + tangents_in = map(instantiate_zeros, tangents_in) + tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - return map(partial(JVPTracer, self), primals_out, tangents_out) - - def post_process_custom_vjp_call(self, out_tracers, _): - raise CustomVJPException() + return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) def process_custom_transpose(self, prim, call, tracers, **params): - ps_in, ts_in = unzip2((t.primal, t.tangent) for t in tracers) + ps_in, ts_in = unzip2(map(self.to_primal_tangent_pair, tracers)) res_ps_in, lin_ps_in = split_list(ps_in, [params['res_tree'].num_leaves]) res_ts_in, lin_ts_in = split_list(ts_in, [params['res_tree'].num_leaves]) @@ -421,24 +396,18 @@ def process_custom_transpose(self, prim, call, tracers, **params): raise NotImplementedError( 'JVP of custom transpose with respect to non-symbolic-zero residuals') - ps_out = prim.bind(call, *ps_in, **params) - - lin_ts_in = map(instantiate_zeros, lin_ts_in) - ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params) + with core.set_current_trace(self.parent_trace): + ps_out = prim.bind(call, *ps_in, **params) + lin_ts_in = map(instantiate_zeros, lin_ts_in) + ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params) - return map(partial(JVPTracer, self), ps_out, ts_out) - - def join(self, xt, yt): - xz, yz = type(xt) is Zero, type(yt) is Zero - if xz == yz: - return xt, yt - elif yz and not xz: - return xt, zeros_like_jaxval(xt) - elif xz and not yz: - return zeros_like_jaxval(yt), yt - else: - raise TypeError((xt, yt)) + return map(partial(maybe_jvp_tracer, self), ps_out, ts_out) +def maybe_jvp_tracer(trace, primal, tangent): + if type(tangent) is Zero: + return primal + else: + return JVPTracer(trace, primal, tangent) class JVPTracer(Tracer): __slots__ = ['primal', 'tangent'] @@ -452,7 +421,6 @@ def __init__(self, trace, primal, tangent): @property def aval(self): - # TODO(dougalm): add epsilon ball return get_aval(self.primal) def full_lower(self): diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index b40a3807dea2..2ff27f0c5d74 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -14,7 +14,7 @@ from __future__ import annotations import collections -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Sequence import dataclasses from functools import partial from typing import Any, Union @@ -29,12 +29,12 @@ from jax._src.ad_util import (Zero, instantiate, SymbolicZero, replace_rule_output_symbolic_zeros, add_jaxvals, add_jaxvals_p) -from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName +from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) from jax._src.typing import Array -from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, +from jax._src.util import (unzip2, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) @@ -284,8 +284,7 @@ def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, def _cont(axis_size, elt, axis): return from_elt(trace, axis_size, i, elt, axis) return handler(_cont, axis_size, x, spec) - x_ = trace.full_raise(x) - val, bdim = x_.val, x_.batch_dim + val, bdim = trace.to_batch_info(x) if type(bdim) is RaggedAxis: if spec is not jumble_axis: # TODO(mattjj): improve this error message @@ -293,9 +292,9 @@ def _cont(axis_size, elt, axis): return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) else: try: - return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val) + return matchaxis(trace.axis_data.name, axis_size, bdim, spec, val) except SpecMatchError: - raise SpecMatchError(i, x_.batch_dim, spec) from None + raise SpecMatchError(i, x.batch_dim, spec) from None from_elt_handlers: dict[type, FromEltHandler] = {} def make_iota(axis_size: AxisSize) -> Array: @@ -435,165 +434,118 @@ def get_referent(self): else: # TODO(mattjj): could handle the RaggedAxis case? return self +@dataclasses.dataclass(frozen=True) +class AxisData: + name : Any + size : Any + spmd_name : Any + + class BatchTrace(Trace): - def __init__(self, *args, axis_name, spmd_axis_name = None): - super().__init__(*args) - self.axis_name = axis_name - self.spmd_axis_name = spmd_axis_name - - def pure(self, val): - return BatchTracer(self, val, not_mapped, source_info_util.current()) - - def lift(self, val): - return BatchTracer(self, val, not_mapped, source_info_util.current()) - - def sublift(self, val): - return BatchTracer(self, val.val, val.batch_dim, source_info_util.current()) - - def get_primitive_batcher(self, primitive, frame): - if primitive in primitive_batchers: - return primitive_batchers[primitive] - elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers: - return partial(spmd_axis_primitive_batchers[primitive], - self.spmd_axis_name, frame.size, frame.name, - frame.main_trace.trace_type) - elif primitive in axis_primitive_batchers: - return self.get_axis_primitive_batcher(primitive, frame) - msg = "Batching rule for '{}' not implemented" - raise NotImplementedError(msg.format(primitive)) - - def get_axis_primitive_batcher(self, primitive, frame): - return partial(axis_primitive_batchers[primitive], - frame.size, frame.name, frame.main_trace.trace_type) - - def get_frame(self, vals, dims) -> core.AxisEnvFrame: - if any(d is not not_mapped for d in dims): - sizes = (x.shape[d] if type(d) is int else d.size - for x, d in zip(vals, dims) if d is not not_mapped) - axis_size, = core.dedup_referents(sizes) + def __init__(self, parent_trace, tag, axis_data): + self.parent_trace = parent_trace + assert isinstance(axis_data, AxisData) + self.axis_data = axis_data + self.tag = tag + + def to_batch_info(self, val): + if isinstance(val, BatchTracer) and val._trace.tag is self.tag: + return val.val, val.batch_dim else: - axis_size = None # can't be inferred from data - if self.axis_name is core.no_axis_name: - assert axis_size is not None # must be inferable from data - return core.AxisEnvFrame(self.axis_name, axis_size, self.main) - frame = core.axis_frame(self.axis_name, self.main) - assert axis_size is None or axis_size == frame.size, (axis_size, frame.size) - assert frame.main_trace is self.main - return frame - - def process_primitive(self, primitive, tracers, params): + return val, not_mapped + + def process_primitive(self, p, tracers, params): if config.dynamic_shapes.value: - primitive.abstract_eval(*(t.aval for t in tracers), **params) - vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers) - is_axis_primitive = primitive in axis_primitive_batchers - used_names = core.used_axis_names(primitive, params) - if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names): - frame = self.get_frame(vals_in, dims_in) - batcher_primitive = self.get_axis_primitive_batcher(primitive, frame) - val_out, dim_out = batcher_primitive(vals_in, dims_in, **params) - elif all(bdim is not_mapped for bdim in dims_in): - return primitive.bind(*vals_in, **params) + p.abstract_eval(*(map(core.get_aval, tracers)), **params) + vals_in, dims_in = unzip2(map(self.to_batch_info, tracers)) + args_not_mapped = all(bdim is not_mapped for bdim in dims_in) + if p in fancy_primitive_batchers: + if (args_not_mapped + and p in skippable_batchers + and not any(self.axis_data.name == axis_name + for axis_name in skippable_batchers[p](params))): + # no-op shortcut + return p.bind_with_trace(self.parent_trace, vals_in, params) + else: + with core.set_current_trace(self.parent_trace): + val_out, dim_out = fancy_primitive_batchers[p](self.axis_data, vals_in, dims_in, **params) + elif args_not_mapped: + # no-op shortcut + return p.bind_with_trace(self.parent_trace, vals_in, params) + elif p in primitive_batchers: + with core.set_current_trace(self.parent_trace): + val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params) else: - frame = self.get_frame(vals_in, dims_in) - batched_primitive = self.get_primitive_batcher(primitive, frame) - val_out, dim_out = batched_primitive(vals_in, dims_in, **params) + raise NotImplementedError("Batching rule for '{}' not implemented".format(p)) src = source_info_util.current() - if primitive.multiple_results: - return [BatchTracer(self, x, d, src) for x, d in zip(val_out, dim_out)] + if p.multiple_results: + with core.set_current_trace(self.parent_trace): # val_out may be lazy map + return [BatchTracer(self, x, d, src) if d is not not_mapped else x + for x, d in zip(val_out, dim_out)] else: - return BatchTracer(self, val_out, dim_out, src) + return (BatchTracer(self, val_out, dim_out, src) + if dim_out is not not_mapped else val_out) def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results params = dict(params, name=params.get('name', f.__name__)) - vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) - if all(bdim is not_mapped for bdim in dims): - return call_primitive.bind(f, *vals, **params) - sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths) - for x, d in zip(vals, dims) if d is not not_mapped) - axis_size, = core.dedup_referents(sizes) + vals, dims = unzip2(map(self.to_batch_info, tracers)) segment_lens, dims = indirectify_ragged_axes(dims) - f_, dims_out = batch_subtrace(f, self.main, tuple(dims)) + f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims)) f_ = _update_annotation( - f_, f.in_type, axis_size, self.axis_name, dims, segment_lens) - vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params) + f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens) + + with core.set_current_trace(self.parent_trace): + vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params) vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out()) src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)] - def post_process_call(self, call_primitive, out_tracers, params): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main - def todo(vals): - trace = main.with_cur_sublevel() - return map(partial(BatchTracer, trace), vals, dims, srcs) - return vals, todo - def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): - vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) - if all(dim is not_mapped for dim in dims): - return map_primitive.bind(f, *vals, **params) - else: - assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1 - # The logic for the dimension math below is as follows: - # ╔═════════════╦════════════════════════════════════════╦═══════════╗ - # ║ d / in_axis ║ None ║ int ║ - # ╠═════════════╬════════════════════════════════════════╩═══════════╣ - # ║ None ║ No extra axis, so in_axis unaffected ║ - # ╠═════════════╬════════════════════════════════════════╦═══════════╣ - # ║ int ║ Not mapped, so batching dim unaffected ║ See below ║ - # ╚═════════════╩════════════════════════════════════════╩═══════════╝ - # When both d and in_axis are defined then: - # - If `d <= in_axis`, we have to move the `in_axis` one dimension further; - # - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed). - def both_mapped(in_out_axis, d): - return in_out_axis is not None and d is not not_mapped - new_in_axes = tuple( - in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis - for d, in_axis in zip(dims, params['in_axes'])) - new_dims = tuple( - d - 1 if both_mapped(in_axis, d) and in_axis < d else d - for d, in_axis in zip(dims, params['in_axes'])) - f, dims_out = batch_subtrace(f, self.main, new_dims) - out_axes_thunk = params['out_axes_thunk'] - # NOTE: This assumes that the choice of the dimensions over which outputs - # are batched is entirely dependent on the function and not e.g. on the - # data or its shapes. - @as_hashable_function(closure=out_axes_thunk) - def new_out_axes_thunk(): - return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis - for out_axis, d in zip(out_axes_thunk(), dims_out())) - new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) - vals_out = map_primitive.bind(f, *vals, **new_params) - dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d - for d, out_axis in zip(dims_out(), out_axes_thunk())] - src = source_info_util.current() - return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)] - - def post_process_map(self, call_primitive, out_tracers, params): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main + vals, dims = unzip2(map(self.to_batch_info, tracers)) + # The logic for the dimension math below is as follows: + # ╔═════════════╦════════════════════════════════════════╦═══════════╗ + # ║ d / in_axis ║ None ║ int ║ + # ╠═════════════╬════════════════════════════════════════╩═══════════╣ + # ║ None ║ No extra axis, so in_axis unaffected ║ + # ╠═════════════╬════════════════════════════════════════╦═══════════╣ + # ║ int ║ Not mapped, so batching dim unaffected ║ See below ║ + # ╚═════════════╩════════════════════════════════════════╩═══════════╝ + # When both d and in_axis are defined then: + # - If `d <= in_axis`, we have to move the `in_axis` one dimension further; + # - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed). def both_mapped(in_out_axis, d): return in_out_axis is not None and d is not not_mapped - def todo(vals): - trace = main.with_cur_sublevel() - return [BatchTracer(trace, v, d + 1 if both_mapped(oa, d) and oa <= d else d, s) - for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)] - if call_primitive.map_primitive: - def out_axes_transform(out_axes): - return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis - for out_axis, d in zip(out_axes, dims)) - todo = (todo, out_axes_transform) - return vals, todo + new_in_axes = tuple( + in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis + for d, in_axis in zip(dims, params['in_axes'])) + new_dims = tuple( + d - 1 if both_mapped(in_axis, d) and in_axis < d else d + for d, in_axis in zip(dims, params['in_axes'])) + f, dims_out = batch_subtrace(f, self.tag, self.axis_data, new_dims) + out_axes_thunk = params['out_axes_thunk'] + # NOTE: This assumes that the choice of the dimensions over which outputs + # are batched is entirely dependent on the function and not e.g. on the + # data or its shapes. + @as_hashable_function(closure=out_axes_thunk) + def new_out_axes_thunk(): + return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis + for out_axis, d in zip(out_axes_thunk(), dims_out())) + new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) + with core.set_current_trace(self.parent_trace): + vals_out = map_primitive.bind(f, *vals, **new_params) + dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d + for d, out_axis in zip(dims_out(), out_axes_thunk())] + src = source_info_util.current() + return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)] def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) - jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims) - out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) + in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) + fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) + jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims) + out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp) + tuple(in_vals), + dict(symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: assert out_dims == out_dims[:len(out_dims) // 2] * 2 @@ -601,34 +553,18 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] - def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main - def todo(vals): - trace = main.with_cur_sublevel() - if jvp_was_run: - primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):] - assert primal_dims == tangent_dims - primal_srcs = srcs[:len(vals)] - return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) - else: - return map(partial(BatchTracer, trace), vals, dims, srcs) - return vals, todo - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, symbolic_zeros): # pytype: disable=signature-mismatch - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) - if d is not not_mapped} + in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] - fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) - fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims) - bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, - out_dims2, in_dims, self.main.trace_type, - self.spmd_axis_name) - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) + + fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) + fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims) + + bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, out_dims2, in_dims) + out_vals = prim.bind_with_trace(self.parent_trace, + (fun, fwd, bwd) + tuple(in_vals), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: _, res_tree = out_trees() @@ -636,83 +572,46 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] - def post_process_custom_vjp_call(self, out_tracers, _): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - main = self.main - def todo(vals): - trace = main.with_cur_sublevel() - return map(partial(BatchTracer, trace), vals, dims, srcs) - return vals, todo - - def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees): - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped} - main, trace_type = self.main, self.main.trace_type - axis_name = self.axis_name - _, res_tree = out_trees() - num_res = res_tree.num_leaves - res_dims, primal_dims = split_list(dims, [num_res]) - _, primal_srcs = split_list(srcs, [num_res]) - def todo(vals): - trace = main.with_cur_sublevel() - return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) - def bwd_transform(bwd): - return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,), - trace_type, self.spmd_axis_name) - return vals, todo, bwd_transform - -def _main_trace_for_axis_names(main_trace: core.MainTrace, - axis_name: Iterable[AxisName], - ) -> bool: - # This function exists to identify whether a main trace corresponds to any of - # the axis names used by a primitive. Axis names alone aren't enough because - # axis names can shadow, so we use the main trace as a tag. - return any(main_trace is core.axis_frame(n).main_trace for n in axis_name) - ### API for batching callables with vmappable inputs and outputs -def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size, - in_dims, out_dim_dests, main_type: type[BatchTrace] = BatchTrace, - spmd_axis_name: tuple[AxisName, ...] | None = None - ) -> lu.WrappedFun: +def batch(fun: lu.WrappedFun, axis_data, + in_dims, out_dim_dests) -> lu.WrappedFun: # we split up _batch_inner and _batch_outer for the leak checker - f = _batch_inner(fun, axis_size, out_dim_dests) - return _batch_outer(f, axis_name, axis_size, in_dims, main_type, - spmd_axis_name) + f = _batch_inner(fun, axis_data, out_dim_dests) + return _batch_outer(f, axis_data, in_dims) @lu.transformation -def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name, - *in_vals): - with core.new_main( - main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main: - with core.extend_axis_env(axis_name, axis_size, main): - with source_info_util.transform_name_stack('vmap'): - outs = yield (main, in_dims, *in_vals), {} - del main +def _batch_outer(axis_data, in_dims, *in_vals): + tag = TraceTag() + with source_info_util.transform_name_stack('vmap'): + outs, trace = yield (tag, in_dims, *in_vals), {} + with core.ensure_no_leaks(trace): del trace yield outs @lu.transformation -def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals): +def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims - trace = main.with_cur_sublevel() - idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0, - source_info_util.current())) - in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) - outs = yield in_tracers, {} + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, + source_info_util.current())) + in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) + with core.set_current_trace(trace): + with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]): + outs = yield in_tracers, {} + out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests - out_vals = map(partial(from_elt, trace, axis_size), range(len(outs)), + out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), outs, out_dim_dests) - yield out_vals + + yield out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. def vtile(f_flat: lu.WrappedFun, in_axes_flat: tuple[int | None, ...], out_axes_flat: tuple[int | None, ...], tile_size: int | None, - axis_name: AxisName, - main_type: type[BatchTrace] = BatchTrace): + axis_name: AxisName): @curry def tile_axis(arg, axis: int | None, tile_size): if axis is None: @@ -736,23 +635,24 @@ def _map_to_tile(*args_flat): outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} yield map(untile_axis, outputs_flat, out_axes_flat) - return _map_to_tile(batch( - f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type)) + axis_data = AxisData(axis_name, tile_size, None) + return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat)) ### API for batching functions with jaxpr type inputs and outputs @lu.transformation_with_aux -def batch_subtrace(main, in_dims, *in_vals): - trace = main.with_cur_sublevel() - in_dims = in_dims() if callable(in_dims) else in_dims - in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) - in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) - if dim is not None else x for x, dim in zip(in_vals, in_dims)] - outs = yield in_tracers, {} - out_tracers = map(trace.full_raise, outs) - out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) - segment_lens, out_dims = indirectify_ragged_axes(out_dims) - yield (*segment_lens, *out_vals), out_dims +def batch_subtrace(tag, axis_data, in_dims, *in_vals): + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + with core.set_current_trace(trace): + in_dims = in_dims() if callable(in_dims) else in_dims + in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) + in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) + if dim is not None else x for x, dim in zip(in_vals, in_dims)] + outs = yield in_tracers, {} + out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) + segment_lens, out_dims = indirectify_ragged_axes(out_dims) + yield (*segment_lens, *out_vals), out_dims def indirectify_ragged_axes(dims): if not any(type(d) is RaggedAxis for d in dims): @@ -823,38 +723,30 @@ def fetch(idx): # Can reuse same pattern for all dynamic shape stuff. def batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, - axis_size: core.AxisSize, + axis_data, in_axes: tuple[int | NotMapped | RaggedAxis, ...], - axis_name: AxisName, - spmd_axis_name: AxisName, - main_type: type[BatchTrace], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]: # This is only ever used in pjit. The difference vs batch_jaxpr is that # batch_jaxpr2 lets the callee decide which outputs are batched and what # their batch axes are; whereas batch_jaxpr has to obey caller-imposed # consistency constraints, such as type-agreement across arms of a # `lax.cond`, or input-output agreement for the body of a `lax.scan`. - return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name, - spmd_axis_name, main_type) + return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes)) @weakref_lru_cache def _batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, - axis_size: core.AxisSize, + axis_data, in_axes: tuple[int | NotMapped | RaggedAxis, ...], - axis_name: AxisName, - spmd_axis_name: AxisName, - main_type: type[BatchTrace], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]: f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) - f, out_axes = _batch_jaxpr_inner(f, axis_size) - f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, - main_type) + f, out_axes = _batch_jaxpr_inner(f, axis_data) + f = _batch_jaxpr_outer(f, axis_data, in_axes) in_axes2, avals_in = unzip2([ handle_ragged(closed_jaxpr.in_avals, dim, aval) if isinstance(dim, RaggedAxis) else (dim, aval) for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) - avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval) + avals_in2 = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(avals_in, in_axes2)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) @@ -868,14 +760,11 @@ def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis, new_aval = aval.update(shape=tuple(new_shape)) return dim.stacked_axis, new_aval -def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, - spmd_axis_name, main_type): +def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate): inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _batch_jaxpr(closed_jaxpr, axis_size, tuple(in_batched), inst, - axis_name, spmd_axis_name, main_type) + return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst) -def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, - spmd_axis_name, main_type): +def _batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate): assert (isinstance(instantiate, bool) or isinstance(instantiate, (list, tuple)) and all(isinstance(b, bool) for b in instantiate)) @@ -883,46 +772,41 @@ def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, instantiate = [instantiate] * len(closed_jaxpr.out_avals) in_axes = [0 if b else not_mapped for b in in_batched] out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate] - return batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, - axis_name, spmd_axis_name, main_type) + return batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest) -def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, - spmd_axis_name, main_type): - return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes), - tuple(out_axes_dest), axis_name, spmd_axis_name, - main_type) +def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): + return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest)) @weakref_lru_cache -def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, - axis_name, spmd_axis_name, main_type): +def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) - f, out_axes = _batch_jaxpr_inner(f, axis_size) - f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes) - f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, - main_type) - avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped + f, out_axes = _batch_jaxpr_inner(f, axis_data) + f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes) + f = _batch_jaxpr_outer(f, axis_data, in_axes) + avals_in = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() @lu.transformation_with_aux -def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals): - trace = main.with_cur_sublevel() - _, in_axes = resolve_ragged_axes(in_vals, in_axes) - in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val - for val, dim in zip(in_vals, in_axes)] - outs = yield in_tracers, {} - out_tracers = map(trace.full_raise, outs) - out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers) - new_out_axes = indirectify_ragged_axes_against_inputs_outputs( - out_axes, in_vals, out_vals) - yield out_vals, new_out_axes +def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + _, in_axes = resolve_ragged_axes(in_vals, in_axes) + in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val + for val, dim in zip(in_vals, in_axes)] + with core.set_current_trace(trace): + with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]): + outs = yield in_tracers, {} + out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) + new_out_axes = indirectify_ragged_axes_against_inputs_outputs( + out_axes, in_vals, out_vals) + yield out_vals, new_out_axes @lu.transformation_with_aux -def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes, +def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, *in_vals): - trace = main.with_cur_sublevel() - out_vals = yield (main, in_axes, *in_vals), {} + out_vals = yield (trace, in_axes, *in_vals), {} out_axes = out_axes() out_axes_dest = [(None if src is not_mapped else 0) if dst is zero_if_mapped else dst @@ -930,24 +814,18 @@ def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes, if len(out_axes_dest) != len(out_axes): out_axis_dest, = out_axes_dest out_axes_dest = [out_axis_dest] * len(out_axes) - out_vals = map(partial(matchaxis, trace.axis_name, axis_size), + out_vals = map(partial(matchaxis, axis_data.name, axis_data.size), out_axes, out_axes_dest, out_vals) out_batched = [dst is not None for dst in out_axes_dest] yield out_vals, out_batched @lu.transformation -def _batch_jaxpr_outer(axis_name, spmd_axis_name, axis_size, in_dims, main_type, - *in_vals): - if axis_size is None: - axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} +def _batch_jaxpr_outer(axis_data, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) else ax for x, ax in unsafe_zip(in_vals, in_dims)] - with core.new_main(main_type, axis_name=axis_name, - spmd_axis_name=spmd_axis_name) as main: - with core.extend_axis_env(axis_name, axis_size, main): - out_vals = yield (main, in_dims, *in_vals), {} - del main + tag = TraceTag() + out_vals = yield (tag, in_dims, *in_vals), {} yield out_vals def _merge_bdims(x, y): @@ -966,31 +844,33 @@ class ZeroIfMapped: pass ### functions for handling custom_vjp @lu.transformation_with_aux -def batch_custom_jvp_subtrace(main, in_dims, *in_vals): - size, = {x.shape[d] for x, d in zip(in_vals, in_dims * 2) - if d is not not_mapped} - trace = main.with_cur_sublevel() - in_tracers = [val if dim is None else - SymbolicZero(core.mapped_aval(size, dim, val.aval)) - if type(val) is SymbolicZero else BatchTracer(trace, val, dim) - for val, dim in zip(in_vals, in_dims * 2)] - outs = yield in_tracers, {} - # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can - # be wasteful in the rare case it actually triggers; handle symbolically! - outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] - out_tracers = map(trace.full_raise, outs) - out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) +def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): + size = axis_data.size + with core.take_current_trace() as parent_trace: + trace = BatchTrace(parent_trace, tag, axis_data) + in_tracers = [val if dim is None else + SymbolicZero(core.mapped_aval(size, dim, val.aval)) + if type(val) is SymbolicZero else BatchTracer(trace, val, dim) + for val, dim in zip(in_vals, in_dims * 2)] + with core.set_current_trace(trace): + outs = yield in_tracers, {} + # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can + # be wasteful in the rare case it actually triggers; handle symbolically! + outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] + + out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) - out_primals = map(partial(matchaxis, trace.axis_name, size), + out_primals = map(partial(matchaxis, trace.axis_data.name, size), out_primal_bds, out_dims, out_primals) - out_tangents = map(partial(matchaxis, trace.axis_name, size), + out_tangents = map(partial(matchaxis, trace.axis_data.name, size), out_tangent_bds, out_dims, out_tangents) yield out_primals + out_tangents, out_dims * 2 -def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, - main_type, spmd_axis_name): +def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): + axis_size = axis_data.size + axis_name = axis_data.name def new_bwd(*args): in_dims_ = in_dims() if callable(in_dims) else in_dims args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval)) @@ -998,9 +878,7 @@ def new_bwd(*args): for x, dim in zip(args, in_dims_)] in_dims_ = [None if type(x) is SymbolicZero else d for x, d in zip(args, in_dims_)] - bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) - bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type, - spmd_axis_name) + bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_) bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests) return bwd_.call_wrapped(*args) @@ -1039,8 +917,23 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False) tuple[Any, Union[int, None, tuple[Union[int, None], ...]]] ] primitive_batchers : dict[core.Primitive, BatchingRule] = {} -axis_primitive_batchers: dict[core.Primitive, Callable] = {} -spmd_axis_primitive_batchers: dict[core.Primitive, Callable] = {} +# "fancy" primitive batchers just take a extra leading `AxisData` and "trace type" args +fancy_primitive_batchers: dict[core.Primitive, Callable] = {} + +# backwards compat shim. TODO: delete +class AxisPrimitiveBatchersProxy: + def __setitem__(self, prim, batcher): + def wrapped(axis_data, vals, dims, **params): + return batcher(axis_data.size, axis_data.name, None, vals, dims, **params) + fancy_primitive_batchers[prim] = wrapped + +axis_primitive_batchers = AxisPrimitiveBatchersProxy() + + +# Presence in this table allows fancy batchers to be skipped by batch traces for +# irrelevant axes. The Callable takes the params and returns a list of relevant +# axes. +skippable_batchers : dict[core.Primitive, Callable] = {} def defvectorized(prim): primitive_batchers[prim] = partial(vectorized_batcher, prim) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ab00e5729cc2..00c970186673 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -15,7 +15,7 @@ from collections import namedtuple from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager, AbstractContextManager +from contextlib import contextmanager from functools import partial import inspect import itertools as it @@ -38,7 +38,7 @@ from jax._src import xla_metadata as xla_metadata_lib from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, fun_sourceinfo) -from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, +from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, ConcreteArray, Var, DropVar, raise_to_shaped, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, @@ -143,22 +143,21 @@ def get_aval(self) -> AbstractValue: class JaxprTrace(Trace['JaxprTracer']): - def __init__(self, *args, name_stack: source_info_util.NameStack): - super().__init__(*args) + def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, tag:TraceTag): self.name_stack = name_stack + self.tag = tag + self.parent_trace = parent_trace - def pure(self, val: Any) -> JaxprTracer: - return self.new_const(val) - - def lift(self, val: Tracer) -> JaxprTracer: - return self.new_const(val) - - def sublift(self, val: JaxprTracer) -> JaxprTracer: - return JaxprTracer(self, val.pval, FreeVar(val)) + def to_jaxpr_tracer(self, x): + if isinstance(x, JaxprTracer) and x._trace.tag is self.tag: + if x._trace is self: + return x + else: + return JaxprTracer(self, x.pval, FreeVar(x)) + else: + return self.new_const(x) def new_const(self, val) -> JaxprTracer: - if isinstance(val, Tracer) and val._trace.level == self.level: - raise Exception return JaxprTracer(self, PartialVal.known(val), None) def new_instantiated_literal(self, val) -> JaxprTracer: @@ -206,18 +205,21 @@ def instantiate_const_abstracted(self, tracer) -> JaxprTracer: return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) def process_primitive(self, primitive, tracers, params): - if primitive in custom_partial_eval_rules: - return custom_partial_eval_rules[primitive](self, *tracers, **params) - else: - return self.default_process_primitive(primitive, tracers, params) + with core.set_current_trace(self.parent_trace): + if primitive in custom_partial_eval_rules: + tracers = map(self.to_jaxpr_tracer, tracers) + return custom_partial_eval_rules[primitive](self, *tracers, **params) + else: + return self.default_process_primitive(primitive, tracers, params) def default_process_primitive(self, primitive, tracers, params): # By default, if all the input tracers are known, then bind the primitive # and consider all outputs known. Otherwise, stage the application into the # jaxpr and consider all outputs unknown. + tracers = map(self.to_jaxpr_tracer, tracers) consts = [t.pval.get_known() for t in tracers] if all(c is not None for c in consts): - return primitive.bind(*consts, **params) + return primitive.bind_with_trace(self.parent_trace, consts, params) tracers = map(self.instantiate_const, tracers) avals = [t.aval for t in tracers] out_aval, effects = primitive.abstract_eval(*avals, **params) @@ -237,6 +239,7 @@ def default_process_primitive(self, primitive, tracers, params): return out_tracer def process_call(self, primitive, f, tracers, params): + tracers = map(self.to_jaxpr_tracer, tracers) rule = call_partial_eval_rules.get(primitive) if rule: return rule(self, primitive, f, tracers, params) @@ -253,15 +256,15 @@ def process_call(self, primitive, f, tracers, params): # which were unknown to the first call (corresponding to in_avals). # Wrap f to perform the partial evaluation and plumb out aux data. - f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) - f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), - tuple(in_avals)) + f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, False) + f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals)) + # Adjust parameters (e.g. donated_invars) for the call to be evaluated now. const_params = update_params(params, in_knowns, 0) # Run the call, getting known out vals and aux data used for staged-out call - out = primitive.bind(_update_annotation_known(f_, f.in_type, in_knowns), - *in_consts, **const_params) + fun_and_args = (_update_annotation_known(f_, f.in_type, in_knowns),) + tuple(in_consts) + out = primitive.bind_with_trace(self.parent_trace, fun_and_args, const_params) fwds, out_knowns, out_type, jaxpr, env = aux() # Split apart known outputs from the original call and non-fwded residuals. out_consts, non_fwd_res = split_list(out, [sum(out_knowns)]) @@ -284,7 +287,7 @@ def process_call(self, primitive, f, tracers, params): # Create the input tracers for the staged-out (unknown-value) call. res_tracers = map(self.instantiate_const, map(self.new_const, res)) - env_tracers = map(self.full_raise, env) + env_tracers = map(self.to_jaxpr_tracer, env) unknown_arg_tracers = [t for t in tracers if not t.is_known()] # Adjust parameters (e.g. donated_invars) for the staged-out call's args. num_new_args = len(res_tracers) + len(env_tracers) @@ -314,6 +317,7 @@ def process_call(self, primitive, f, tracers, params): return merge_lists(out_knowns, out_tracers, out_consts) def process_map(self, primitive, f: lu.WrappedFun, tracers, params): + tracers = map(self.to_jaxpr_tracer, tracers) update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) @@ -329,7 +333,7 @@ def process_map(self, primitive, f: lu.WrappedFun, tracers, params): for ax, aval in zip(unk_in_axes, in_avals)] # Wrap f to perform partial evaluation and plumb out aux data. - f = trace_to_subjaxpr_nounits(f, self.main, False) + f = trace_to_subjaxpr_nounits2(f, self.tag, False) f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals_mapped)) # Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk) @@ -344,13 +348,13 @@ def const_out_axes_thunk(): out_axes_thunk=const_out_axes_thunk) # Run the map, getting known out vals and aux data used for staged-out map. - out = primitive.bind(f, *in_consts, **const_params) + out = primitive.bind_with_trace(self.parent_trace, (f, *in_consts), const_params) out_knowns, out_avals_mapped, jaxpr, env = aux() # Split apart known outputs from the original call and residuals. out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) # We can only check_jaxpr with the dynamic axis environment extended: - with core.extend_axis_env(params['axis_name'], params['axis_size'], None): + with core.extend_axis_env_nd([(params['axis_name'], params['axis_size'])]): call_jaxpr = convert_constvars_jaxpr(jaxpr) # Compute staged and const out_axes, taking into account residuals. @@ -360,7 +364,7 @@ def const_out_axes_thunk(): # Create the input tracers for the staged-out (unkonwn-value) call. const_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.full_raise, env) + env_tracers = map(self.to_jaxpr_tracer, env) unknown_arg_tracers = [t for t in tracers if not t.is_known()] # Adjust params for staged-out call on unknown values. num_new_args = len(const_tracers) + len(env_tracers) @@ -381,95 +385,24 @@ def const_out_axes_thunk(): return merge_lists(out_knowns, out_tracers, out_consts) - def post_process_call(self, primitive, out_tracers, params): - unknown_out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) - out_pvals = [t.pval for t in out_tracers] - out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - out = [*out_consts, *res] - main = self.main - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) - const_tracers = map(trace.new_instantiated_const, res) - in_tracers = (*const_tracers, *map(trace.full_raise, env)) - out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) - for a in out_avals] - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - new_params = update_params(params, [], len(in_tracers)) - new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, - jaxpr.effects, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - return out, todo - - def post_process_map(self, primitive, out_tracers, params): - unknown_out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr([], unknown_out_tracers) - out_pvals = [t.pval for t in out_tracers] - out_knowns, out_avals_mapped, out_consts = partition_pvals(out_pvals) - out = [*out_consts, *res] - main = self.main - - with core.extend_axis_env(params['axis_name'], params['axis_size'], None): - call_jaxpr = convert_constvars_jaxpr(jaxpr) - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) - const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.full_raise, env) - - staged_out_axes = tuple(out_axes_unknown) # set by out_axes_transform - staged_in_axes = (0,) * len(res) + (None,) * len(env) - - update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) - staged_params = update_params(params, [], len(res) + len(env)) - staged_params = dict(staged_params, in_axes=staged_in_axes, - out_axes=tuple(staged_out_axes), - call_jaxpr=call_jaxpr) - - out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], d, a) - for d, a in zip(staged_out_axes, out_avals_mapped)] - out_tracers = [JaxprTracer(trace, PartialVal.unknown(a), None) - for a in out_avals] - name_stack = self._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, - primitive, staged_params, jaxpr.effects, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - def out_axes_transform(out_axes): - nonlocal out_axes_unknown - out_axes_unknown, out_axes_known = partition_list(out_knowns, out_axes) - return tuple(out_axes_known) + (0,) * len(jaxpr.constvars) - out_axes_unknown: list | None = None - - return out, (todo, out_axes_transform) - def _current_truncated_name_stack(self): return source_info_util.current_name_stack()[len(self.name_stack):] - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - # We assume partial evaluation is only performed to build linear functions, - # and hence we don't need to keep the custom JVP rule around anymore. + def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) + if all(t.is_known() for t in tracers): + with core.set_current_trace(self.parent_trace): + vals = [t.pval[1] for t in tracers] + return prim.bind(fun, jvp, *vals, symbolic_zeros=symbolic_zeros) + # We assume non-trivial partial evaluation is only performed to build linear + # functions, and hence we don't need to keep the custom JVP rule around + # anymore. del jvp, symbolic_zeros - assert not all(t.is_known() for t in tracers) - return fun.call_wrapped(*tracers) - - def post_process_custom_jvp_call(self, out_tracers, _): - # This path should only be reachable if we expose a partial eval API - # unrelated to autodiff, since we raise an error when differentiation with - # respect to values over which a custom_jvp function closes is detected. - raise NotImplementedError # TODO(mattjj) + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) def process_custom_transpose(self, prim, call, tracers, **params): + tracers = map(self.to_jaxpr_tracer, tracers) res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves]) assert all(t.is_known() for t in res_ts) lin_all_known = all(t.is_known() for t in lin_ts) @@ -487,36 +420,41 @@ def process_custom_transpose(self, prim, call, tracers, **params): for t in out_tracers: t.recipe = eqn return out_tracers - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, - symbolic_zeros): - # TODO(mattjj): after old remat is deleted, make this method trivial. - # Because we instantiate all tracers, in_knowns is all False. - tracers = map(self.instantiate_const_abstracted, tracers) - in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) - f = trace_to_subjaxpr_nounits(f, self.main, True) - f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - out_knowns, out_avals, jaxpr, env = aux() - out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - res_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.full_raise, env) - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_avals] - closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) - - @_memoize - def fwd_jaxpr_thunk(*zeros): - fwd_ = _interleave_fun(fwd, zeros) - fwd_ = trace_to_subjaxpr_nounits(fwd_, self.main, True) - fwd_, aux = partial_eval_wrapper_nounits( - fwd_, tuple(in_knowns), tuple(in_avals)) - with core.new_sublevel(): - out_flat = fwd_.call_wrapped() + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) + if all(t.is_known() for t in tracers): + vals = [t.pval[1] for t in tracers] + with core.set_current_trace(self.parent_trace): + return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) + else: + # TODO(mattjj): remove non-ad users of partial eval, then drop this case. + # We stage out the whole thing, i.e. no nontrivial partial evaluation. + tracers = map(self.instantiate_const_abstracted, tracers) + # Because we instantiate all tracers, in_knowns is all False. + in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) + f = trace_to_subjaxpr_nounits(f, self, True) + f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,)) + with core.set_current_trace(self.parent_trace): + out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) out_knowns, out_avals, jaxpr, env = aux() - _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) - return converted_jaxpr, (*res, *env) + out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) + res_tracers = map(self.new_instantiated_const, res) + env_tracers = map(self.to_jaxpr_tracer, env) + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_avals] + closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) + + @_memoize + def fwd_jaxpr_thunk(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True) + fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,)) + out_flat = fwd_.call_wrapped() + out_knowns, out_avals, jaxpr, env = aux() + _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) + converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) + return converted_jaxpr, (*res, *env) name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) @@ -531,12 +469,6 @@ def fwd_jaxpr_thunk(*zeros): for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) - def post_process_custom_vjp_call(self, out_tracers, _): - # This path should only be reachable if we expose a partial eval API - # unrelated to autodiff, since we raise an error when differentiation with - # respect to values over which a custom_vjp function closes is detected. - raise NotImplementedError # TODO(mattjj) - def partition_pvals( pvals: list[PartialVal] ) -> tuple[list[bool], list[AbstractValue], list[Any]]: @@ -587,12 +519,6 @@ def __init__(self, trace: JaxprTrace, pval: PartialVal, recipe: JaxprTracerRecipe | None): assert isinstance(pval, PartialVal) pv, const = pval - if isinstance(const, Tracer) and const._trace.level >= trace.level: - raise core.escaped_tracer_error( - const, f"Tracer from a higher level: {const} in trace {trace}") - if isinstance(pv, DShapedArray): - assert all(not isinstance(d, Tracer) or isinstance(d, JaxprTracer) and - d._trace.level == trace.level for d in pv.shape) self._trace = trace self.pval = pval self.recipe = recipe @@ -614,13 +540,6 @@ def parents(self) -> Sequence[JaxprTracer]: else: return [] - def full_lower(self): - known = self.pval.get_known() - if known is not None: - return core.full_lower(known) - else: - return self - def is_known(self): return self.pval.is_known() @@ -633,84 +552,66 @@ def get_referent(self): return self -@profiler.annotate_function -def trace_to_jaxpr( - fun: lu.WrappedFun, pvals: Sequence[PartialVal], - instantiate: bool | Sequence[bool] = False, - ) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]: - """ - Partially evaluate a function, building a jaxpr for un-evaluated computation. - - Args: - fun: lu.WrappedFun representing the function to be partially evaluated. The - function must be flattened, in the sense of accepting jaxpr type arguments - and returning a flat list of jaxpr type outputs. - pvals: sequence of PartialVals of length equal to the number of inputs to - `fun` indicating which inputs are known or unknown. - instantiate: optional bool or sequence of bools of length equal to the - number of outputs of `fun` indicating which outputs should be forced to be - treated as unknown and hence instantiated in the jaxpr. If a single bool, - the value is applied to all outputs. Default False. - - Returns: - A triple where the first element is a jaxpr representing the computation - which depends on unknown inputs; the second element is a list of PartialVals - of length equal to the length of the output of `fun` representing which - outputs are known and unknown (along with their values and abstract values, - respectively); the third element is a list of known residual values. The - returned jaxpr takes as inputs the known residual values followed by values - of the originally unknown inputs. - """ - current_name_stack = source_info_util.current_name_stack() - with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: - fun = trace_to_subjaxpr(fun, main, instantiate) - jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) - assert not env - del main, fun, env - - return jaxpr, out_pvals, consts - @profiler.annotate_function def trace_to_jaxpr_nounits( fun: lu.WrappedFun, pvals: Sequence[PartialVal], instantiate: bool | Sequence[bool] = False, ) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]: current_name_stack = source_info_util.current_name_stack() - with core.new_main(JaxprTrace, name_stack=current_name_stack) as main: - fun = trace_to_subjaxpr_nounits(fun, main, instantiate) - jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) - assert not env - del main, fun, env - return jaxpr, out_pvals, consts - - + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, TraceTag()) + with core.ensure_no_leaks(trace): + fun = trace_to_subjaxpr_nounits(fun, trace, instantiate) + with core.set_current_trace(trace): + jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) + assert not env + del trace, fun + return jaxpr, out_pvals, consts + +# TODO(mattjj): superfluous wrapper...? @lu.transformation def trace_to_subjaxpr_nounits( - main: core.MainTrace, + trace: JaxprTrace, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) + trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers yield jaxpr, (out_pvals, out_consts, env) -def _trace_to_subjaxpr_nounits(main, instantiate, in_pvals): - trace = main.with_cur_sublevel() +@lu.transformation +def trace_to_subjaxpr_nounits2( + tag: TraceTag, + instantiate: bool | Sequence[bool], + in_pvals: Sequence[PartialVal]): + assert isinstance(tag, TraceTag) + assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals + current_name_stack = source_info_util.current_name_stack() + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, tag) + out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( + trace, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] + del out_tracers + yield jaxpr, (out_pvals, out_consts, env) + +def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] in_args = merge_lists(in_knowns, in_tracers, in_consts) - ans = yield in_args, {} + with core.set_current_trace(trace): + ans = yield in_args, {} assert isinstance(ans, (list, tuple)), ( f"Got unexpected return type when tracing function to jaxpr: {ans}") assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( f"Got unexpected return type when tracing function to jaxpr: {ans}") if isinstance(instantiate, bool): instantiate = [instantiate] * len(ans) - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t + out_tracers = map(trace.to_jaxpr_tracer, ans) + out_tracers = [trace.instantiate_const(t) if inst else t for inst, t in zip(instantiate, out_tracers)] out_tracers_ = [t for t in out_tracers if not t.is_known()] jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_) @@ -721,22 +622,26 @@ def _trace_to_subjaxpr_nounits(main, instantiate, in_pvals): # TODO(mattjj): update all callers to use this version, delete other version. @lu.transformation def trace_to_subjaxpr_nounits_fwd( - main: core.MainTrace, + tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) - out_pvals = [t.pval for t in out_tracers] + current_name_stack = source_info_util.current_name_stack() + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, tag) + with core.set_current_trace(trace): + out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( + trace, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] - # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. - in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] - id_map = {id(c): i for i, c in enumerate(in_consts)} - fwds: list[int | None] = [id_map.get(id(c)) for c in out_consts] - pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] + # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. + in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] + id_map = {id(c): i for i, c in enumerate(in_consts)} + fwds: list[int | None] = [id_map.get(id(c)) for c in out_consts] + pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] - del out_tracers - yield jaxpr, (fwds, out_pvals, pruned_consts, env) + del out_tracers + yield jaxpr, (fwds, out_pvals, pruned_consts, env) # The below variant implements two optimizations: # 1. residuals that are also primal inputs are indicated in aux data rather @@ -745,13 +650,16 @@ def trace_to_subjaxpr_nounits_fwd( # than passed as redundant outputs. @lu.transformation def trace_to_subjaxpr_nounits_fwd2( - main: core.MainTrace, + tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits( - main, instantiate, in_pvals) - out_pvals = [t.pval for t in out_tracers] + current_name_stack = source_info_util.current_name_stack() + with core.take_current_trace() as parent_trace: + trace = JaxprTrace(parent_trace, current_name_stack, tag) + out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits( + trace, instantiate, in_pvals) + out_pvals = [t.pval for t in out_tracers] # Which consts (aka residuals) are just forwarded inputs? Check obj id. in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] @@ -1283,7 +1191,7 @@ def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx: Callable[[core.ParamDict], AbstractContextManager[None]] = trivial_ctx, + ctx = trivial_ctx, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): @@ -1614,13 +1522,7 @@ def _contents(self): return () def _origin_msg(self): - if not self._trace.main.jaxpr_stack: - # If this Tracer has been leaked the jaxpr stack may no longer be - # available. So we can't print as much origin information. - return ("\nThis DynamicJaxprTracer was created on line " - f"{source_info_util.summarize(self._line_info)}") - else: - invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) + invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) dbg = self._debug_info if dbg is None: return "" @@ -1653,10 +1555,6 @@ def _origin_msg(self): origin += "\n\n(Additional originating lines are not shown.)" return "\n" + origin - def _assert_live(self) -> None: - if not self._trace.main.jaxpr_stack: # type: ignore - raise core.escaped_tracer_error(self, None) - def get_referent(self): frame = self._trace.frame val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) @@ -1737,7 +1635,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer] invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) - state_outvars = [self.tracer_to_var[id(trace.full_raise(x))] + state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x))] for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars @@ -1892,11 +1790,25 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: class DynamicJaxprTrace(core.Trace): - __slots__ = [] - - @property - def frame(self): - return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error + def __init__(self, frame): + self.frame = frame + + def invalidate(self): + # avoid cyclic refs + self.frame.tracers = [] + self.frame.constid_to_tracer = {} + + def to_jaxpr_tracer(self, x): + as_local_var = self.frame.tracer_to_var.get(id(x)) + if as_local_var is None: + if hasattr(x, "dimension_as_value"): # Used for shape_poly._DimExpr + with core.set_current_trace(self): + x = x.dimension_as_value() + return self.to_jaxpr_tracer(x) + else: + return self.new_const(x) + else: + return x def new_arg(self, aval): tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) @@ -1924,22 +1836,11 @@ def _new_const(self, aval, c) -> DynamicJaxprTracer: self.frame.constvar_to_val[var] = c return tracer - def sublift(self, t): - # When lifting closed-over tracers corresponding to this same trace, the - # variable to lift could have tracers (representing axis size variables) in - # its shape. We must lift those too! - tracer = self.frame.constid_to_tracer.get(id(t)) - if tracer is None: - aval = raise_to_shaped(get_aval(t), weak_type=dtypes.is_weakly_typed(t)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, t) - return tracer - def _lift_tracers_in_aval(self, aval): if (not isinstance(aval, DShapedArray) or not any(isinstance(d, Tracer) for d in aval.shape)): return aval - shape = [self.full_raise(d) if isinstance(d, Tracer) else d + shape = [self.to_jaxpr_tracer(d) if isinstance(d, Tracer) else d for d in aval.shape] return aval.update(shape=tuple(shape)) @@ -1956,17 +1857,16 @@ def makevar(self, tracer): var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) return var - def instantiate_const(self, val): - if (isinstance(val, Tracer) and val._trace.main is self.main - and val._trace.sublevel == self.sublevel): - return val - else: - return self.new_const(val) + def is_const(self, tracer): + return self.frame.tracer_to_var.get(id(tracer)) is None def process_primitive(self, primitive, tracers, params): + if (config.eager_constant_folding.value and all(map(self.is_const, tracers))): + return primitive.bind_with_trace(core.eval_trace, tracers, params) + jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: - return custom_staging_rules[primitive](self, *tracers, **params) - return self.default_process_primitive(primitive, tracers, params) + return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) + return self.default_process_primitive(primitive, jaxpr_tracers, params) def default_process_primitive(self, primitive, tracers, params): avals = [t.aval for t in tracers] @@ -1986,16 +1886,13 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f, explicit_tracers, params): if f.in_type is None: - f = lu.annotate(f, tuple((raise_to_shaped(t.aval), True) + f = lu.annotate(f, tuple((raise_to_shaped(get_aval(t)), True) for t in explicit_tracers)) implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) - in_tracers = [*implicit_tracers, *explicit_tracers] + in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation - with core.new_sublevel(): - # TODO(lenamartens): Make call_primitive name -> API function name mapping. - # (currently this will display eg. 'xla_call' instead of `jit`) - dbg = debug_info_final(f, call_primitive.name) - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg) + dbg = debug_info_final(f, call_primitive.name) + jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f, debug_info=dbg) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) @@ -2009,7 +1906,7 @@ def process_call(self, call_primitive, f, explicit_tracers, params): aval = aval.update(shape=tuple(get_referent(d) for d in shape)) out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) update_params = call_param_updaters.get(call_primitive) @@ -2017,25 +1914,21 @@ def process_call(self, call_primitive, f, explicit_tracers, params): new_params = update_params(new_params, [True] * len(explicit_tracers), len(consts) + len(implicit_tracers)) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, - new_params, new_params['call_jaxpr'].effects, - source_info) + new_params, new_params['call_jaxpr'].effects, source_info) self.frame.add_eqn(eqn) return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] - def post_process_call(self, call_primitive, out_tracers, params): - assert False # unreachable - def process_map(self, map_primitive, f, tracers, params): + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) if in_axis is not None else a for a, in_axis in zip(in_avals, params['in_axes'])] - with core.extend_axis_env(axis_name, params["global_axis_size"], None): - with core.new_sublevel(): - jaxpr, reduced_out_avals, consts, () = trace_to_subjaxpr_dynamic( - f, self.main, reduced_in_avals, - debug_info=debug_info_final(f, map_primitive.name)) + with core.extend_axis_env_nd([(axis_name, params["global_axis_size"])]): + jaxpr, reduced_out_avals, consts, () = trace_to_jaxpr_dynamic( + f, reduced_in_avals, + debug_info=debug_info_final(f, map_primitive.name)) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) if ordered_effects: raise ValueError("Ordered effects not supported for " @@ -2047,7 +1940,7 @@ def process_map(self, map_primitive, f, tracers, params): source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, @@ -2062,16 +1955,12 @@ def process_map(self, map_primitive, f, tracers, params): self.frame.add_eqn(eqn) return out_tracers - def post_process_map(self, map_primitive, out_tracers, params): - assert False # unreachable - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] in_tangent_avals = [t.to_tangent_aval() for t in in_avals] - with core.new_sublevel(): - fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) - main_ = ref(self.main) @_memoize def jvp_jaxpr_thunk(*in_zeros): @@ -2079,12 +1968,12 @@ def jvp_jaxpr_thunk(*in_zeros): nz_tangent_avals, zero_avals = partition_list(in_zeros, in_tangent_avals) jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals)) in_avals_ = (*in_avals, *nz_tangent_avals) - jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_) + jaxpr, _, out_consts, () = trace_to_jaxpr_dynamic(jvp_, in_avals_) return jaxpr, out_consts, out_zeros() out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, @@ -2096,29 +1985,24 @@ def jvp_jaxpr_thunk(*in_zeros): self.frame.add_eqn(eqn) return out_tracers - def post_process_custom_jvp_call(self, out_tracers, _): - assert False # unreachable - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): + tracers = map(self.to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] - with core.new_sublevel(): - fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) - main_ = ref(self.main) - @_memoize def fwd_jaxpr_from_zeros(*zeros): for store in fwd.stores: store and store.reset() fwd_ = _interleave_fun(fwd, zeros) - jaxpr, _, consts, atr = trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals) + jaxpr, _, consts, atr = trace_to_jaxpr_dynamic(fwd_, in_avals) if atr: raise NotImplementedError return jaxpr, consts out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, dict(fun_jaxpr=closed_fun_jaxpr, @@ -2131,38 +2015,32 @@ def fwd_jaxpr_from_zeros(*zeros): self.frame.add_eqn(eqn) return out_tracers - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable - def process_custom_transpose(self, prim, call, tracers, *, transpose, out_types, lin_tree, res_tree, out_tree): + tracers = map(self.to_jaxpr_tracer, tracers) tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) in_avals_p = [t.aval for t in tracers] in_avals_t = [*[t.aval for t in tracers_res], *out_types] - with core.new_sublevel(): - call_jaxpr, out_avals, call_consts, () = trace_to_subjaxpr_dynamic( - call, self.main, in_avals_p) + call_jaxpr, out_avals, call_consts, _ = trace_to_jaxpr_dynamic(call, in_avals_p) closed_call_jaxpr = core.ClosedJaxpr( convert_constvars_jaxpr(call_jaxpr), ()) transpose_flat, in_tree2 = flatten_fun_nokwargs( lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree))) - main_ = ref(self.main) # the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts @_memoize def transpose_jaxpr_thunk(): for store in transpose_flat.stores: store.reset() - jaxpr, _, consts, () = trace_to_subjaxpr_dynamic( - transpose_flat, main_(), in_avals_t) + jaxpr, _, consts, () = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) return jaxpr, consts out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.instantiate_const, call_consts)) + constvars = map(self.getvar, map(self.to_jaxpr_tracer, call_consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_call_jaxpr, @@ -2182,19 +2060,15 @@ def _interleave_fun(every_others, *args, **kwargs): args_ = [x for pair in zip(args, every_others) for x in pair] yield (yield (args_, kwargs)) +# TODO: consider renaming to "lazy_thunk" def _memoize(fn): cells = {} - saved_state = core.thread_local_state.trace_state.copy() sentinel = object() def memoized(*args): out = cells.get(args, sentinel) if out is sentinel: - prev_state = core.thread_local_state.trace_state - core.thread_local_state.trace_state = saved_state - try: + with core.set_current_trace(None): out = cells[args] = fn(*args) - finally: - core.thread_local_state.trace_state = prev_state return out return memoized @@ -2271,106 +2145,45 @@ def trace_to_jaxpr_dynamic( debug_info: DebugInfo | None = None, *, keep_inputs: list[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: - main.jaxpr_stack = () # type: ignore - jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) - del main, fun - return jaxpr, out_avals, consts, attrs_tracked - - -def trace_to_subjaxpr_dynamic( - fun: lu.WrappedFun, - main: core.MainTrace, - in_avals: Sequence[AbstractValue], - *, - keep_inputs: Sequence[bool] | None = None, - debug_info: DebugInfo | None = None, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs frame = JaxprStackFrame() frame.debug_info = debug_info - with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): - trace = DynamicJaxprTrace(main, core.cur_sublevel()) + + trace = DynamicJaxprTrace(frame) + with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) - in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - ans = fun.call_wrapped(*in_tracers_) - out_tracers = map(trace.full_raise, ans) + in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) + + out_tracers = map(trace.to_jaxpr_tracer, ans) jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers) - del fun, main, trace, frame, in_tracers, out_tracers, ans + del trace, fun, frame, in_tracers, out_tracers, ans + config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked - @profiler.annotate_function def trace_to_jaxpr_dynamic2( fun: lu.WrappedFun, debug_info: DebugInfo | None = None ) -> tuple[Jaxpr, OutputType, list[Any]]: - with core.new_main(DynamicJaxprTrace, dynamic=True) as main: - main.jaxpr_stack = () # type: ignore - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) - del main, fun - return jaxpr, out_type, consts - -def trace_to_subjaxpr_dynamic2( - fun: lu.WrappedFun, main: core.MainTrace, - debug_info: DebugInfo | None = None -) -> tuple[Jaxpr, OutputType, list[Any]]: - in_avals, keep_inputs = unzip2(fun.in_type) - frame = JaxprStackFrame() - frame.debug_info = debug_info - with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack(): - trace = DynamicJaxprTrace(main, core.cur_sublevel()) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) - in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - ans = fun.call_wrapped(*in_tracers_) - out_tracers = map(trace.full_raise, ans) - jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers) - del fun, main, trace, frame, in_tracers, out_tracers, ans - return jaxpr, out_type, consts - - -@contextmanager -def extend_jaxpr_stack(main, frame): - main.jaxpr_stack = main.jaxpr_stack + (frame,) - try: - yield - finally: - assert frame is main.jaxpr_stack[-1] - main.jaxpr_stack = main.jaxpr_stack[:-1] - - -@profiler.annotate_function -def trace_to_jaxpr_final( - fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue], - debug_info: DebugInfo | None = None, - keep_inputs: Sequence[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: - with core.new_base_main(DynamicJaxprTrace) as main: - main.jaxpr_stack = () # type: ignore - with core.new_sublevel(): - jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic( - fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) - del fun, main - return jaxpr, out_avals, consts + trace = DynamicJaxprTrace(JaxprStackFrame()) + with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): + trace.frame.debug_info = debug_info + in_avals, keep_inputs = unzip2(fun.in_type) + in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) + out_tracers = map(trace.to_jaxpr_tracer, ans) + jaxpr = trace.frame.to_jaxpr2(out_tracers) + del trace, in_tracers, out_tracers, ans -@profiler.annotate_function -def trace_to_jaxpr_final2( - fun: lu.WrappedFun, debug_info: DebugInfo | None = None - ) -> tuple[Jaxpr, OutputType, list[Any]]: - with core.new_base_main(DynamicJaxprTrace) as main: - main.jaxpr_stack = () # type: ignore - with core.new_sublevel(): - jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) - del fun, main - return jaxpr, out_type, consts - + return jaxpr AbstractedAxisName = Hashable AbstractedAxesSpec = Union[ @@ -2555,8 +2368,8 @@ def _extract_implicit_args( for d1, d2 in zip(aval.shape, tracer.aval.shape): if isinstance(d1, DBIdx): if tracers[d1.val] is None: - tracers[d1.val] = trace.instantiate_const(d2) - assert tracers[d1.val] is trace.instantiate_const(d2) + tracers[d1.val] = trace.to_jaxpr_tracer(d2) + assert tracers[d1.val] is trace.to_jaxpr_tracer(d2) assert all(t is not None for t in tracers) return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore @@ -2693,32 +2506,9 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): return prim.bind(*subfuns, *args, **bind_params) -# TODO(mattjj): the following are deprecated; update callers to _nounits version -# See https://github.com/jax-ml/jax/pull/9498 -@lu.transformation -def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool], - pvals: Sequence[PartialVal]): - assert all(isinstance(pv, PartialVal) for pv in pvals), pvals - trace = main.with_cur_sublevel() - in_tracers = map(trace.new_arg, pvals) - ans = yield in_tracers, {} - assert isinstance(ans, (list, tuple)), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( - f"Got unexpected return type when tracing function to jaxpr: {ans}") - instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers) - jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers) - out_pvals = [t.pval for t in out_tracers] - del trace, in_tracers, out_tracers - yield jaxpr, (out_pvals, consts, env) - -partial_eval_jaxpr: Callable - def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): if instantiate: - return trace.instantiate_const(trace.full_raise(tracer)) + return trace.instantiate_const(tracer) else: return tracer diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b81cb9ef9238..02ec54ba5d39 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -16,7 +16,6 @@ from __future__ import annotations import enum -from contextlib import contextmanager import collections from collections import namedtuple from collections.abc import Callable, Sequence, Iterable @@ -374,14 +373,15 @@ def _emap_impl(fun: lu.WrappedFun, *args, emap_info = EmapInfo(backend, devices) shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes] - with core.new_base_main(MapTrace, emap_info=emap_info) as main: - with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main): - t = main.with_cur_sublevel() - tracers = [MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)] + trace = MapTrace(axis_name, emap_info) + with core.extend_axis_env_nd([(axis_name, axis_size)]): + tracers = [MapTracer(trace, arg, s) for arg, s in zip(args, shard_axes)] + with core.set_current_trace(trace): ans = fun.call_wrapped(*tracers) - out_tracers = map(t.full_raise, ans) - outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers) - del main + + out_tracers = map(trace.to_map_tracer, ans) + outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers) + out_axes = out_axes_thunk() platform = xb.get_backend(backend).platform @@ -441,25 +441,33 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName], class MapTrace(core.Trace): - def __init__(self, *args, emap_info): - super().__init__(*args) + def __init__(self, axis_name, emap_info): self.emap_info = emap_info + self.axis_name = axis_name - def pure(self, val): - return MapTracer(self, val, {}) - - def sublift(self, tracer): - return MapTracer(self, tracer.val, tracer.shard_axes) + def to_map_tracer(self, val): + if isinstance(val, MapTracer): + return val + else: + return MapTracer(self, val, {}) def process_primitive(self, primitive, tracers, params): - info = self.main.payload["emap_info"] + if primitive is jax._src.lax.parallel.axis_index_p: + return self.process_axis_index(**params) + if primitive is jax._src.lax.parallel.psum_p: + f = HashableFunction( + lambda *xs: jax._src.lax.parallel.psum( + xs, axis_name=params['axes'], axis_index_groups=params['axis_index_groups']), + (primitive, tuple(params.items()))) + else: + f = HashableFunction(lambda *args: primitive.bind(*args, **params), + (primitive, tuple(params.items()))) + tracers = map(self.to_map_tracer, tracers) vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers]) - names = tuple(f.name for f in core.thread_local_state.trace_state.axis_env - if f.main_trace is self.main) + info = self.emap_info + names = core.get_axis_env().axis_names() all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations - f = HashableFunction(lambda *args: primitive.bind(*args, **params), - (primitive, tuple(params.items()))) - f_mapped, out_shard_axes = _multi_pmap(f, info, names, all_axes) + f_mapped, out_shard_axes = _multi_pmap(f, self.emap_info, names, all_axes) with core.eval_context(), jax.disable_jit(False): outvals = f_mapped(*vals) if primitive.multiple_results: @@ -484,14 +492,12 @@ def process_map(self, map_primitive, fun, tracers, params): shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s} if ax is not None else s for v, ax, s in zip(vals, in_axes, shard_axes)] - # TODO(mattjj): use _emap_subtrace here? - with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main): - t = self.main.with_cur_sublevel() - in_tracers = map(partial(MapTracer, t), vals, shard_axes) - ans = fun.call_wrapped(*in_tracers) - out_tracers = map(t.full_raise, ans) + in_tracers = map(partial(MapTracer, self), vals, shard_axes) + with core.extend_axis_env_nd([(axis_name, axis_size)]): + with core.set_current_trace(self): + ans = fun.call_wrapped(*in_tracers) + out_tracers = map(self.to_map_tracer, ans) out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers) - del t, in_tracers, ans, out_tracers out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst) for v, s, dst in zip(out, outaxes, out_axes_thunk())) return map(partial(MapTracer, self), out, outaxes) @@ -502,11 +508,8 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros # always base main, can drop jvp - in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) - fun, out_axes = _emap_subtrace(fun, self.main, in_axes) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(MapTracer, self), out_vals, out_axes()) + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): @@ -515,32 +518,18 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp - in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) - fun, out_axes = _emap_subtrace(fun, self.main, in_axes) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(MapTracer, self), out_vals, out_axes()) + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) - def process_axis_index(self, frame): + def process_axis_index(self, axis_name): bind = HashableFunction( - lambda _: jax.lax.axis_index(frame.name), - (jax.lax.axis_index, frame.name)) + lambda _: jax.lax.axis_index(axis_name), + (jax.lax.axis_index, axis_name)) fake_primitive = FakePrimitive(multiple_results=False, bind=bind) - with core.eval_context(): - range = jax.lax.iota(np.int32, frame.size) - dummy_tracer = MapTracer(self, range, {frame.name: 0}) + range = jax.lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) + dummy_tracer = MapTracer(self, range, {axis_name: 0}) return self.process_primitive(fake_primitive, (dummy_tracer,), {}) -@lu.transformation_with_aux -def _emap_subtrace(main, in_axes, *in_vals): - t = main.with_cur_sublevel() - in_tracers = map(partial(MapTracer, t), in_vals, in_axes) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - out_vals, out_axes = unzip2((t.val, t.shard_axes) for t in out_tracers) - del t, in_tracers, ans, out_tracers - yield out_vals, out_axes - def _annot_to_flat(ndim: int, mapped_axes: Iterable[int], annotation: int | None) -> int | None: if annotation is None: return None @@ -706,11 +695,11 @@ def stage_parallel_callable( fun = _change_argument_ranks(fun, pci.in_axes, pci.out_axes_thunk) else: fun = orig_fun - with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): + with core.extend_axis_env_nd([(pci.axis_name, pci.global_axis_size)]): with dispatch.log_elapsed_time( - "Finished tracing + transforming {fun_name} for pmap in {elapsed_time:.9f} sec", + "Finished tracing + transforming {fun_name} for pmap in {elapsed_time} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( + jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic( fun, sharded_avals, pe.debug_info_final(fun, "pmap")) jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) @@ -748,7 +737,8 @@ def get_pmap_jaxpr( pci = ParallelCallableInfo( name, backend, axis_name, axis_size, global_axis_size, devices, in_axes, out_axes_thunk, avals) - jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) + with core.extend_axis_env_nd([(axis_name, axis_size)]): + jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) return closed_jaxpr, backend, replicas, shards, pci @@ -847,7 +837,7 @@ def lower_parallel_callable( backend.platform) module_name = f"pmap_{fun.__name__}" platforms = lowering_platforms or (backend.platform,) - with maybe_extend_axis_env(axis_name, global_axis_size, None): + with core.extend_axis_env_nd([(axis_name, global_axis_size)]): ordered_effects = list( effects.ordered_effects.filter_in(closed_jaxpr.effects)) if ordered_effects: @@ -1343,7 +1333,7 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval): def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes axis_name = eqn.params["axis_name"] - with maybe_extend_axis_env(axis_name, eqn.params["global_axis_size"], None): + with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]): new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) _, donated_invars = partition_list(used_inputs, eqn.params['donated_invars']) _, in_axes = partition_list(used_inputs, eqn.params['in_axes']) @@ -1402,21 +1392,6 @@ def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) -def _pmap_axis_subst(params, subst, traverse): - if 'call_jaxpr' not in params: - return params - if not traverse: - return params - def shadowed_subst(name): - return (name,) if name in params['axis_name'] else subst(name) - with maybe_extend_axis_env(params['axis_name'], - params['global_axis_size'], None): - new_jaxpr = core.subst_axis_names_jaxpr(params['call_jaxpr'], - shadowed_subst) - return dict(params, call_jaxpr=new_jaxpr) -core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst - - def _unravel_index_hlo(axis_env): div = mlir.ir_constant( np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32)) @@ -1525,7 +1500,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, if in_axis is not None else in_node for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes)) - with maybe_extend_axis_env(axis_name, global_axis_size, None): + with core.extend_axis_env_nd([(axis_name, global_axis_size)]): sub_ctx = ctx.module_context.replace( axis_context=sharding_impls.ReplicaAxisContext(new_env)) sharded_outs, _ = mlir.jaxpr_subcomp( @@ -3203,9 +3178,3 @@ def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: parsed_pspec = sharding_impls.prepare_axis_resources( pspec, "pspec to array_mapping") return _get_array_mapping(parsed_pspec) - - -@contextmanager -def maybe_extend_axis_env(*args, **kwargs): - with core.extend_axis_env(*args, **kwargs): - yield diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index db03143f1083..34395756f25a 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -28,7 +28,6 @@ fori_loop as fori_loop, map as map, scan as scan, - scan_bind as scan_bind, scan_p as scan_p, _scan_impl as _scan_impl, while_loop as while_loop, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index c634148768fc..d189dc0bd2cf 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -148,11 +148,6 @@ def switch(index, branches, *operands): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') - if joined_effects: - # Raise index in case of effects to allow data-dependence-based discharging - # of those effects (even if they don't have an explicit data dependence). - index = core.raise_as_much_as_possible(index) - out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) return tree_unflatten(out_trees[0], out) @@ -263,10 +258,6 @@ def cond(pred, true_fun, false_fun, *operands): f'Effects not supported in `cond`: {disallowed_effects}') index = lax.convert_element_type(pred, np.int32) - if joined_effects: - # Raise index in case of effects to allow data-dependence-based discharging - # of those effects (even if they don't have an explicit data dependence). - index = core.raise_as_much_as_possible(index) false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects) true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) @@ -352,8 +343,7 @@ def _bcast_select_n(pred, *cases): pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases) -def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, - dims, branches): +def _cond_batching_rule(axis_data, args, dims, branches): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -375,15 +365,13 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = ( - batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) + batching.bdim_at_front(x, d, axis_data.size) for x, d in zip(args, dims)) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) branches_batched = [ - batching.batch_jaxpr( - jaxpr, axis_size, in_batched, out_batched, axis_name, spmd_axis_name, - main_type)[0] + batching.batch_jaxpr(jaxpr, axis_data, in_batched, out_batched)[0] for jaxpr in branches] branch_outs = [] @@ -401,13 +389,11 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, for b, x, d in zip(ops_bat, ops, op_dims)] branches_out_bat = [ - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, - spmd_axis_name, main_type)[1] + batching.batch_jaxpr(jaxpr, axis_data, ops_bat, False)[1] for jaxpr in branches] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, - spmd_axis_name, main_type)[0] + batching.batch_jaxpr(jaxpr, axis_data, ops_bat, out_bat)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] @@ -733,12 +719,6 @@ def _cond_transpose(cts, *args, branches): assert next(out_iter, None) is None return [None] + out -def _cond_axis_substitution(params, subst, traverse): - if not traverse: - return params - branches = tuple(core.subst_axis_names_jaxpr(jaxpr, subst) for jaxpr in params['branches']) - return dict(params, branches=branches) - def _cond_typecheck(bind_time, *in_atoms, branches): if not bind_time: _, *in_atoms = in_atoms @@ -793,28 +773,16 @@ def _cond_typecheck(bind_time, *in_atoms, branches): f'called with operands of type {_avals_short(op_avals)}') return jaxpr0.out_avals, joined_effects -def cond_bind(*args, branches): - if config.enable_checks.value: - avals = map(core.get_aval, args) - in_atoms = [core.Var('', a) for a in avals] # dummies - _cond_typecheck(True, *in_atoms, branches=branches) - for jaxpr in branches: - core.check_jaxpr(jaxpr.jaxpr) - return core.AxisPrimitive.bind(cond_p, *args, branches=branches) - -cond_p = core.AxisPrimitive('cond') +cond_p = core.Primitive('cond') cond_p.multiple_results = True cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) -cond_p.def_custom_bind(cond_bind) ad.primitive_jvps[cond_p] = _cond_jvp ad.reducing_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval -batching.spmd_axis_primitive_batchers[cond_p] = _cond_batching_rule -batching.axis_primitive_batchers[cond_p] = partial(_cond_batching_rule, None) +batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule xla.register_initial_style_primitive(cond_p) core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) -core.axis_substitution_rules[cond_p] = _cond_axis_substitution pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.dce_rules[cond_p] = _cond_dce_rule batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 21b522b3d8bb..b6ae09d364a3 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -278,34 +278,30 @@ def _cached_for_jaxpr(jaxpr): discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) return core.ClosedJaxpr(discharged_jaxpr, body_consts) -def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, +def _for_vmap(axis_data, args, dims, *, jaxpr, nsteps, reverse, which_linear, unroll): init_batched = [d is not batching.not_mapped for d in dims] closed_jaxpr = _cached_for_jaxpr(jaxpr) batched = init_batched for _ in range(len(batched)): _, out_batched = batching.batch_jaxpr( - closed_jaxpr, - axis_size, [False] + batched, instantiate=batched, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + closed_jaxpr, axis_data, [False] + batched, instantiate=batched) if out_batched == batched: break batched = map(operator.or_, batched, out_batched) else: raise Exception("Invalid fixpoint") - args = [batching.broadcast(x, axis_size, 0) if now_bat and not was_bat + args = [batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)] batched_jaxpr_, _ = batching.batch_jaxpr( - pe.close_jaxpr(jaxpr), axis_size, [False] + batched, [], - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + pe.close_jaxpr(jaxpr), axis_data, [False] + batched, []) batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=which_linear, unroll=unroll) return out_flat, [0 if b else batching.not_mapped for b in batched] -batching.axis_primitive_batchers[for_p] = functools.partial(_for_vmap, None) -batching.spmd_axis_primitive_batchers[for_p] = _for_vmap +batching.fancy_primitive_batchers[for_p] = _for_vmap def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, unroll): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7a9596bf2c0d..598601cc4097 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -885,7 +885,7 @@ def transposed(*res1_cbar_bbar_res2): b_ys_avals_stripped + res2_avals)) -def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, +def _scan_batching_rule(axis_data, args, dims, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): @@ -902,11 +902,8 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, for _ in range(1 + len(carry_batched)): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( - jaxpr, axis_size, batched, - instantiate=carry_batched + [False] * num_ys, - axis_name=axis_name, - spmd_axis_name=spmd_axis_name, - main_type=main_type) + jaxpr, axis_data, batched, + instantiate=carry_batched + [False] * num_ys) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] if carry_batched_out == carry_batched: break @@ -919,7 +916,7 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims)] - new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched + new_init = [batching.broadcast(x, axis_data.size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched)] @@ -1209,17 +1206,8 @@ def arrange_jaxpr_args_for_wrapped(args): assert len(refs_out_matching_in_avals) == len(in_avals) return refs_out_matching_in_avals, [*carry_out, *ys] -def scan_bind(*args, **params): - if config.enable_checks.value: - avals = _map(core.get_aval, args) - in_atoms = [core.Var('', a) for a in avals] # dummies - _scan_typecheck(True, *in_atoms, **params) - core.check_jaxpr(params['jaxpr'].jaxpr) - return core.AxisPrimitive.bind(scan_p, *args, **params) - -scan_p = core.AxisPrimitive("scan") +scan_p = core.Primitive("scan") scan_p.multiple_results = True -scan_p.def_custom_bind(scan_bind) scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp @@ -1228,8 +1216,7 @@ def scan_bind(*args, **params): xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, mlir.lower_fun(_scan_impl, multiple_results=True)) -batching.axis_primitive_batchers[scan_p] = partial(_scan_batching_rule, None) -batching.spmd_axis_primitive_batchers[scan_p] = _scan_batching_rule +batching.fancy_primitive_batchers[scan_p] = _scan_batching_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom pe.padding_rules[scan_p] = _scan_padding_rule @@ -1382,8 +1369,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects -def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, - args, dims, cond_nconsts, cond_jaxpr, +def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): from jax._src.callback import _IOEffect, _OrderedIOEffect if any(_OrderedIOEffect in fn.effects for fn in [body_jaxpr, cond_jaxpr]): @@ -1401,8 +1387,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # reach a fixpoint. for _ in range(1 + len(carry_bat)): _, carry_bat_out = batching.batch_jaxpr( - body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + body_jaxpr, axis_data, bconst_bat + carry_bat, instantiate=carry_bat) if carry_bat == carry_bat_out: break carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out) @@ -1412,8 +1397,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # Knowing how the carry is batched now, we can determine if the predicate is # batched. _, (pred_bat,) = batching.batch_jaxpr( - cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + cond_jaxpr, axis_data, cconst_bat + carry_bat, instantiate=False) if pred_bat: # If the predicate is batched, we have to batch *all* of the carry @@ -1424,13 +1408,9 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, carry_bat = [True] * len(carry_bat) carry_dims = [0] * len(carry_bat) body_jaxpr_batched, _ = batching.batch_jaxpr_axes( - body_jaxpr, axis_size, bconst_dims + carry_dims, - carry_dims, axis_name=axis_name, spmd_axis_name=spmd_axis_name, - main_type=main_type) + body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims) cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( - cond_jaxpr, axis_size, cconst_dims + carry_dims, [0], - axis_name=axis_name, spmd_axis_name=spmd_axis_name, - main_type=main_type) + cond_jaxpr, axis_data, cconst_dims + carry_dims, [0]) else: # If the predicate is not batched, we can look at the `cond_jaxpr`'s out # shape to determine the rank of the predicate. From this rank we pick the @@ -1440,13 +1420,11 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, cond_rank = len(cond_jaxpr.out_avals[0].shape) carry_dims = [cond_rank if b else None for b in carry_bat] body_jaxpr_batched, _ = batching.batch_jaxpr_axes( - body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims) # Now we need to rebatch the `cond_jaxpr` according to the new dims of the # carry. cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( - cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,), - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + cond_jaxpr, axis_data, cconst_dims + carry_dims, (None,)) # To prepare the `init` to the `while_p`, we broadcast values if they are # unbatched and need to have an out axis. If their current batch axis does not @@ -1455,7 +1433,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, new_init = [] for x, old_axis, new_axis in zip(init, init_dims, carry_dims): if old_axis is batching.not_mapped and new_axis is not batching.not_mapped: - new_init.append(batching.broadcast(x, axis_size, new_axis)) + new_init.append(batching.broadcast(x, axis_data.size, new_axis)) elif old_axis is batching.not_mapped and new_axis is batching.not_mapped: new_init.append(x) else: @@ -1891,7 +1869,7 @@ def new_cond(*consts_refs_carry): *[None] * num_carry] return invals_out, carry_out -while_p = core.AxisPrimitive('while') +while_p = core.Primitive('while') while_p.multiple_results = True while_p.def_impl(partial(dispatch.apply_primitive, while_p)) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) @@ -1899,8 +1877,7 @@ def new_cond(*consts_refs_carry): pe.custom_partial_eval_rules[while_p] = _while_partial_eval xla.register_initial_style_primitive(while_p) ad.primitive_transposes[while_p] = _while_transpose_error -batching.axis_primitive_batchers[while_p] = partial(_while_loop_batching_rule, None) -batching.spmd_axis_primitive_batchers[while_p] = _while_loop_batching_rule +batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) core.custom_typechecks[while_p] = _while_typecheck diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 4e0f5086b121..9a5a01e3987d 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -376,8 +376,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): return [None] * sum(const_lengths) + cotangent_b -def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, - args, dims, const_lengths, jaxprs): +def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] params, b = _split_linear_solve_args(args, const_lengths) @@ -397,15 +396,13 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( - solve, axis_size, solve_bat + b_bat, instantiate=x_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + solve, axis_data, solve_bat + b_bat, instantiate=x_bat) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( - vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + vecmat, axis_data, vecmat_bat + b_bat, instantiate=b_bat) # batch all aux data by default x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) # keep a slice of only the linear operator part of solve's avals @@ -413,15 +410,13 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( - matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + matvec, axis_data, matvec_bat + x_bat_noaux, instantiate=b_bat) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( - solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + solve_t, axis_data, solve_t_bat + x_bat_noaux, instantiate=x_bat_out) assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, @@ -445,7 +440,7 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, ] # Broadcast out b if necessary new_b = [ - batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else + batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] @@ -458,7 +453,7 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, return outs, out_dims -linear_solve_p = core.AxisPrimitive('custom_linear_solve') +linear_solve_p = core.Primitive('custom_linear_solve') linear_solve_p.multiple_results = True linear_solve_p.def_impl(_custom_linear_solve_impl) linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval) @@ -468,5 +463,4 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule -batching.axis_primitive_batchers[linear_solve_p] = partial(_linear_solve_batching_rule, None) -batching.spmd_axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule +batching.fancy_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c0c594c4abdc..bbb23bcd1725 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1759,6 +1759,9 @@ def stop(x): return x elif (dtypes.issubdtype(_dtype(x), np.floating) or dtypes.issubdtype(_dtype(x), np.complexfloating)): + # break abstractions to support legacy leaked tracer use cases + if isinstance(x, ad.JVPTracer): + return stop(x.primal) return ad_util.stop_gradient_p.bind(x) else: return x @@ -2979,14 +2982,18 @@ def _convert_elt_type_pp_rule(eqn, context, settings): return core._pp_eqn(eqn.replace(params=params), context, settings) convert_element_type_p = Primitive('convert_element_type') -def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): - operand = core.Primitive.bind(convert_element_type_p, operand, - new_dtype=new_dtype, weak_type=weak_type, - sharding=sharding) + +# TODO(dougalm): I'm overriding bind_with_trace here because that's the closest thing to +# the old "custom bind" but it might not be the best way to do this. +def _convert_element_type_bind_with_trace(trace, args, params): + sharding = params['sharding'] + operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params) if sharding is not None and not config.sharding_in_types.value: - operand = pjit.with_sharding_constraint(operand, sharding) + with core.set_current_trace(trace): + operand = pjit.with_sharding_constraint(operand, sharding) return operand -convert_element_type_p.def_custom_bind(_convert_element_type_bind) +convert_element_type_p.def_bind_with_trace(_convert_element_type_bind_with_trace) + convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 9d4614f344fb..cbea424a9d95 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,6 +24,7 @@ from jax import tree_util from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src import sharding_impls from jax._src.core import AxisName, ShapedArray, raise_to_shaped @@ -119,8 +120,25 @@ def psum(x, axis_name, *, axis_index_groups=None): leaves = [lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves] axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - out_flat = psum_p.bind( - *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + # handle the constant case specially + if all(not isinstance(leaf, core.Tracer) for leaf in leaves): + named_axes, pos_axes = axes_partition = [], [] + for axis in axis_name: + axes_partition[isinstance(axis, int)].append(axis) + def pos_reduce(x): + if not pos_axes: + return x + return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0)) + for axis in pos_axes]) + if axis_index_groups is not None: + assert not pos_axes + size = len(axis_index_groups[0]) + else: + size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) + out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) + else: + out_flat = psum_p.bind( + *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) def pmean(x, axis_name, *, axis_index_groups=None): @@ -233,7 +251,7 @@ def _axis_index_of_val(x, val, axis_name): mask = (val == x) validx = lax.select(mask, lax.full(mask.shape, idx), - lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtype=idx.dtype)) + lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtypes.dtype(idx))) return pmin(validx, axis_name) def _validate_reduce_axis_index_groups(axis_index_groups): @@ -303,6 +321,8 @@ def ppermute(x, axis_name, perm): Array(s) with the same shape as ``x`` with slices along the axis ``axis_name`` gathered from ``x`` according to the permutation ``perm``. """ + if not isinstance(axis_name, (list, tuple)): + axis_name = (axis_name,) return tree_util.tree_map( partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(map(tuple, perm))), x) @@ -472,8 +492,15 @@ def axis_index(axis_name): [0 1] [0 1]] """ - return axis_index_p.bind(axis_name=axis_name) - + if not isinstance(axis_name, (tuple, list)): + return axis_index_p.bind(axis_name=axis_name) + else: + inner_size = 1 + index = 0 + for name in reversed(axis_name): + index += axis_index(name) * inner_size + inner_size *= psum(1, name) + return index def pgather(src, idx, axes: int | AxisName): """Uses the last positional axis of idx to index into src's axes.""" @@ -485,18 +512,30 @@ def pgather(src, idx, axes: int | AxisName): ### parallel primitives -def _subst_all_names_in_param( - pname: str, params: core.ParamDict, subst: core.AxisSubst, traverse: bool) -> core.ParamDict: - axis_name = params[pname] - if not isinstance(axis_name, (tuple, list)): - axis_name = (axis_name,) - result = dict(params) - result[pname] = sum(((name,) if isinstance(name, int) else subst(name) - for name in axis_name), - ()) - return result +def _names_in_param(pname: str, params: core.ParamDict) -> tuple[str]: + axis_names = params[pname] + if isinstance(axis_names, (tuple, list)): + return tuple(axis_names) + else: + return (axis_names,) + +def _constant_reduction(prim, axis_data, args, axes, axis_index_groups): + assert axis_data.name in axes + if axis_index_groups: raise NotImplementedError + new_axes = tuple(n for n in axes if n != axis_data.name) + if new_axes: + args = prim.bind(*args, axes=new_axes, axis_index_groups=axis_index_groups) + if prim is psum_p: + outs = [lax._const(x, axis_data.size) * x for x in args] + elif prim in (pmin_p, pmax_p): + outs = args + else: + raise Exception(f"Unrecognized reducer: {prim}") + + return outs, [None] * len(outs) -def _reduction_with_positional_batcher(prim, vals_in, dims_in, axis_index_groups, +def _reduction_with_positional_batcher( + prim, vals_in, dims_in, axis_index_groups, transform_unmapped, transform_mapped): if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap collectives. " @@ -536,10 +575,19 @@ def _reduction_batcher(prim, vals_in, dims_in, *, axes, axis_index_groups): return vals_out, [d if d is batching.not_mapped else 0 for d in dims_in] def _batched_reduction_collective( - prim, if_unmapped, axis_size, frame_name, _, vals_in, dims_in, axes, + prim, if_unmapped, axis_data, vals_in, dims_in, axes, axis_index_groups): assert prim.multiple_results - assert frame_name in axes + if all(d is None for d in dims_in): + if axis_data.name in axes: + return _constant_reduction(prim, axis_data, vals_in, axes, axis_index_groups) + else: + return prim.bind(*vals_in, axes=axes, axis_index_groups=axis_index_groups), dims_in + + if axis_data.name not in axes: + return _reduction_batcher(prim, vals_in, dims_in, axes=axes, + axis_index_groups=axis_index_groups) + # Note that we have a choice here. We can either unfuse the reduction into one # that handles the batched dims and then another one that handles the rest. # Alternatively, we can keep the dimension reduction fused with the rest, but @@ -548,12 +596,11 @@ def _batched_reduction_collective( # We choose the second strategy here. vals_out = _reduction_with_positional_batcher( prim, vals_in, dims_in, axis_index_groups, - lambda d, d_vals_in: (tuple(axis for axis in axes if axis != frame_name), - [if_unmapped(v, axis_size) for v in d_vals_in]), + lambda d, d_vals_in: (tuple(axis for axis in axes if axis != axis_data.name), + [if_unmapped(v, axis_data.size) for v in d_vals_in]), lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else - axis if axis != frame_name else - d - for axis in axes), + axis if axis != axis_data.name else + d for axis in axes), d_vals_in)) return vals_out, [batching.not_mapped] * len(vals_out) @@ -572,12 +619,16 @@ def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]] dtype=np.int64).T return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups)) -def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups): +def _allreduce_impl(prim, pos_reducer, *args, axes, axis_index_groups): assert axis_index_groups is None + if not all(isinstance(axis, int) for axis in axes): + return dispatch.apply_primitive(prim, *args, axes=axes, + axis_index_groups=axis_index_groups) assert all(isinstance(axis, int) for axis in axes) return [pos_reducer(arg, axes) for arg in args] def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): + _check_axis_names(axes) named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) if axis_index_groups is not None: @@ -589,6 +640,13 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): arg.dtype) for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +def _check_axis_names(axes): + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + axis_env = core.get_axis_env() + for name in named_axes: + if not axis_env.axis_exists(name): + raise NameError(f"unbound axis name: {name}") + def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): len_0 = len(axis_index_groups[0]) @@ -669,64 +727,37 @@ def broadcast_positional(ct, arg): axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts) -psum_p = core.AxisPrimitive('psum') +psum_p = core.Primitive('psum') psum_p.multiple_results = True -psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum)) +psum_p.def_impl(partial(_allreduce_impl, psum_p, lax._reduce_sum)) psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum)) ad.deflinear2(psum_p, _psum_transpose_rule) -batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p) -batching.axis_primitive_batchers[psum_p] = \ +batching.fancy_primitive_batchers[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v) -core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes') - - -# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at -# tracing time. -@psum_p.def_custom_bind -def psum_bind(*args, axes, axis_index_groups): - if all(not isinstance(x, core.Tracer) for x in args): - named_axes, pos_axes = axes_partition = [], [] - for axis in axes: - axes_partition[isinstance(axis, int)].append(axis) - def pos_reduce(x): - if not pos_axes: - return x - return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0)) - for axis in pos_axes]) - if axis_index_groups is not None: - assert not pos_axes - size = len(axis_index_groups[0]) - else: - size = math.prod([core.axis_frame(name).size for name in named_axes]) - return tuple(lax._const(x, size) * pos_reduce(x) for x in args) - return core.AxisPrimitive.bind( - psum_p, *args, axes=axes, axis_index_groups=axis_index_groups) - +batching.skippable_batchers[psum_p] = partial(_names_in_param, 'axes') -pmax_p = core.AxisPrimitive('pmax') +pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True -pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max)) +pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax._reduce_max)) pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max)) -batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p) -batching.axis_primitive_batchers[pmax_p] = \ +batching.fancy_primitive_batchers[pmax_p] = \ partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v) -core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes') +batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes') -pmin_p = core.AxisPrimitive('pmin') +pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True -pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min)) +pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax._reduce_min)) pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min)) -batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p) -batching.axis_primitive_batchers[pmin_p] = \ +batching.fancy_primitive_batchers[pmin_p] = \ partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v) -core.axis_substitution_rules[pmin_p] = partial(_subst_all_names_in_param, 'axes') +batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes') def _ppermute_lowering(ctx, x, *, axis_name, perm): @@ -765,15 +796,16 @@ def _ppermute_transpose_rule(t, x, perm, axis_name): inverse_perm = list(zip(dsts, srcs)) return [ppermute(t, axis_name=axis_name, perm=inverse_perm)] -def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, perm): +def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): + axis_size, frame_name = axis_data.size, axis_data.name (v,), (d,) = vals_in, dims_in if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) + if axis_data.name not in axis_name: + return ppermute_p.bind(v, perm=perm, axis_name=axis_name), d remaining_axes = tuple(axis for axis in axis_name if axis != frame_name) - if axis_size == 1 and remaining_axes: - return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d if remaining_axes: - raise NotImplementedError("ppermute batcher only supports a single axis") + return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d assert axis_name[0] == frame_name, "ppermute batcher called with a wrong axis!" assert len(perm) == axis_size, "Permutation doesn't match the axis size!" if d is batching.not_mapped: @@ -783,30 +815,33 @@ def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, per perm_indices[dst] = src return v.take(perm_indices, d), d -def _collective_batcher(prim, args, dims, **params): - return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] +def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): + _check_axis_names(axis_name) + return raise_to_shaped(x) -ppermute_p = core.AxisPrimitive('ppermute') -ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) +ppermute_p = core.Primitive('ppermute') +ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval) ad.deflinear2(ppermute_p, _ppermute_transpose_rule) mlir.register_lowering(ppermute_p, _ppermute_lowering) -batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p) -batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher -core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[ppermute_p] = _ppermute_batcher +batching.skippable_batchers[ppermute_p] = partial(_names_in_param, 'axis_name') def _pbroadcast_transpose_rule(t, x, source, axis_name): is_source = axis_index(axis_name) == source tsum = psum(t, axis_name) return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))] -def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source): +def _pbroadcast_batcher(axis_data, vals_in, dims_in, axis_name, source): + axis_size = axis_data.size (v,), (d,) = vals_in, dims_in if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) - remaining_axes = tuple(axis for axis in axis_name if axis != frame_name) + if axis_data.name not in axis_name: + return pbroadcast_p.bind(v, axis_name=axis_name, source=source), d + remaining_axes = tuple(axis for axis in axis_name if axis != axis_data.name) if remaining_axes: raise NotImplementedError("pbroadcast batcher only supports a single axis") - assert axis_name[0] == frame_name, "pbroadcast batcher called with a wrong axis!" + assert axis_name[0] == axis_data.name, "pbroadcast batcher called with a wrong axis!" assert source >= 0 and source < axis_size, "collective broadcast doesn't fit in the axis size!" if axis_size == 1 and remaining_axes: return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d @@ -823,13 +858,12 @@ def source_to_front(group): return hlo.CollectiveBroadcastOp( x, replica_groups=_replica_groups_hlo(replica_groups)).results -pbroadcast_p = core.AxisPrimitive('pbroadcast') -pbroadcast_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) +pbroadcast_p = core.Primitive('pbroadcast') +pbroadcast_p.def_abstract_eval(_raise_to_shaped_abstract_eval) ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule) mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering) -batching.primitive_batchers[pbroadcast_p] = partial(_collective_batcher, pbroadcast_p) -batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher -core.axis_substitution_rules[pbroadcast_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher +batching.skippable_batchers[pbroadcast_p] = partial(_names_in_param, 'axis_name') def _moveaxis(src, dst, x): @@ -914,11 +948,22 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, ) return result, d -def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, +def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): + axis_size, frame_name = axis_data.size, axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") + + if isinstance(axis_name, (list, tuple)): + axes_names = axis_name + else: + axes_names = [axis_name] + if axis_data.name not in axes_names: + return _all_to_all_batcher( + vals_in, dims_in, axis_name=axis_name, split_axis=split_axis, + concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled) + x, = vals_in d, = dims_in if d is batching.not_mapped: @@ -979,6 +1024,7 @@ def _all_to_all_effectful_abstract_eval( del tiled # expand_dims and squeeze is done in `all_to_all` if `True` if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) + _check_axis_names(axis_name) input_aval = raise_to_shaped(x) shape = list(input_aval.shape) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) @@ -990,13 +1036,12 @@ def _all_to_all_effectful_abstract_eval( return out_aval, effects -all_to_all_p = core.AxisPrimitive('all_to_all') +all_to_all_p = core.Primitive('all_to_all') all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval) mlir.register_lowering(all_to_all_p, _all_to_all_lowering) ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) -batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher -batching.axis_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective -core.axis_substitution_rules[all_to_all_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective +batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name') def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): @@ -1063,6 +1108,8 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): [[12 13 14 15] [ 4 5 6 7]]] """ + if not isinstance(axis_name, tuple): + axis_name = axis_name, axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) def bind(leaf): @@ -1071,7 +1118,7 @@ def bind(leaf): all_gather_dimension=canonicalize_axis( axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1), axis_name=axis_name, axis_index_groups=axis_index_groups, - axis_size=axis_size, tiled=tiled) + axis_size=int(axis_size), tiled=tiled) return tree_util.tree_map(bind, x) def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): @@ -1126,6 +1173,7 @@ def _all_gather_effectful_abstract_eval( ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) + _check_axis_names(axis_name) x_aval = raise_to_shaped(x) new_shape = list(x_aval.shape) if tiled: @@ -1144,10 +1192,11 @@ def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): (x,), (d,) = vals_in, dims_in - if d <= all_gather_dimension: - all_gather_dimension += 1 - elif not tiled: # Tiled all-gather doesn't modify the set of dimensions - d += 1 + if d is not batching.not_mapped: + if d <= all_gather_dimension: + all_gather_dimension += 1 + elif not tiled: # Tiled all-gather doesn't modify the set of dimensions + d += 1 result = all_gather_p.bind( x, all_gather_dimension=all_gather_dimension, @@ -1157,9 +1206,15 @@ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, ax tiled=tiled) return result, d -def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, +def _all_gather_batched_collective(axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): + frame_size, frame_name = axis_data.size, axis_data.name + if frame_name not in axis_name: + return _all_gather_batcher( + vals_in, dims_in, all_gather_dimension=all_gather_dimension, + axis_name=axis_name, axis_index_groups=axis_index_groups, + axis_size=axis_size, tiled=tiled) if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap") assert axis_size == frame_size, "axis size doesn't match" @@ -1180,7 +1235,7 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, y = _foldaxis(all_gather_dimension, y) return y, batching.not_mapped -all_gather_p = core.AxisPrimitive('all_gather') +all_gather_p = core.Primitive('all_gather') all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval) all_gather_p.def_impl(_all_gather_impl) mlir.register_lowering(all_gather_p, _all_gather_lowering) @@ -1189,9 +1244,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, partial(_all_gather_lowering, platform=p), platform=p) ad.deflinear2(all_gather_p, _all_gather_transpose_rule) -batching.primitive_batchers[all_gather_p] = _all_gather_batcher -batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective -core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name') +batching.fancy_primitive_batchers[all_gather_p] = _all_gather_batched_collective +batching.skippable_batchers[all_gather_p] = partial(_names_in_param, 'axis_name') def _reduce_scatter_lowering( @@ -1248,6 +1302,7 @@ def _reduce_scatter_effectful_abstract_eval( ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) + _check_axis_names(axis_name) x_aval = core.raise_to_shaped(x) new_shape = list(x_aval.shape) scatter_dim_input_size = x_aval.shape[scatter_dimension] @@ -1289,9 +1344,15 @@ def _reduce_scatter_batcher(vals_in, dims_in, *, scatter_dimension, axis_name, tiled=tiled) return result, d -def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, +def _reduce_scatter_collective(axis_data, vals_in, dims_in, scatter_dimension, axis_name, axis_index_groups, axis_size, tiled): + frame_size, frame_name = axis_data.size, axis_data.name + if frame_name not in axis_name: + return _reduce_scatter_batcher( + vals_in, dims_in, scatter_dimension=scatter_dimension, + axis_name=axis_name, axis_index_groups=axis_index_groups, + axis_size=axis_size, tiled=tiled) if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap") assert axis_size == frame_size, "axis size doesn't match" @@ -1310,21 +1371,17 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, return y, dy -reduce_scatter_p = core.AxisPrimitive("reduce_scatter") +reduce_scatter_p = core.Primitive("reduce_scatter") reduce_scatter_p.def_effectful_abstract_eval( _reduce_scatter_effectful_abstract_eval ) ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule) -batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher -batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective +batching.fancy_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective +batching.skippable_batchers[reduce_scatter_p] = partial(_names_in_param, 'axis_name') mlir.register_lowering(reduce_scatter_p, partial(_reduce_scatter_lowering, lax.add_p)) -core.axis_substitution_rules[reduce_scatter_p] = \ - partial(_subst_all_names_in_param, 'axis_name') - - def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False): """ @@ -1401,6 +1458,8 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, [12 14] [16 18]] """ + if not isinstance(axis_name, tuple): + axis_name = axis_name, axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) bind = partial( @@ -1420,6 +1479,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): raise NotImplementedError( '`axis_index` translation rule does not support multiple axis names.') axis_name, = axis_name + if axis_name not in axis_env.names: + raise NameError(f"unbound axis name: {axis_name}") axis_pos = list(axis_env.names).index(axis_name) nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( @@ -1443,51 +1504,22 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): unsigned_index) def _axis_index_lowering(ctx, *, axis_name): - return [ - _build_axis_index_lowering_hlo(ctx, axis_name, - ctx.module_context.axis_env) - ] - + return [_build_axis_index_lowering_hlo(ctx, axis_name, + ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - frame = core.axis_frame(axis_name) + _check_axis_names([axis_name]) return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} +def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): + return lax.iota(np.int32, axis_data.size), 0 + axis_index_p = core.Primitive('axis_index') +axis_index_p.def_impl(partial(dispatch.apply_primitive, axis_index_p)) mlir.register_lowering(axis_index_p, _axis_index_lowering) axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval) -core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name') - -# Axis index doesn't get any arguments, so that the default bind would have no -# way to call into a data-dependency based trace such as vmap. Each trace that -# wants to bind an axis name has to additionally implement `process_axis_index` -# and put its main trace on the axis env stack. -def _axis_index_bind(*, axis_name): - def name_idx(name): - frame = core.axis_frame(name) - dynamic = core.thread_local_state.trace_state.trace_stack.dynamic - if (frame.main_trace is None or dynamic.level > frame.main_trace.level): - return core.Primitive.bind(axis_index_p, axis_name=name) - else: - trace = frame.main_trace.with_cur_sublevel() - return trace.process_axis_index(frame) - - if not isinstance(axis_name, (tuple, list)): - return name_idx(axis_name) - else: - inner_size = 1 - index = 0 - for name in reversed(axis_name): - index += name_idx(name) * inner_size - inner_size *= psum(1, name) - return index -axis_index_p.def_custom_bind(_axis_index_bind) - -def _vmap_process_axis_index(self, frame): - assert frame.size is not None - return batching.BatchTracer(self, lax.iota(np.int32, frame.size), 0) -batching.BatchTrace.process_axis_index = _vmap_process_axis_index # type: ignore - +batching.fancy_primitive_batchers[axis_index_p] = _axis_index_batcher +batching.skippable_batchers[axis_index_p] = partial(_names_in_param, 'axis_name') def _pgather_impl(src, idx, *, axes): assert all(isinstance(axis, int) for axis in axes) @@ -1508,6 +1540,7 @@ def _pgather_impl(src, idx, *, axes): def _pgather_abstract_eval(src, idx, *, axes): # TODO: Avals with names rule: remove all axes from src, insert those from idx # The order is important, because it is ok to re-insert one of the deleted axes! + _check_axis_names(axes) shape = list(src.shape) for axis in sorted((a for a in axes if isinstance(a, int)), reverse=True): del shape[axis] @@ -1559,11 +1592,10 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a else: return pgather_p.bind(src, idx, axes=new_axes), batching.not_mapped -pgather_p = core.AxisPrimitive('pgather') +pgather_p = core.Primitive('pgather') pgather_p.def_impl(_pgather_impl) pgather_p.def_abstract_eval(_pgather_abstract_eval) mlir.register_lowering(pgather_p, _pgather_parallel_lowering) # TODO: Transpose? That requires adding pscatter... -batching.primitive_batchers[pgather_p] = _pgather_batcher -batching.axis_primitive_batchers[pgather_p] = _pgather_collective_batcher -core.axis_substitution_rules[pgather_p] = partial(_subst_all_names_in_param, 'axes') +batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher +batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes') diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 8cb1fedb9ef3..dd8f671c639c 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -64,14 +64,12 @@ def trans1(static_arg, *dynamic_args, **kwargs): from __future__ import annotations from collections.abc import Callable -from functools import partial from typing import Any, NamedTuple import weakref from jax._src import config from jax._src import core from jax._src import traceback_util -from jax._src.tree_util import tree_map from jax._src.util import curry, cache_clearing_funs @@ -337,13 +335,8 @@ def cache(call: Callable, *, explain: Callable | None = None): def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore - if config.check_tracer_leaks.value: - key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args, - config.enable_x64.value, config.default_device.value, - config.trace_context()) - else: - key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value, - config.default_device.value, config.trace_context()) + key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value, + config.default_device.value, config.trace_context()) result = cache.get(key, None) if result is not None: ans, stores = result @@ -364,17 +357,6 @@ def _evict_function(f): cache_clearing_funs.add(memoized_fun.cache_clear) return memoized_fun - -def _copy_main_trace(x): - if isinstance(x, core.MainTrace): - return core.MainTrace(x.level, x.trace_type, **x.payload) - else: - return x - -_copy_main_traces = partial(tree_map, _copy_main_trace) - - - @transformation def hashable_partial(*args): yield (yield args, {}) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 7b98a5314744..4768a8126c72 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -607,7 +607,6 @@ def __array_module__(self, types): return NotImplemented -@core.stash_axis_env() @partial(jax.jit, static_argnums=(1,2,3)) def _multi_slice(self: Array, start_indices: tuple[tuple[int, ...]], diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index dad45bbae207..b697810b8967 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1142,14 +1142,3 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh): effs.add(eff) return [], effs jax_core.custom_typechecks[core_map_p] = _core_map_typecheck_rule - - -def _core_map_axis_subst(params, subst, traverse): - if not traverse: - return params - def shadowed_subst(name): - return (name,) if name in params['mesh'].shape else subst(name) - with jax_core.extend_axis_env_nd(params['mesh'].shape.items()): - new_jaxpr = jax_core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst) - return dict(params, jaxpr=new_jaxpr) -jax_core.axis_substitution_rules[core_map_p] = _core_map_axis_subst diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 7aab30ffc2ab..9ea2b59f66f1 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -593,14 +593,14 @@ def dma_start_discharge_rule(in_avals, out_avals, # Note that this code only works in SPMD mode. If not all devices execute # the DMA then the devices that do will hang. # TODO(justinfu): Verify that code only works in SPMD mode. - axis_env = jax_core.thread_local_state.trace_state.axis_env - nonempty_axes = [frame for frame in axis_env if frame.name is not None] + axis_env = jax_core.get_axis_env() + nonempty_axes = [name for name in axis_env.axis_sizes if name is not None] if device_id_type == DeviceIdType.LOGICAL: if len(nonempty_axes) > 1: raise NotImplementedError("Sharding with more than one named axis not " "implemented in dma_start_p for LOGICAL " "device_id_type.") - shard_axis = nonempty_axes[0].name + shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) elif device_id_type == DeviceIdType.MESH: device_id_len = 1 @@ -608,9 +608,9 @@ def dma_start_discharge_rule(in_avals, out_avals, device_id_len = device_id.size elif hasattr(device_id, '__len__'): device_id_len = len(device_id) - if device_id_len != len(axis_env): + if device_id_len != len(axis_env.axis_sizes): raise ValueError( - f"device_id ({device_id_len}) and mesh ({len(axis_env)}) " + f"device_id ({device_id_len}) and mesh ({len(axis_env.axis_sizes)}) " "must have same length.") if device_id_len > 1 or len(nonempty_axes) > 1: raise NotImplementedError("Meshes with more than 1 named dimension not " diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index b41ce3632468..c7bd7dd7178f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -68,8 +68,8 @@ def program_id(axis: int) -> jax.Array: """ return program_id_p.bind(axis=axis) -@program_id_p.def_custom_bind -def program_id_bind(*, axis: int): +def program_id_bind_with_trace(trace, _, params): + axis = params.pop("axis") grid_env = pallas_core.current_grid_env() if grid_env: return grid_env[axis].index @@ -77,7 +77,9 @@ def program_id_bind(*, axis: int): # Query the size of the axis to make sure it's a valid axis (and error # otherwise). _ = frame.size(axis) - return jax_core.Primitive.bind(program_id_p, axis=axis) + return jax_core.Primitive.bind_with_trace(program_id_p, trace, (), dict(axis=axis)) +# TODO(dougalm): figure out how put the grid_env contest on the relevant trace +program_id_p.def_bind_with_trace(program_id_bind_with_trace) @program_id_p.def_abstract_eval def _program_id_abstract_eval(**_): @@ -89,8 +91,8 @@ def num_programs(axis: int) -> int | jax.Array: """Returns the size of the grid along the given axis.""" return num_programs_p.bind(axis=axis) -@num_programs_p.def_custom_bind -def _num_programs_bind(*, axis: int): +def _num_programs_bind_with_trace(trace, _, params): + axis = params.pop("axis") # We might be using a local grid env grid_env = pallas_core.current_grid_env() if grid_env: @@ -99,8 +101,9 @@ def _num_programs_bind(*, axis: int): frame = pallas_core.axis_frame() size = frame.size(axis) if size is pallas_core.dynamic_grid_dim: - return jax_core.Primitive.bind(num_programs_p, axis=axis) + return jax_core.Primitive.bind_with_trace(num_programs_p, trace, (), dict(axis=axis)) return size +num_programs_p.def_bind_with_trace(_num_programs_bind_with_trace) @num_programs_p.def_abstract_eval def _num_programs_abstract_eval(**_): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index c0a1cde4f8b6..904e92af2f91 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1437,7 +1437,7 @@ def check_aval_layout_compatibility( # -------------------- pjit rules -------------------- -pjit_p = core.AxisPrimitive("pjit") +pjit_p = core.Primitive("pjit") pjit_p.multiple_results = True @@ -1786,8 +1786,9 @@ def pjit_staging_rule(trace, *args, **params): # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, # but redundantly performs abstract evaluation again. - out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) + with core.set_current_trace(trace): + out_tracers = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, + propagate_source_info=False) else: out_tracers = pe.inline_jaxpr_into_trace( trace, jaxpr.jaxpr, jaxpr.consts, *args) @@ -1807,7 +1808,7 @@ def pjit_staging_rule(trace, *args, **params): trace.frame.add_eqn(eqn) elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(trace.instantiate_const, consts) + consts = map(trace.new_const, consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) @@ -1936,14 +1937,11 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, mlir.register_lowering(pjit_p, _pjit_lowering) -def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, - vals_in, dims_in, jaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): +def _pjit_batcher(axis_data, vals_in, dims_in, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) - new_jaxpr, axes_out = batching.batch_jaxpr2( - jaxpr, axis_size, dims_in, axis_name=axis_name, - spmd_axis_name=spmd_axis_name, main_type=main_type) + new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) if resource_env is not None: mesh = resource_env.physical_mesh @@ -1952,11 +1950,11 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs in_shardings = tuple( - _pjit_batcher_for_sharding(i, axis_in, spmd_axis_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim) if axis_in is not None else i for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals)) out_shardings = tuple( - _pjit_batcher_for_sharding(o, axis_out, spmd_axis_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim) if axis_out is not None else o for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) # TODO(yashkatariya): Figure out layouts should change under vmap. @@ -1982,8 +1980,7 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, vals_in, vals_out, axes_out) return vals_out, resolved_axes_out -batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher -batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None) +batching.fancy_primitive_batchers[pjit_p] = _pjit_batcher batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule def _pjit_batcher_for_sharding( @@ -2541,24 +2538,23 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, def _sharding_constraint_batcher( - spmd_axis_name, axis_size, axis_name, main_type, vals_in, - dims_in, sharding, layout, resource_env, unconstrained_dims): - if spmd_axis_name is not None and isinstance(sharding, NamedSharding): + axis_data, vals_in, dims_in, sharding, layout, resource_env, unconstrained_dims): + if axis_data.spmd_name is not None and isinstance(sharding, NamedSharding): used = {n for ns in sharding.spec for n in (ns if isinstance(ns, tuple) else (ns,))} - if set(spmd_axis_name) & used: - raise ValueError(f"vmap spmd_axis_name {spmd_axis_name} cannot appear in " + if set(axis_data.spmd_name) & used: + raise ValueError(f"vmap spmd_axis_name {axis_data.spmd_name} cannot appear in " "with_sharding_constraint spec, but got spec " f"{sharding.spec}") x, = vals_in d, = dims_in - + # None means unconstrained in ParsedPartitionSpec unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims} - if spmd_axis_name is None: + if axis_data.spmd_name is None: unconstrained_dims.add(d) vmapped_sharding = _pjit_batcher_for_sharding( - sharding, d, spmd_axis_name, resource_env.physical_mesh, x.ndim) + sharding, d, axis_data.spmd_name, resource_env.physical_mesh, x.ndim) if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding): new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec)) for u in unconstrained_dims: @@ -2579,9 +2575,9 @@ def _sharding_constraint_batcher( resource_env=resource_env, unconstrained_dims=unconstrained_dims) return y, d -batching.spmd_axis_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher -batching.axis_primitive_batchers[sharding_constraint_p] = partial( - _sharding_constraint_batcher, None) +batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher +batching.skippable_batchers[sharding_constraint_p] = lambda _: () + # -------------------- helpers -------------------- diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index ecfedad971f4..2c38878c7112 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -23,7 +23,6 @@ from jax._src import ad_util from jax._src import api_util -from jax._src import config from jax._src import core from jax._src import linear_util as lu from jax._src import source_info_util @@ -478,20 +477,6 @@ def _closed_call_discharge_rule( run_state_p = core.Primitive("run_state") run_state_p.multiple_results = True -def _run_state_bind(*args: Any, jaxpr: core.Jaxpr, - which_linear: tuple[bool, ...], - is_initialized: tuple[bool, ...]): - if config.enable_checks.value: - core.check_jaxpr(jaxpr) - num_uninitialized = sum(not i for i in is_initialized) - assert len(jaxpr.invars) == len(args) + num_uninitialized - assert len(which_linear) == len(args) + num_uninitialized - return core.Primitive.bind(run_state_p, *args, jaxpr=jaxpr, - which_linear=which_linear, - is_initialized=is_initialized) -run_state_p.def_custom_bind(_run_state_bind) - - def _default_initialization(x): assert hasattr(x, 'shape') assert hasattr(x, 'dtype') @@ -502,7 +487,6 @@ def _default_initialization(x): value = math.nan return lax.full(x.shape, value, dtype) - def _run_state_impl(*args: Any, jaxpr: core.Jaxpr, which_linear: tuple[bool, ...], is_initialized: tuple[bool, ...]): diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 4ec3123bd3e6..bb81c979bc48 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1162,10 +1162,8 @@ class JaxTestCase(parameterized.TestCase): _compilation_cache_exit_stack: ExitStack | None = None - # TODO(mattjj): this obscures the error messages from failures, figure out how - # to re-enable it - # def tearDown(self) -> None: - # assert core.reset_trace_state() + def tearDown(self) -> None: + assert core.reset_trace_state() def setUp(self): super().setUp() diff --git a/jax/core.py b/jax/core.py index 9682d106e202..6869f747b0d8 100644 --- a/jax/core.py +++ b/jax/core.py @@ -19,7 +19,9 @@ AbstractToken as AbstractToken, AbstractValue as AbstractValue, Atom as Atom, + axis_frame as axis_frame, AxisSize as AxisSize, + AxisName as AxisName, CallPrimitive as CallPrimitive, ClosedJaxpr as ClosedJaxpr, ConcreteArray as ConcreteArray, @@ -40,36 +42,28 @@ JaxprPpSettings as JaxprPpSettings, JaxprTypeError as JaxprTypeError, Literal as Literal, - MainTrace as MainTrace, MapPrimitive as MapPrimitive, nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 OpaqueTraceState as OpaqueTraceState, - NameGatheringSubst as NameGatheringSubst, OutDBIdx as OutDBIdx, OutputType as OutputType, ParamDict as ParamDict, Primitive as Primitive, ShapedArray as ShapedArray, - Sublevel as Sublevel, TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING, - ThreadLocalState as ThreadLocalState, Token as Token, Trace as Trace, - TraceStack as TraceStack, - TraceState as TraceState, Tracer as Tracer, unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 + unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE, # noqa: F401 UnshapedArray as UnshapedArray, Value as Value, Var as Var, abstract_token as abstract_token, - apply_todos as apply_todos, aval_mapping_handlers as aval_mapping_handlers, - axis_frame as axis_frame, call as call, - call_bind_with_continuation as call_bind_with_continuation, call_impl as call_impl, call_p as call_p, check_jaxpr as check_jaxpr, @@ -77,15 +71,12 @@ concrete_aval as concrete_aval, concrete_or_error as concrete_or_error, concretization_function_error as concretization_function_error, - cur_sublevel as cur_sublevel, custom_typechecks as custom_typechecks, dedup_referents as dedup_referents, - do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr, ensure_compile_time_eval as ensure_compile_time_eval, escaped_tracer_error as escaped_tracer_error, eval_context as eval_context, eval_jaxpr as eval_jaxpr, - extend_axis_env as extend_axis_env, extend_axis_env_nd as extend_axis_env_nd, find_top_trace as find_top_trace, full_lower as full_lower, @@ -102,44 +93,33 @@ lattice_join as lattice_join, leaked_tracer_error as leaked_tracer_error, literalable_types as literalable_types, - map_bind as map_bind, - map_bind_with_continuation as map_bind_with_continuation, mapped_aval as mapped_aval, maybe_find_leaked_tracers as maybe_find_leaked_tracers, max_dim as max_dim, min_dim as min_dim, - new_base_main as new_base_main, new_jaxpr_eqn as new_jaxpr_eqn, - new_main as new_main, - new_sublevel as new_sublevel, no_axis_name as no_axis_name, no_effects as no_effects, outfeed_primitives as outfeed_primitives, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primitive_uses_outfeed as primitive_uses_outfeed, - process_env_traces_call as process_env_traces_call, - process_env_traces_map as process_env_traces_map, pytype_aval_mappings as pytype_aval_mappings, - raise_as_much_as_possible as raise_as_much_as_possible, raise_to_shaped as raise_to_shaped, raise_to_shaped_mappings as raise_to_shaped_mappings, reset_trace_state as reset_trace_state, - stash_axis_env as stash_axis_env, + set_current_trace as set_current_trace, str_eqn_compact as str_eqn_compact, subjaxprs as subjaxprs, - subst_axis_names as subst_axis_names, - subst_axis_names_eqn as subst_axis_names_eqn, - subst_axis_names_jaxpr as subst_axis_names_jaxpr, - subst_axis_names_var as subst_axis_names_var, substitute_vars_in_output_ty as substitute_vars_in_output_ty, - thread_local_state as thread_local_state, + take_current_trace as take_current_trace, + trace_ctx as trace_ctx, trace_state_clean as trace_state_clean, + TraceTag as TraceTag, traverse_jaxpr_params as traverse_jaxpr_params, typecheck as typecheck, typecompat as typecompat, typematch as typematch, unmapped_aval as unmapped_aval, - used_axis_names as used_axis_names, used_axis_names_jaxpr as used_axis_names_jaxpr, valid_jaxtype as valid_jaxtype, ) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 62da0f231d50..a25d93a35c51 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -14,18 +14,20 @@ from __future__ import annotations -from contextlib import contextmanager from typing import Any from jax._src import core +from jax._src import source_info_util from jax._src import api_util from jax._src import linear_util as lu +from jax._src.ad_util import (Zero) from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, treedef_tuple) from jax._src.util import unzip2, safe_map, safe_zip, split_list +from jax._src.dtypes import dtype, float0 map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -35,23 +37,13 @@ register = api_util.register_class_with_attrs -@contextmanager -def top_trace(): - stack = core.thread_local_state.trace_state.trace_stack.stack - main = stack.pop() - try: - trace = main.with_cur_sublevel() - yield trace - finally: - stack.append(main) - def jax_getattr(obj: Any, attr: str): - with top_trace() as trace: - return trace.process_getattr(obj, attr) + with core.take_current_trace() as t: + return t.process_getattr(obj, attr) def jax_setattr(obj: Any, attr: str, val: Pytree): - with top_trace() as trace: - return trace.process_setattr(obj, attr, val) + with core.take_current_trace() as t: + return t.process_setattr(obj, attr, val) def _getattr_impl(_, obj, attr): return getattr(obj, attr) @@ -62,7 +54,7 @@ def _setattr_impl(_, obj, attr, val): core.EvalTrace.process_setattr = _setattr_impl def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): - frame = trace.main.jaxpr_stack[-1] # type: ignore + frame = trace.frame def new_tracer(x): aval = core.raise_to_shaped(core.get_aval(x)) @@ -116,37 +108,40 @@ def _jvp(fun: lu.WrappedFun): @lu.transformation def jvpfun2(primals, tangents): - with core.new_main(ad.JVPTrace) as main: - out_primals, out_tangents, tangent_attrs_out = \ - yield (main, primals, tangents), {} - del main + tag = core.TraceTag() + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and dtype(t) == float0 else t for t in tangents] + ctx = source_info_util.transform_name_stack('jvp') + with ctx: + out_primals, out_tangents, tangent_attrs_out = yield (tag, primals, tangents), {} yield out_primals, out_tangents, tangent_attrs_out @lu.transformation -def jvp_subtrace2(main, primals, tangents): - main.attrs_tracked = [] # attrs written to - trace = main.with_cur_sublevel() - in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x - for x, t in zip(primals, tangents)] - ans = yield in_tracers, {} - out_tracers = map(trace.full_raise, ans) - out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in out_tracers) - tangent_attrs_out = [] - for (obj, name) in main.attrs_tracked: - tracer = trace.full_raise(jax_getattr(obj, name)) - jax_setattr(obj, name, tracer.primal) - if type(tracer.tangent) is not ad.Zero: - tangent_attrs_out.append((obj, name, tracer.tangent)) - del main.attrs_tracked - yield out_primals, out_tangents, tangent_attrs_out +def jvp_subtrace2(tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = ad.JVPTrace(parent_trace, tag) + tag.attrs_tracked = [] # attrs written to + in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x + for x, t in zip(primals, tangents)] + with core.set_current_trace(trace): + ans = yield in_tracers, {} + out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) + tangent_attrs_out = [] + for (obj, name) in tag.attrs_tracked: + primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name)) + jax_setattr(obj, name, primal) + if type(tangent) is not ad.Zero: + tangent_attrs_out.append((obj, name, tangent)) + del tag.attrs_tracked + yield out_primals, out_tangents, tangent_attrs_out def _setattr_jvp(trace, obj, attr, maybe_tracer): - tracer = trace.full_raise(maybe_tracer) - if isinstance(tracer.tangent, ad.Zero): - return setattr(obj, attr, tracer.primal) - if (obj, attr) not in trace.main.attrs_tracked: - trace.main.attrs_tracked.append((obj, attr)) - return setattr(obj, attr, tracer) + primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) + if isinstance(tangent, ad.Zero): + return setattr(obj, attr, primal) + if (obj, attr) not in trace.tag.attrs_tracked: + trace.tag.attrs_tracked.append((obj, attr)) + return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent)) ad.JVPTrace.process_setattr = _setattr_jvp def _getattr_jvp(trace, obj, attr): diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 273f756fe634..972d1b3dd570 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -399,7 +399,7 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal: # It is Ok to nest convert when we are inside a call_tf raise ValueError( "convert must be used outside all JAX transformations." + - f"Trace state: {core.thread_local_state.trace_state.trace_stack}") + f"Trace state: {core.trace_ctx}") global _has_registered_tf_source_path if not _has_registered_tf_source_path: @@ -844,15 +844,11 @@ def _interpret_fun_jax( extra_name_stack: str | None, fresh_constant_cache: bool = False, ) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]: - with core.new_base_main(TensorFlowTrace) as main: - subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals) - with _extended_name_stack(extra_name_stack): - with core.new_sublevel(): - out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \ - _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf, - fresh_constant_cache=fresh_constant_cache) - del main - + subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), args_avals) + with _extended_name_stack(extra_name_stack): + out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \ + _call_wrapped_with_new_constant_cache(subtrace_fun, args_tf, + fresh_constant_cache=fresh_constant_cache) return util.unzip2(out_vals) @@ -1036,16 +1032,16 @@ def impl_multiple_results_jax(*args_jax): @lu.transformation -def _interpret_subtrace(main: core.MainTrace, - in_avals: Sequence[core.ShapedArray], +def _interpret_subtrace(in_avals: Sequence[core.ShapedArray], *in_vals: TfVal): - trace = TensorFlowTrace(main, core.cur_sublevel()) + trace = TensorFlowTrace() in_tracers = tuple( TensorFlowTracer(trace, val, aval) for val, aval in zip(in_vals, in_avals)) - outs = yield in_tracers, {} # type: Sequence[TfVal] + with core.set_current_trace(trace): + outs = yield in_tracers, {} # type: Sequence[TfVal] out_tracers: Iterable[TensorFlowTracer] = ( - map(trace.full_raise, outs)) + map(trace.to_tf_tracer, outs)) out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( tuple((t.val, t.aval) for t in out_tracers)) yield out_vals_with_avals @@ -1321,13 +1317,14 @@ class TensorFlowTrace(core.Trace): those will introduce their own MainTrace, and any operations involving those will be done on those traces, i.e., not a concern for TFT. """ - def pure(self, val: TfVal) -> TensorFlowTracer: + def to_tf_tracer(self, val: TfVal) -> TensorFlowTracer: """Lifts a non-Tracer into the TensorFlowTracer. - - This function may be called by way of trace.full_raise. """ + if isinstance(val, TensorFlowTracer): + return val if hasattr(val, "__jax_array__"): - val = val.__jax_array__() + with core.set_current_trace(self): + val = val.__jax_array__() if isinstance(val, TensorFlowTracer): return val tf_val, jax_dtype = _tfval_to_tensor_jax_dtype(val, memoize_constants=True) @@ -1335,20 +1332,10 @@ def pure(self, val: TfVal) -> TensorFlowTracer: self, tf_val, core.ShapedArray(np.shape(val), jax_dtype, weak_type=dtypes.is_weakly_typed(val))) - def lift(self, val: core.Tracer) -> TensorFlowTracer: - # This would be called when we need to raise a tracer from a lower-level - # main into the TensorFlowTrace. Since the TensorFlowTrace is never nested - # inside another transform, there are no lower-level main traces. - assert False - - def sublift(self, val: TensorFlowTracer) -> TensorFlowTracer: - # This is called when we need to raise a tracer from the same main, - # but a lower sublevel. This could come from a nested jit. - return TensorFlowTracer(self, val.val, val._aval) - def process_primitive(self, primitive: core.Primitive, tracers: Sequence[TensorFlowTracer], params) -> TensorFlowTracer: + tracers = map(self.to_tf_tracer, tracers) impl, impl_needs_avals = self.get_primitive_impl(primitive) args_avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) # This is a bit conservative, doing abstract_eval even in op-by-op execution @@ -1424,39 +1411,18 @@ def invoke_impl() -> TfVal: def process_call(self, call_primitive: core.Primitive, fun: lu.WrappedFun, tracers: Sequence[TensorFlowTracer], params): assert call_primitive.multiple_results + tracers = map(self.to_tf_tracer, tracers) vals: Sequence[TfVal] = [t.val for t in tracers] avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) - interpreted_fun = _interpret_subtrace(fun, self.main, avals) + interpreted_fun = _interpret_subtrace(fun, avals) extra_name_stack = None with _extended_name_stack(extra_name_stack): - with core.new_sublevel(): - vals_out = interpreted_fun.call_wrapped(*vals) + vals_out = interpreted_fun.call_wrapped(*vals) return [TensorFlowTracer(self, v, a) for v, a in vals_out] - def post_process_call(self, call_primitive: core.Primitive, - out_tracers: Sequence[TensorFlowTracer], params): - # We encountered a call primitive whose result (out_tracers) include - # TensorFlowTracer that were not passed through its arguments (captured from - # the environment). - vals = tuple(t.val for t in out_tracers) - main = self.main - - def todo(vals: Sequence[TfVal]): - # TODO: is name_stack correct? - trace = TensorFlowTrace(main, core.cur_sublevel()) - return [ - TensorFlowTracer(trace, v, out_tracer.aval) - for v, out_tracer in zip(vals, out_tracers) - ] - - return vals, todo - def process_map(self, map_primitive, f, tracers, params): raise NotImplementedError("process_map") - def post_process_map(self, map_primitive, out_tracers, params): - raise NotImplementedError("post_process_map") - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This # behavior is desirable because jax2tf stages code out of the JAX system, so @@ -1464,9 +1430,6 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): del jvp, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) - def post_process_custom_jvp_call(self, out_tracers, _): - assert False # unreachable assuming jax2tf runs with clean trace state - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This @@ -1475,12 +1438,6 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, del fwd, bwd, out_trees, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable assuming jax2tf runs with clean trace state - - def post_process_custom_vjp_call_fwd(self, *_, **__): - assert False # unreachable assuming jax2tf runs with clean trace state - def get_primitive_impl(self, p: core.Primitive) -> tuple[Callable, bool]: # Returns the primitive implementation and whether the implementation # takes abstract values (see definition of tf_impl_with_avals) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index ffe362974dcb..8dd2a319a1cb 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -152,22 +152,22 @@ def flatten_fun_output(*args): @lu.transformation def jet_fun(order, primals, series): - with core.new_main(JetTrace) as main: - main.order = order - out_primals, out_terms = yield (main, primals, series), {} - del main + tag = core.TraceTag() + out_primals, out_terms = yield (tag, order, primals, series), {} out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] yield out_primals, out_terms @lu.transformation -def jet_subtrace(main, primals, series): - trace = JetTrace(main, core.cur_sublevel()) - in_tracers = map(partial(JetTracer, trace), primals, series) - ans = yield in_tracers, {} - out_tracers = map(trace.full_raise, ans) - out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) - yield out_primals, out_terms +def jet_subtrace(tag, order, primals, series): + with core.take_current_trace() as parent_trace: + trace = JetTrace(tag, parent_trace, order) + in_tracers = map(partial(JetTracer, trace), primals, series) + with core.set_current_trace(trace): + ans = yield in_tracers, {} + + out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans)) + yield out_primals, out_terms @lu.transformation_with_aux def traceable(in_tree_def, *primals_and_series): @@ -198,33 +198,44 @@ def full_lower(self): class JetTrace(core.Trace): - def pure(self, val): - return JetTracer(self, val, zero_series) + def __init__(self, tag, parent_trace, order): + self.tag = tag + self.parent_trace = parent_trace + self.order = order - def lift(self, val): - return JetTracer(self, val, zero_series) - - def sublift(self, val): - return JetTracer(self, val.primal, val.terms) + def to_primal_terms_pair(self, val): + if isinstance(val, JetTracer) and val._trace.tag is self.tag: + return val.primal, val.terms + else: + return val, zero_series def process_primitive(self, primitive, tracers, params): - order = self.main.order # pytype: disable=attribute-error - primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) + order = self.order # pytype: disable=attribute-error + primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers)) + + if all(t is zero_series for t in series_in): + primal_out = primitive.bind_with_trace(self.parent_trace, primals_in, params) + if primitive.multiple_results: + return [JetTracer(self, p, zero_series) for p in primal_out] + else: + return JetTracer(self, primal_out, zero_series) + series_in = [[zero_term] * order if s is zero_series else s for s in series_in] - # TODO(mattjj): avoid always instantiating zeros - series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x)) - if t is zero_term else t for t in series] - for x, series in zip(primals_in, series_in)] - rule = jet_rules[primitive] - primal_out, terms_out = rule(primals_in, series_in, **params) + with core.set_current_trace(self.parent_trace): + # TODO(mattjj): avoid always instantiating zeros + series_in = [[jnp.zeros(np.shape(x), dtype=jnp.result_type(x)) + if t is zero_term else t for t in series] + for x, series in zip(primals_in, series_in)] + rule = jet_rules[primitive] + primal_out, terms_out = rule(primals_in, series_in, **params) if not primitive.multiple_results: return JetTracer(self, primal_out, terms_out) else: return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)] def process_call(self, call_primitive, f, tracers, params): - primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) + primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers)) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) update_params = call_param_updaters.get(call_primitive) @@ -234,17 +245,6 @@ def process_call(self, call_primitive, f, tracers, params): primals_out, series_out = tree_unflatten(out_tree_def(), result) return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)] - def post_process_call(self, call_primitive, out_tracers, params): - primals, series = unzip2((t.primal, t.terms) for t in out_tracers) - out, treedef = tree_flatten((primals, series)) - del primals, series - main = self.main - def todo(x): - primals, series = tree_unflatten(treedef, x) - trace = JetTrace(main, core.cur_sublevel()) - return map(partial(JetTracer, trace), primals, series) - return out, todo - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): # TODO(mattjj): don't just ignore custom jvp rules? diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 803efa19056e..b38edcaba10a 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -359,22 +359,18 @@ def ltg_abstract_eval(arr, *, global_mesh, pspec): lambda ct, _, **params: ( host_local_array_to_global_array_p.bind(ct, **params),)) -def ltg_batcher(insert_axis, spmd_axis_name, axis_size, - axis_name, main_type, vals_in, dims_in, - global_mesh, pspec): +def ltg_batcher(insert_axis, axis_data, vals_in, dims_in, global_mesh, pspec): x, = vals_in d, = dims_in - new_parts = None if spmd_axis_name is None else spmd_axis_name + new_parts = None if axis_data.spmd_name is None else axis_data.spmd_name new_pspec = list(pspec) new_pspec.insert(d, new_parts) new_pspec = P(*new_pspec) y = host_local_array_to_global_array_p.bind( x, global_mesh=global_mesh, pspec=new_pspec) return y, d -batching.spmd_axis_primitive_batchers[host_local_array_to_global_array_p] = partial( +batching.fancy_primitive_batchers[host_local_array_to_global_array_p] = partial( ltg_batcher, False) -batching.axis_primitive_batchers[host_local_array_to_global_array_p] = partial( - ltg_batcher, False, None) def _ltg_lowering(ctx, x, *, global_mesh, pspec): return [x] diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 03f3c96005ec..2fa028b2fe1e 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -53,9 +53,9 @@ special, control_flow, ann) from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy -from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, +from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2) + split_list, subs_list2) from jax.api_util import flatten_fun_nokwargs, shaped_abstractify, argnums_partial from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -454,30 +454,9 @@ def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] class ShardMapPrimitive(core.Primitive): multiple_results = True - def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, rewrite: bool, auto: frozenset[AxisName] - ) -> Sequence[MaybeTracer]: - top_trace = core.find_top_trace(args) - fun, env_todo = process_env_traces(fun, top_trace.level, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto) - - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - out_names = out_names_thunk() - _, xforms = env_todo() - for t in xforms: - out_names = t(out_names) - return out_names - - tracers = map(top_trace.full_raise, args) - outs = top_trace.process_shard_map( # pytype: disable=attribute-error - shard_map_p, fun, tracers, mesh=mesh, in_names=in_names, - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - todos, _ = env_todo() - return map(core.full_lower, core.apply_todos(todos, outs)) + def bind_with_trace(self, trace, fun_and_args, params): + fun, *args = fun_and_args + return trace.process_shard_map(shard_map_p, fun, args, **params) def get_bind_params(self, params): new_params = dict(params) @@ -489,56 +468,37 @@ def get_bind_params(self, params): shard_map_p = ShardMapPrimitive('shard_map') -@lu.transformation_with_aux -def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep, - rewrite, auto, *args: Any): - outs = yield args, {} - todos, out_names_transforms = [], [] - while True: - tracers = [x for x in outs if isinstance(x, core.Tracer) - and (level is None or x._trace.level > level)] - if tracers: - ans = max(tracers, key=op.attrgetter('_trace.level')) - else: - break - trace = ans._trace.main.with_cur_sublevel() - outs = map(trace.full_raise, outs) - outs, (todo, xform) = trace.post_process_shard_map( - outs, mesh, in_names, out_names_thunk, check_rep, rewrite, auto) - todos.append(todo) - out_names_transforms.append(xform) - yield outs, (tuple(todos), tuple(out_names_transforms)) - # Staging def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, - in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh, + in_tracers: Sequence[Any], *, mesh: Mesh, in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, rewrite: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: + in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) - main = trace.main - with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) - out_avals_ = map(_check_shapedarray, genavals) + with core.extend_axis_env_nd(list(mesh.shape.items())): + jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) - in_rep = map(partial(_in_names_to_rep, mesh), in_names) if check_rep: + in_rep = map(partial(_in_names_to_rep, mesh), in_names) out_rep = _check_rep(mesh, jaxpr, in_rep) _check_reps(mesh, out_names_thunk(), out_rep) - out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_) + out_avals = map(_check_shapedarray, out_avals_) + out_avals = [_check_shapedarray(_unshard_aval(mesh, names, aval)) + for names, aval in zip(out_names_thunk(), out_avals)] source_info = source_info_util.current() out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(trace.instantiate_const, consts)) + constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with core.extend_axis_env_nd(mesh.shape.items()): + with core.extend_axis_env_nd(list(mesh.shape.items())): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, @@ -804,28 +764,23 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, mesh = get_mesh_from_args(args, mesh) args = map(partial(_unmatch_spec, mesh), in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) - with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main: - fun, out_rep = _shmap_subtrace(fun, main, in_rep) - with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main): - outs = fun.call_wrapped(*args) - del main + outs, out_rep = _run_shmap(fun, mesh, args, in_rep, check_rep) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types if check_rep: - _check_reps(mesh, out_names_thunk(), out_rep()) + _check_reps(mesh, out_names_thunk(), out_rep) pspecs = map(_names_to_pspec, out_names_thunk()) return map(partial(_match_spec, mesh, check_rep), pspecs, outs) core.EvalTrace.process_shard_map = _shard_map_impl -@lu.transformation_with_aux -def _shmap_subtrace(main, in_rep, *in_vals): - t = main.with_cur_sublevel() - in_tracers = map(partial(ShardMapTracer, t), in_rep, in_vals) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - outs, out_rep = unzip2((t.val, t.rep) for t in out_tracers) - del t, in_tracers, ans, out_tracers - yield outs, out_rep +def _run_shmap(f, mesh, args, reps, check_rep): + trace = ShardMapTrace(mesh, check_rep) + in_tracers = map(partial(ShardMapTracer, trace), reps, args) + with core.set_current_trace(trace): + with core.extend_axis_env_nd(mesh.shape.items()): + ans = f.call_wrapped(*in_tracers) + outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) + return outs, out_rep def _names_to_pspec(names: AxisNames) -> PartitionSpec: ndmin = max(names) + 1 if names else 0 @@ -877,20 +832,21 @@ class ShardMapTrace(core.Trace): mesh: Mesh check: bool - def __init__(self, *args, mesh, check): - super().__init__(*args) + def __init__(self, mesh, check): self.mesh = mesh self.check = check - def pure(self, val): - val_ = _unmatch_spec(self.mesh, {}, val) - return ShardMapTracer(self, None, val_) - - def sublift(self, tracer): - return ShardMapTracer(self, tracer.rep, tracer.val) + def to_val_rep_pair(self, val): + if isinstance(val, ShardMapTracer): + return val.val, val.rep + elif isinstance(val, Tracer): + raise Exception("Shouldn't have any non-shard_map tracers") + else: + val_ = _unmatch_spec(self.mesh, {}, val) + return val_, None def process_primitive(self, prim, tracers, params): - in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) + in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) eager_rule = eager_rules.get(prim) if eager_rule: out_vals = eager_rule(self.mesh, *in_vals, **params) @@ -926,36 +882,21 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros - in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) - fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(ShardMapTracer, self), out_rep(), out_vals) - - def post_process_custom_jvp_call(self, out_tracers, _): - assert False # unreachable + in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) + out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check) + return map(partial(ShardMapTracer, self), out_rep, out_vals) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): - # Since ShardMapTrace is only used as a base main, we can drop the jvp. if symbolic_zeros: msg = ("custom_vjp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros - in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) - fun, out_rep = _shmap_subtrace(fun, self.main, in_rep) - with core.new_sublevel(): - out_vals = fun.call_wrapped(*in_vals) - return map(partial(ShardMapTracer, self), out_rep(), out_vals) - - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable - - def process_axis_index(self, frame): - with core.eval_context(), jax.disable_jit(False): - return jax.jit(lambda: jax.lax.axis_index(frame.name))() + in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) + out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check) + return map(partial(ShardMapTracer, self), out_rep, out_vals) class ShardMapTracer(core.Tracer): @@ -978,9 +919,6 @@ def aval(self): aval = core.raise_to_shaped(aval) return core.mapped_aval(self._trace.mesh.size, 0, aval) - def full_lower(self) -> ShardMapTracer: - return self - def __str__(self) -> str: with core.eval_context(): blocks = list(self.val) @@ -1023,17 +961,16 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): # New primitives for efficient transposition # psum2_p is like psum_p except has a different transpose, so mostly copied: -psum2_p = core.AxisPrimitive('psum2') +psum2_p = core.Primitive('psum2') psum2_p.multiple_results = True psum2_p.def_impl(lax_parallel.psum_p.impl) psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) -batching.primitive_batchers[psum2_p] = partial(lax_parallel._reduction_batcher, psum2_p) -batching.axis_primitive_batchers[psum2_p] = \ +batching.fancy_primitive_batchers[psum2_p] = \ partial(lax_parallel._batched_reduction_collective, psum2_p, lambda v, axis_size: axis_size * v) -core.axis_substitution_rules[psum2_p] = \ - partial(lax_parallel._subst_all_names_in_param, 'axes') +batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes') + def _psum2_transpose_rule(cts, *args, axes, axis_index_groups): del args return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) @@ -1046,7 +983,7 @@ def pbroadcast(x, axis_name): xs, treedef = tree_flatten(x) ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) return tree_unflatten(treedef, ys) -pbroadcast_p = core.AxisPrimitive('pbroadcast') +pbroadcast_p = core.Primitive('pbroadcast') pbroadcast_p.multiple_results = True pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) @@ -1057,12 +994,6 @@ def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): axis_index_groups=axis_index_groups) return vals_out, dims_in batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher -def _pbroadcast_axis_batcher(size, name, trace_type, vals_in, dims_in, *, axes, - groups): - raise NotImplementedError # vmap with axis name involved in this primitive -batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_axis_batcher -core.axis_substitution_rules[pbroadcast_p] = \ - partial(lax_parallel._subst_all_names_in_param, 'axes') ad.deflinear2(pbroadcast_p, lambda cts, *_, axes, axis_index_groups: psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)) @@ -1421,23 +1352,23 @@ def _shard_map_batch( check_rep: bool, rewrite: bool, auto: frozenset) -> Sequence[batching.BatchTracer]: - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers) - if all(bdim is batching.not_mapped for bdim in in_dims): - return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names, - out_names_thunk=out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) + in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): raise NotImplementedError - fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims)) - new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] # type: ignore + new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] for ax in names} for names, d in zip(in_names, in_dims)] - spmd_axis_name = trace.spmd_axis_name + spmd_axis_name = trace.axis_data.spmd_name if spmd_axis_name is not None: used = {n for names in in_names for ns in names.values() for n in ns} if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") - new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore + new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped else ns for ns, d in zip(new_in_names, in_dims)] + new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) + new_axis_data = batching.AxisData(trace.axis_data.name, new_size, trace.axis_data.spmd_name) + else: + new_axis_data = trace.axis_data + fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) @as_hashable_function(closure=out_names_thunk) def new_out_names_thunk(): return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) @@ -1445,25 +1376,13 @@ def new_out_names_thunk(): new_params = dict(mesh=mesh, in_names=new_in_names, out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, auto=auto) - out_vals = prim.bind(fun, *in_vals, **new_params) + with core.set_current_trace(trace.parent_trace): + out_vals = prim.bind(fun, *in_vals, **new_params) make_tracer = partial(batching.BatchTracer, trace, source_info=source_info_util.current()) return map(make_tracer, out_vals, out_dims()) batching.BatchTrace.process_shard_map = _shard_map_batch -def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - del mesh, in_names, out_names_thunk, check_rep, rewrite, auto - vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) - for t in out_tracers) - m = trace.main - def todo(vals): - trace = m.with_cur_sublevel() - return map(partial(batching.BatchTracer, trace), vals, dims, srcs) - out_names_transform = partial(_batch_out_names, trace.spmd_axis_name, dims) - return vals, (todo, out_names_transform) -batching.BatchTrace.post_process_shard_map = _shard_map_batch_post_process - def _batch_out_names(spmd_axis_name, dims, out_names): out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] for ax in names} for names, d in zip(out_names, dims)] @@ -1480,11 +1399,11 @@ def _batch_out_names(spmd_axis_name, dims, out_names): def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): - primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] args, in_tree = tree_flatten((primals, tangents)) - f_jvp = ad.jvp_subtrace(f, trace.main) + f_jvp = ad.jvp_subtrace(f, trace.tag) f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] @@ -1496,36 +1415,22 @@ def new_out_names_thunk(): out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, auto=auto) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) - result = shard_map_p.bind(f_jvp, *args, **params) + result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) primal_out, tangent_out = tree_unflatten(out_tree(), result) tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t for p, t in zip(primal_out, tangent_out)] return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] ad.JVPTrace.process_shard_map = _shard_map_jvp -def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - del mesh, in_names, out_names_thunk, check_rep, rewrite, auto - primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) - out, treedef = tree_flatten((primals, tangents)) - tangents_nz = [type(t) is not ad.Zero for t in tangents] - m = trace.main - def todo(x): - primals, tangents = tree_unflatten(treedef, x) - return map(partial(ad.JVPTracer, m.with_cur_sublevel()), primals, tangents) - def out_names_transform(out_names): - return (*out_names, *(n for n, nz in zip(out_names, tangents_nz) if nz)) - return out, (todo, out_names_transform) -ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process - def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): + tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - all_names = _all_mesh_names(mesh) + all_names = _all_mesh_names_except_spmd(mesh, trace) in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals) - f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.main, False) + f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, False) f = _promote_scalar_residuals(f) f_known, aux = pe.partial_eval_wrapper_nounits( f, (*in_knowns,), (*in_avals_sharded,)) @@ -1540,7 +1445,7 @@ def known_out_names(): known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_rep=check_rep, rewrite=rewrite, auto=auto) - out = shard_map_p.bind(f_known, *in_consts, **known_params) + out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux() num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) @@ -1553,7 +1458,7 @@ def known_out_names(): {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.full_raise, env) + env_tracers = map(trace.to_jaxpr_tracer, env) unk_arg_tracers = [t for t in tracers if not t.is_known()] unk_params = dict(mesh=mesh, in_names=unk_in_names, out_names=unk_out_names, jaxpr=jaxpr, check_rep=False, @@ -1569,55 +1474,6 @@ def known_out_names(): return pe.merge_lists(out_knowns, out_tracers, out_consts) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval -def _shard_map_partial_eval_post_process( - trace, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): - del check_rep - all_names = _all_mesh_names(mesh) - unk_tracers = [t for t in tracers if not t.is_known()] - jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers) - # TODO(mattjj): output forwarding optimization - which = [not getattr(v.aval, 'shape', True) for v in jaxpr.constvars] - res = [jax.lax.broadcast(x, (1,)) if not getattr(v.aval, 'shape', True) else x - for x, v in zip(res, jaxpr.constvars)] - jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) - - out_knowns, out_avals_, consts = pe.partition_pvals([t.pval for t in tracers]) - out = [*consts, *res] - main = trace.main - with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_ = pe.convert_constvars_jaxpr(jaxpr) - - def todo(out): - trace = main.with_cur_sublevel() - out_consts, res_ = split_list(out, [len(out) - len(res)]) - const_tracers = map(trace.new_instantiated_const, res_) - env_tracers = map(trace.full_raise, env) - - staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env) - staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names, - out_names=(*out_names_unknown,), check_rep=False, - rewrite=rewrite, auto=auto) - - out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_) - out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in out_avals] - name_stack = trace._current_truncated_name_stack() - source = source_info_util.current().replace(name_stack=name_stack) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, - shard_map_p, staged_params, effs, source) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) - - def out_names_transform(out_names): - nonlocal out_names_unknown - out_names_unknown, out_names_known = partition_list(out_knowns, out_names) - return (*out_names_known,) + ({0: all_names},) * len(res) - out_names_unknown: list | None = None - - return out, (todo, out_names_transform) -pe.JaxprTrace.post_process_shard_map = _shard_map_partial_eval_post_process - @lu.transformation def _promote_scalar_residuals(*args, **kwargs): jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs @@ -1645,7 +1501,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-rep case. name_set = {n for ns in names.values() for n in ns} - return [n for n in _all_mesh_names(mesh) if n not in name_set] + return [n for n in mesh.axis_names if n not in name_set] def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, @@ -1692,18 +1548,6 @@ def new_out_names_thunk(): return tree_unflatten(out_tree(), out_flat) ad.primitive_transposes[shard_map_p] = _shard_map_transpose -def _shard_map_axis_subst(params, subst, traverse): - if 'jaxpr' not in params: - return params - if not traverse: - return params - def shadowed_subst(name): - return (name,) if name in params['mesh'].shape else subst(name) - with core.extend_axis_env_nd(params['mesh'].shape.items()): - new_jaxpr = core.subst_axis_names_jaxpr(params['jaxpr'], shadowed_subst) - return dict(params, jaxpr=new_jaxpr) -core.axis_substitution_rules[shard_map_p] = _shard_map_axis_subst - # Remat def _partial_eval_jaxpr_custom_rule( @@ -1783,7 +1627,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, in_fwd, out_fwd, which, params_known, params_staged): # prune inputs to jaxpr_known according to unks_in mesh = params_known['mesh'] - all_names = _all_mesh_names(mesh) + all_names = _all_mesh_names_except_spmd(mesh) in_names_known, _ = partition_list(unks_in, params_known['in_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) out_names_known = out_names_known + [{0: all_names}] * sum(which) @@ -1801,15 +1645,13 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, out_names=tuple(out_names_staged), check_rep=False) return new_params_known, new_params_staged, all_names - # TODO(mattjj): remove this mechanism when we revise mesh scopes -def _all_mesh_names(mesh: Mesh) -> tuple[AxisName, ...]: - stack = core.thread_local_state.trace_state.trace_stack.stack - names = {n for frame in stack - if (ns := frame.payload.get('spmd_axis_name', ())) is not None - for n in ns} - return tuple(name for name in mesh.axis_names if name not in names) - +def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]: + trace = core.unsafe_get_current_trace() if trace is None else trace + stack = core.unsafe_get_trace_stack(trace) + batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)] + spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name } + return tuple(name for name in mesh.axis_names if name not in spmd_names) # DCE @@ -1926,59 +1768,52 @@ def __init__(self, trace, rep, val): def aval(self) -> core.AbstractValue: return core.get_aval(self.val) - def full_lower(self) -> RewriteTracer: - return self - def __str__(self) -> str: return str(self.val) # TODO(mattjj): could show replication info here __repr__ = __str__ # for debuggers, like `p x` class RewriteTrace(core.Trace): + parent_trace : core.Trace + tag : core.TraceTag mesh: Mesh - dyna: int - def __init__(self, *args, mesh, dyna): - super().__init__(*args) + def __init__(self, parent_trace, tag, mesh): + self.parent_trace = parent_trace + self.tag = tag self.mesh = mesh - self.dyna = dyna - - def pure(self, val) -> RewriteTracer: - return RewriteTracer(self, set(self.mesh.axis_names), val) - - def lift(self, tracer: core.Tracer) -> RewriteTracer: - return RewriteTracer(self, set(self.mesh.axis_names), tracer) - def sublift(self, tracer: core.Tracer) -> RewriteTracer: - return RewriteTracer(self, tracer.rep, tracer.val) + def to_val_rep_pair(self, val): + # TODO: add a tag to tell if self + if isinstance(val, RewriteTracer) and val._trace.tag is self.tag: + return val.val, val.rep + else: + return val, set(self.mesh.axis_names) def process_primitive(self, prim, in_tracers, params): rule = _rewrite_rules.get(prim, partial(_rule_missing, prim)) - in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) - with core.new_dynamic(self.dyna): + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) + with core.set_current_trace(self.parent_trace): out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params) out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals) return out_tracers if prim.multiple_results else out_tracers[0] def process_call(self, call_primitive, f, in_tracers, params): - in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) - f, out_reps = _rewrite_subtrace(f, self.main, tuple(in_reps)) - with core.new_dynamic(self.dyna): + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) + f, out_reps = _rewrite_subtrace(f, self.tag, self.mesh, tuple(in_reps)) + with core.set_current_trace(self.parent_trace): out_vals = call_primitive.bind(f, *in_vals, **params) return map(partial(RewriteTracer, self), out_reps(), out_vals) - def post_process_call(self, call_primitive, out_tracers, params): - assert False # unreachable - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) - in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) - fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) - jvp, out_reps2 = _rewrite_subtrace(jvp, self.main, in_reps * 2) - with core.new_dynamic(self.dyna): + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) + fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) + jvp, out_reps2 = _rewrite_subtrace(jvp, self.tag, self.mesh, in_reps * 2) + with core.set_current_trace(self.parent_trace): out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) if not fst: @@ -1986,9 +1821,6 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): out_reps = out_reps[:len(out_reps) // 2] return map(partial(RewriteTracer, self), out_reps, out_vals) - def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): - assert False # unreachable - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: @@ -1996,12 +1828,12 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) - in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) - fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) + in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) + fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]] - fwd, out_reps2 = _rewrite_subtrace(fwd, self.main, fwd_in_reps) + fwd, out_reps2 = _rewrite_subtrace(fwd, self.tag, self.mesh, fwd_in_reps) bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps) - with core.new_dynamic(self.dyna): + with core.set_current_trace(self.parent_trace): out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) @@ -2010,36 +1842,24 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, _, out_reps = split_list(out_reps, [res_tree.num_leaves]) return map(partial(RewriteTracer, self), out_reps, out_vals) - def post_process_custom_vjp_call(self, out_tracers, _): - assert False # unreachable - - # TODO process_axis_index - def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): in_reps = map(partial(_in_names_to_rep, mesh), in_names) out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()] fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) return _match_rep(fun, mesh, out_reps_src, out_reps_dst) -def _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps): - return _efficient_transpose_outer(_efficient_transpose_inner(fun), mesh, in_reps) - @lu.transformation_with_aux -def _efficient_transpose_outer(mesh, in_reps, *args): - lvl = core.dynamic_level() - with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: - out_vals, out_reps = yield (main, mesh, in_reps, args), {} - del main +def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args): + with core.take_current_trace() as parent: + tag = core.TraceTag() + t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) + in_tracers = map(partial(RewriteTracer, t), in_reps, args) + with core.set_current_trace(t): + ans = yield in_tracers, {} + out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) + del t, in_tracers, ans yield out_vals, out_reps -@lu.transformation -def _efficient_transpose_inner(main, mesh, in_reps, args): - t = main.with_cur_sublevel() - in_tracers = map(partial(RewriteTracer, t), in_reps, args) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - yield unzip2((t.val, t.rep) for t in out_tracers) - @lu.transformation def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): outs = yield args, {} @@ -2060,8 +1880,7 @@ def _replication_rewrite_match( f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) f = _match_rep(f, mesh, out_rep, out_rep_dst) - with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts) # TODO(mattjj): caching @@ -2072,28 +1891,25 @@ def _replication_rewrite_nomatch( ) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts), out_rep() @lu.transformation_with_aux -def _rewrite_subtrace(main, in_reps, *in_vals): - assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) - t = main.with_cur_sublevel() - in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) - with core.new_dynamic(main.level): - outs = yield in_tracers, {} - out_tracers = map(t.full_raise, outs) - out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers) - yield out_vals, out_reps +def _rewrite_subtrace(tag, mesh, in_reps, *in_vals): + with core.take_current_trace() as parent_trace: + assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) + t = RewriteTrace(parent_trace, tag, mesh) + in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) + with core.set_current_trace(t): + outs = yield in_tracers, {} + ans = unzip2(map(t.to_val_rep_pair, outs)) + yield ans def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): def new_bwd(*args): - lvl = core.dynamic_level() - with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: - bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), main, in_reps()) - out = bwd_.call_wrapped(*args) - del main + tag = core.TraceTag() + bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), tag, mesh, in_reps()) + out = bwd_.call_wrapped(*args) return map(_match_replication, reps_thunk(), reps_dst, out) return new_bwd diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index efdf1888f436..5348dd62a32e 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -276,16 +276,6 @@ def spvalue_to_aval(spvalue): # ------------------------------------------------------------------------------ # Implementation of sparsify() using tracers. -def popattr(obj: Any, name: str) -> Any: - assert hasattr(obj, name) - val = getattr(obj, name) - delattr(obj, name) - return val - -def setnewattr(obj: Any, name: str, val: Any): - assert not hasattr(obj, name) - setattr(obj, name, val) - class SparseTracer(core.Tracer): def __init__(self, trace: core.Trace, *, spvalue): self._spvalue = spvalue @@ -293,9 +283,9 @@ def __init__(self, trace: core.Trace, *, spvalue): @property def spenv(self): - if not hasattr(self._trace.main, 'spenv'): - raise RuntimeError("Internal: main does not have spenv defined.") - return self._trace.main.spenv + if not hasattr(self._trace, 'spenv'): + raise RuntimeError("Internal: trace does not have spenv defined.") + return self._trace.spenv @property def aval(self): @@ -305,71 +295,70 @@ def full_lower(self): return self class SparseTrace(core.Trace): - def pure(self, val: Any): - if not hasattr(self.main, 'spenv'): - raise RuntimeError("Internal: main does not have spenv defined.") - spvalue, = arrays_to_spvalues(self.main.spenv, [val]) - return SparseTracer(self, spvalue=spvalue) - def lift(self, val: core.Tracer): - if not hasattr(self.main, 'spenv'): - raise RuntimeError("Internal: main does not have spenv defined.") - spvalue, = arrays_to_spvalues(self.main.spenv, [val]) - return SparseTracer(self, spvalue=spvalue) + def __init__(self, parent_trace, tag, spenv): + self.parent_trace = parent_trace + self.tag = tag + self.spenv = spenv - def sublift(self, val: SparseTracer): - return SparseTracer(val._trace, spvalue=val._spvalue) + def to_sparse_tracer(self, val): + if isinstance(val, SparseTracer) and self.tag is val._trace.tag: + return val + else: + with core.set_current_trace(self.parent_trace): + spvalue, = arrays_to_spvalues(self.spenv, [val]) + return SparseTracer(self, spvalue=spvalue) def process_primitive(self, primitive, tracers, params): - spenv = popattr(self.main, 'spenv') + tracers = [self.to_sparse_tracer(t) for t in tracers] spvalues = [t._spvalue for t in tracers] if any(spvalue.is_sparse() for spvalue in spvalues): if primitive not in sparse_rules_bcoo: _raise_unimplemented_primitive(primitive) - out_spvalues = sparse_rules_bcoo[primitive](spenv, *(t._spvalue for t in tracers), **params) + with core.set_current_trace(self.parent_trace): + out_spvalues = sparse_rules_bcoo[primitive](self.spenv, *(t._spvalue for t in tracers), **params) else: - out_bufs = primitive.bind(*(spenv.data(spvalue) for spvalue in spvalues), **params) - out_spvalues = arrays_to_spvalues(spenv, out_bufs if primitive.multiple_results else [out_bufs]) - setnewattr(self.main, 'spenv', spenv) + out_bufs = primitive.bind_with_trace(self.parent_trace, tuple(self.spenv.data(spvalue) for spvalue in spvalues), params) + out_spvalues = arrays_to_spvalues(self.spenv, out_bufs if primitive.multiple_results else [out_bufs]) out_tracers = tuple(SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues) return out_tracers if primitive.multiple_results else out_tracers[0] def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): - spenv = popattr(self.main, 'spenv') + assert False spvalues = tuple(t._spvalue for t in tracers) - in_bufs = spenv._buffers + in_bufs = self.spenv._buffers fun, out_spvalues = sparsify_subtrace(f, self.main, spvalues) if any(params['donated_invars']): raise NotImplementedError("sparsify does not support donated_invars") params = dict(params, donated_invars=tuple(False for buf in in_bufs)) bufs_out = call_primitive.bind(fun, *in_bufs, **params) - setnewattr(self.main, 'spenv', SparsifyEnv(bufs_out)) return [SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues()] def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): # TODO(jakevdp): handle the jvp here del primitive, jvp, symbolic_zeros - return fun.call_wrapped(*tracers) + with core.set_current_trace(self): + return fun.call_wrapped(*tracers) @lu.transformation_with_aux -def sparsify_subtrace(main, spvalues, *bufs): - setnewattr(main, 'spenv', SparsifyEnv(bufs)) - trace = main.with_cur_sublevel() - in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] - outs = yield in_tracers, {} - out_traces = [trace.full_raise(out) for out in outs] - buffers = popattr(main, 'spenv')._buffers - yield buffers, [out._spvalue for out in out_traces] +def sparsify_subtrace(tag, spenv, spvalues, *bufs): + with core.take_current_trace() as parent: + trace = SparseTrace(parent, tag, spenv) + with core.set_current_trace(trace): + in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] + outs = yield in_tracers, {} + out_traces = [trace.to_sparse_tracer(out) for out in outs] + buffers = spenv._buffers + yield buffers, [out._spvalue for out in out_traces] def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]): - with core.new_main(SparseTrace) as main: - spenv = SparsifyEnv() - spvalues = arrays_to_spvalues(spenv, args) - in_bufs = spenv._buffers - fun, out_spvalues = sparsify_subtrace(wrapped_fun, main, spvalues) - out_bufs = fun.call_wrapped(*in_bufs) - spenv = SparsifyEnv(out_bufs) - del main + tag = core.TraceTag() + spenv = SparsifyEnv() + spvalues = arrays_to_spvalues(spenv, args) + in_bufs = spenv._buffers + fun, out_spvalues = sparsify_subtrace(wrapped_fun, tag, spenv, spvalues) + out_bufs = fun.call_wrapped(*in_bufs) + spenv = SparsifyEnv(out_bufs) return spvalues_to_arrays(spenv, out_spvalues()) def _sparsify_with_tracer(fun): diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 28816afb01e3..160a96fae368 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -18,8 +18,6 @@ from __future__ import annotations from jax._src.interpreters.ad import ( - CustomJVPException as CustomJVPException, - CustomVJPException as CustomVJPException, JVPTrace as JVPTrace, JVPTracer as JVPTracer, UndefinedPrimal as UndefinedPrimal, @@ -67,7 +65,6 @@ vjp as vjp, zero_jvp as zero_jvp, zeros_like_aval as zeros_like_aval, - zeros_like_jaxval as zeros_like_jaxval, zeros_like_p as zeros_like_p, ) diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 607fc6fa596d..7a93a6942c21 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -50,6 +50,7 @@ defbroadcasting as defbroadcasting, defreducer as defreducer, defvectorized as defvectorized, + fancy_primitive_batchers as fancy_primitive_batchers, flatten_fun_for_vmap as flatten_fun_for_vmap, from_elt as from_elt, from_elt_handlers as from_elt_handlers, @@ -64,7 +65,6 @@ reducer_batcher as reducer_batcher, register_vmappable as register_vmappable, spec_types as spec_types, - spmd_axis_primitive_batchers as spmd_axis_primitive_batchers, to_elt as to_elt, to_elt_handlers as to_elt_handlers, unregister_vmappable as unregister_vmappable, diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 3c63948bee63..1aa3ebc67b06 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -62,7 +62,6 @@ debug_info as debug_info, debug_info_final as debug_info_final, def_trivial_padding as def_trivial_padding, - extend_jaxpr_stack as extend_jaxpr_stack, forwarding_rules as forwarding_rules, infer_lambda_input_type as infer_lambda_input_type, instantiate_const_at as instantiate_const_at, @@ -81,15 +80,9 @@ recipe_to_eqn as recipe_to_eqn, result_info as result_info, sig_info as sig_info, - trace_to_jaxpr as trace_to_jaxpr, trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic, trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2, - trace_to_jaxpr_final as trace_to_jaxpr_final, - trace_to_jaxpr_final2 as trace_to_jaxpr_final2, trace_to_jaxpr_nounits as trace_to_jaxpr_nounits, - trace_to_subjaxpr as trace_to_subjaxpr, - trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic, - trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2, trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 7f42cfca5fe8..5f3bfa057912 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -330,7 +330,6 @@ linear_solve_p as linear_solve_p, map as map, scan as scan, - scan_bind as scan_bind, scan_p as scan_p, switch as switch, while_loop as while_loop, diff --git a/tests/api_test.py b/tests/api_test.py index 2c2412093805..197784d99772 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1458,6 +1458,8 @@ def test_caches_depend_on_axis_env(self): ans = jax.vmap(g, axis_name="i", axis_size=3, out_axes=None)() self.assertEqual(ans, expected) + # Since stackless, the vmap(f) version gets compiled a second time + @unittest.skip def test_caches_dont_depend_on_unnamed_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 f = jax.jit(lambda: jnp.sin(1)) @@ -3004,9 +3006,11 @@ def test_error_for_invalid_dtype(self): with jax.enable_checks(False): with self.assertRaisesRegex(TypeError, err_str): lax.add(jnp.array(7), np.array("hello")) - with jax.enable_checks(True): - with self.assertRaises(AssertionError): - lax.add(jnp.array(7), np.array("hello")) + # TODO(dougalm): re-enable checks at the beginning of `bind`. We just + # need to know which arguments to a generic primitive are ordinary operands vs functions. + # with jax.enable_checks(True): + # with self.assertRaises(AssertionError): + # lax.add(jnp.array(7), np.array("hello")) def test_vmap_preserves_docstr(self): def superfun(a): @@ -3438,13 +3442,10 @@ def test_escaped_tracers_cant_lift_sublevels(self): re.DOTALL)): api.jit(lambda x: x)(self._saved_tracer) + @unittest.skip # TODO(dougalm): rethink what this should do under stackless def test_escaped_tracers_tracer_from_higher_level(self): api.grad(self.helper_save_tracer)(0.) - with self.assertRaisesRegex( - UnexpectedTracerError, - re.compile( - "Encountered an unexpected tracer.*Tracer from a higher level", - re.DOTALL)): + with self.assertRaises(UnexpectedTracerError): api.grad(lambda x: x)(self._saved_tracer) def test_escaped_tracers_incompatible_sublevel(self): @@ -3464,8 +3465,7 @@ def func1(x): return x + self._saved_tracer with self.assertRaisesRegex( UnexpectedTracerError, - re.compile("Encountered an unexpected tracer.*Can't lift", - re.DOTALL)): + re.compile("unexpected tracer")): api.grad(func1)(2.) def test_escaped_tracers_not_among_input_tracers(self): @@ -3860,7 +3860,7 @@ def g(x): x = g(x) return x - msg = r'Leaked trace MainTrace\(2,DynamicJaxprTrace\)' + msg = r'Leaked trace DynamicJaxprTrace' with self.assertRaisesRegex(Exception, f"{msg}"): f(3) @@ -4725,6 +4725,7 @@ def f(inputs): for a, b in zip(ans, expected): self.assertAllClose(a, b) + @unittest.skip # TODO(dougalm): figure out with Matt what to do with this feature def test_inner_jit_forwarded_consts_stay_const(self): out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash self.assertEqual(out, 3) @@ -4874,6 +4875,7 @@ def g(x): msg = str(e) self.assertNotIn('static_argnums', msg) + @unittest.skip def test_remat_grad_python_control_flow_static_argnums(self): @partial(jax.remat, static_argnums=(0,)) def g(x): @@ -4896,6 +4898,7 @@ def f(x): expected = np.cos(2.) self.assertAllClose(ans, expected, check_dtypes=False) + @unittest.skip def test_remat_grad_python_control_flow_unhashable_static_argnums(self): @partial(jax.remat, static_argnums=(0,)) def g(x): @@ -7138,8 +7141,8 @@ def g_jvp(primals, tangents): g.defjvp(g_jvp) return g(1.) - self.assertRaises(ad.CustomJVPException, lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaises(ad.CustomJVPException, lambda: api.grad(f)(3.)) + self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) def test_nondiff_arg(self): @partial(jax.custom_jvp, nondiff_argnums=(0,)) @@ -7214,7 +7217,7 @@ def g_jvp(h, primals, tangents): h = lambda y: x + y # capture x return g(h, x) - with self.assertRaisesRegex(ad.CustomJVPException, "Detected differentiation"): + with self.assertRaises(UnexpectedTracerError): api.jvp(f, (2.,), (1.,)) def test_vmap_axes(self): @@ -7625,8 +7628,8 @@ def f_jvp(primals, _): f.defjvp(f_jvp) primals = (2., 3) - tangents = (np.ones(()), np.zeros((), float0),) - expected_tangents = (2., np.zeros((), float0)) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) self.assertAllClose(api.jvp(f, primals, tangents), (primals, expected_tangents)) diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 438ba55203a9..9e0ebd4ff922 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -223,7 +223,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase): [dict(for_impl=for_impl, impl_name=impl_name) for for_impl, impl_name in FOR_LOOP_IMPLS], ) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name, impl_name): for_ = for_impl @@ -255,7 +255,7 @@ def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name, [dict(for_impl=for_impl, impl_name=impl_name) for for_impl, impl_name in FOR_LOOP_IMPLS], ) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? def test_for_linearize(self, f, ref, body_shapes, n, for_impl, for_body_name, impl_name): for_ = for_impl @@ -365,7 +365,7 @@ def g(a, b): [dict(for_impl=for_impl, impl_name=impl_name) for for_impl, impl_name in FOR_LOOP_IMPLS], ) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name, impl_name): @@ -385,7 +385,7 @@ def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name, jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2, rtol=7e-3, atol=1e-2) - @jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts? + @jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts? @jax.legacy_prng_key('allow') def test_grad_of_triple_nested_for_loop(self): diff --git a/tests/infeed_test.py b/tests/infeed_test.py index e378fe37a2f5..5dd52b4167d5 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -37,6 +37,7 @@ def setUp(self): @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. def testInfeed(self): + raise SkipTest("skipping temporarily for stackless") @jax.jit def f(x): @@ -56,6 +57,7 @@ def f(x): self.assertAllClose(f(x), x + y + z) def testInfeedPytree(self): + raise SkipTest("skipping temporarily for stackless") x = np.float32(1.5) y = np.reshape(np.arange(12, dtype=np.int16), (3, 4)) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 7fb118d47256..79d5fb79b44c 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2095,6 +2095,7 @@ def apply_carry(x, i): jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash def testIssue804(self): + # https://github.com/google/jax/issues/804 num_devices = jax.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 9a8d0b91272b..6e0e795df334 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2057,7 +2057,7 @@ def testSizeOverflow(self): def test_axis_env_length(self): f = lambda x: jax.pmap(g)(jnp.array([x]))[0] def g(x): - assert len(core.thread_local_state.trace_state.axis_env) == 1 + assert len(core.get_axis_env().axis_names()) == 1 return x jax.grad(f)(3.) # doesn't fail diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index 38bd7e05533e..d141bc15c249 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -20,7 +20,6 @@ from absl.testing import absltest import jax from jax._src import config -from jax._src import dispatch from jax._src import test_util as jtu from jax._src.lax import lax from jax.experimental.xla_metadata import set_xla_metadata @@ -65,7 +64,7 @@ def f(a, b): def test_f_nonjitted(self): def f_add(a, b): - return dispatch.apply_primitive(lax.add_p, a, b) + return lax.add(a, b) arg1 = jnp.arange(2) with set_xla_metadata(a="b"): @@ -126,7 +125,7 @@ def f_add_jit(a, b): def test_attr_caching_nonjit(self): def f_add(a, b): - return dispatch.apply_primitive(lax.add_p, a, b) + return lax.add(a, b) arg1 = jnp.arange(2) arg2 = jnp.arange(2) + 1 From 63e8aff2685638270e4520469098e037f885475f Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 29 Oct 2024 11:44:38 -0700 Subject: [PATCH 103/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b5690e93ea1e4da5ca9f135d0a0e5796694e706a. PiperOrigin-RevId: 691102818 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 084ae0c1f3ba..47c5c832a19d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "7b4c7f36ccb2a0afa511d98fe4cb024599c275ae" -XLA_SHA256 = "3ebbee39182dfc8373e870aa69aa9821b6a5149da440a3f7503bdd8c8073165e" +XLA_COMMIT = "b5690e93ea1e4da5ca9f135d0a0e5796694e706a" +XLA_SHA256 = "38505101e6c62b8afd29c31eba6ac7e4f0709aaba6d1c3006bd6afdb9757cf9b" def repo(): tf_http_archive( From b9ad519a2950156eb780bfed50dea90f88cf8673 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 29 Oct 2024 12:34:46 -0700 Subject: [PATCH 104/698] Implement device_get for typed PRNG keys --- jax/_src/api.py | 7 +++++++ jax/_src/prng.py | 5 +++++ tests/random_test.py | 5 +++++ 3 files changed, 17 insertions(+) diff --git a/jax/_src/api.py b/jax/_src/api.py index 0c46517b2191..c0f411a7c26e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2445,6 +2445,13 @@ def _device_put_replicated(x): def _device_get(x): if isinstance(x, core.Tracer): return x + if dtypes.issubdtype(getattr(x, "dtype", None), dtypes.extended): + try: + to_device = x.dtype._rules.device_get + except AttributeError: + pass + else: + return to_device(x) try: toarray = x.__array__ except AttributeError: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 039b0a309775..0a5a1dff2659 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -400,6 +400,11 @@ def make_sharded_array(aval, sharding, arrays, committed): phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) + @staticmethod + def device_get(val): + buffer = api.device_get(random_unwrap(val)) + return random_wrap(buffer, impl=val.dtype._impl) + @staticmethod def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) diff --git a/tests/random_test.py b/tests/random_test.py index da182dbccae9..fed12792d5c6 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -936,6 +936,11 @@ def f(x): x = jnp.array([True, False, False]) f(x) # doesn't crash + def test_device_get(self): + keys = self.make_keys(4) + keys_on_host = jax.device_get(keys) + self.assertKeysEqual(keys, keys_on_host) + def test_device_put(self): device = jax.devices()[0] keys = self.make_keys(4) From 80fde785f536a2aac9178ce38b57da184570d405 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 29 Oct 2024 20:46:07 +0000 Subject: [PATCH 105/698] Fix a reference cycle bug. When we use a context manager within a linear_util.transformation we should leave the scope of the context manager before the final yield. Otherwise we create spurious reference cycles. This was causing CoreTest.test_reference_cycles to fail on Python 3.10 (but not 3.13 for some reason). --- jax/_src/interpreters/ad.py | 4 ++-- jax/_src/interpreters/partial_eval.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 9b350fdd6a87..b9cace3dec70 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -93,7 +93,7 @@ def jvp_subtrace(tag, primals, tangents): with core.set_current_trace(trace): ans = yield in_tracers, {} out = unzip2(map(trace.to_primal_tangent_pair, ans)) - yield out + yield out @lu.transformation_with_aux def jvp_subtrace_aux(tag, primals, tangents): @@ -104,7 +104,7 @@ def jvp_subtrace_aux(tag, primals, tangents): out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag else x for x in aux] - yield (out_primals, out_tangents), aux_primals + yield (out_primals, out_tangents), aux_primals def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 00c970186673..5bfb758040e3 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -595,7 +595,7 @@ def trace_to_subjaxpr_nounits2( trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + yield jaxpr, (out_pvals, out_consts, env) def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): in_knowns = [pval.is_known() for pval in in_pvals] @@ -641,7 +641,7 @@ def trace_to_subjaxpr_nounits_fwd( pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] del out_tracers - yield jaxpr, (fwds, out_pvals, pruned_consts, env) + yield jaxpr, (fwds, out_pvals, pruned_consts, env) # The below variant implements two optimizations: # 1. residuals that are also primal inputs are indicated in aux data rather From 5ad066eeaad60ce48d7d0afb9069d65a161cf91a Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 29 Oct 2024 15:56:44 -0700 Subject: [PATCH 106/698] [TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases) PiperOrigin-RevId: 691192121 --- jaxlib/mosaic/dialect/tpu/tpu.td | 2 + .../tpu/transforms/apply_vector_layout.cc | 62 ------------------- .../tpu/transforms/canonicalize_mosaic.cc | 27 +++++++- .../tpu/transforms/infer_vector_layout.cc | 10 --- tests/pallas/tpu_pallas_test.py | 19 +++++- 5 files changed, 45 insertions(+), 75 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index faee869663c4..783101e839b1 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -295,6 +295,8 @@ def TPU_IotaOp : TPU_Op<"iota", [Pure]> { let assemblyFormat = [{ attr-dict `:` type($output) }]; } +// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically. +// b/376295711 def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { let arguments = (ins AnyVector:$source, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 79ebb725ccab..1b4b2cad1f97 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -170,25 +170,6 @@ FailureOr> getInternalScratch( .getResult(); } -// Models Numpy's np.repeat, repeating each element `repeats` times along the -// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is -// 3, this will return [1, 1, 1, 2, 2, 2]. -xla::Array repeat(const xla::Array &src, const int repeats, - const int64_t axis) { - SmallVector dims(toArrayRef(src.dimensions())); - dims[axis] *= repeats; - xla::Array res(dims); - src.Each([&](absl::Span idx, const Value v) { - SmallVector res_idx(toArrayRef(idx)); - res_idx[axis] *= repeats; - for (int i = 0; i < repeats; ++i) { - res(res_idx) = v; - ++res_idx[axis]; - } - }); - return res; -} - // Models Numpy's np.concatenate xla::Array concatenate(const ArrayRef> arrays, const int64_t axis) { @@ -2949,48 +2930,6 @@ LogicalResult tpu_region_rule(RewriteContext &ctx, Operation &op, return success(); } -LogicalResult tpu_repeat_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(layouts_out.front().has_value()); - const VectorLayout &layout_in = *layouts_in.front(); - const VectorLayout &layout_out = *layouts_out.front(); - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError("Not implemented: Only 2D layouts supported"); - } - if (layout_in != layout_out) { - return op.emitOpError("Not implemented: Changing layout mid-repeat"); - } - if (!layout_in.hasNaturalTopology(ctx.target_shape) || - layout_in.offsets() != LayoutOffsets{0, 0}) { - return op.emitOpError("Not implemented: Non-trivial layouts unsupported"); - } - OpBuilder builder(&op); - tpu::RepeatOp repeat_op = cast(op); - VectorType src_ty = repeat_op.getSource().getType(); - const uint32_t dim = repeat_op.getDimension(); - if (dim != src_ty.getRank() - 1) { - return op.emitOpError( - "Not implemented: Only repeats along the last dim supported"); - } - if (src_ty.getShape().back() % ctx.target_shape.back() != 0) { - return op.emitOpError("Not implemented: Only free repeats are suppported"); - } - FAILUREOR_ASSIGN_OR_RETURN( - const xla::Array &in_vregs, - disassemble(builder, layout_in, repeat_op.getSource(), ctx.target_shape)); - xla::Array out_vregs = repeat(in_vregs, repeat_op.getTimes(), dim); - repeat_op->replaceAllUsesWith( - assemble(builder, repeat_op.getResult().getType(), layout_out, out_vregs, - ctx.target_shape) - .getOperation()); - repeat_op->erase(); - return success(); -} - LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -4648,7 +4587,6 @@ const llvm::StringMap &rules() { {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, {tpu::RegionOp::getOperationName(), tpu_region_rule}, - {tpu::RepeatOp::getOperationName(), tpu_repeat_rule}, {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, {tpu::TraceOp::getOperationName(), tpu_trace_rule}, {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 9f2a8ed73a44..b95ff5067734 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -350,6 +351,29 @@ LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) { return success(); } +LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) { + auto op = dyn_cast(raw_op); + if (!isa(op.getType())) { + return op.emitOpError("Only vector types supported"); + } + auto operand = op.getSource(); + auto times = op.getTimes(); + if (times == 1) { + // A true no op - kind of an odd edge case, but this does come up in + // flash_attention_backward tests. + op.replaceAllUsesWith(operand); + op.erase(); + return success(); + } + auto operands = std::vector(times, operand); + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto concat = builder.create(op.getLoc(), op.getType(), + operands, op.getDimension()); + op.replaceAllUsesWith(concat.getResult()); + op.erase(); + return success(); +} + using canonicalize_rule_type = std::function; @@ -360,7 +384,8 @@ const llvm::StringMap &rules() { {vector::ContractionOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}, - {arith::SelectOp::getOperationName(), canonicalize_select}}; + {arith::SelectOp::getOperationName(), canonicalize_select}, + {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 9fcb8afc7a47..a9b5ed6876b2 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -288,10 +288,6 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { - return failure(); - } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); @@ -1020,12 +1016,6 @@ class VectorLayoutInferer { return success(); } - LogicalResult infer(tpu::RepeatOp op) { - auto src_layout = getLayout(op.getSource()); - setLayout(op, src_layout, src_layout); - return success(); - } - LogicalResult infer(tpu::TraceOp op) { static LogicalResult (*match_yield)(Operation *) = [](Operation *op) { TPU_CHECK_OP(isa(op), "expected yield terminator"); diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 49dd127b76fe..d92991caa6fe 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1587,7 +1587,6 @@ def kernel(x, y): self.assertEqual(analysis_result['transcendentals'], 21) self.assertEqual(analysis_result['bytes accessed'], 12345) - def test_cost_analysis_vmap(self): def kernel(x, y): y[:] = x[:] @@ -1606,7 +1605,6 @@ def kernel(x, y): self.assertEqual(analysis_result['transcendentals'], batch_size * 21) self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345) - def test_vmem_limit(self): shape = (128, 128) @@ -1673,6 +1671,23 @@ def kernel(x_ref, y_ref): ), )(x) + @parameterized.product(dtype=[jnp.bfloat16, jnp.float32]) + def test_pltpu_repeat(self, dtype): + def test_kernel(x_ref, o_ref): + x = x_ref[...] + o_ref[...] = pltpu.repeat(x, 2, axis=1) + + @jax.jit + def test(x: jax.Array) -> jax.Array: + return pl.pallas_call( + test_kernel, + out_shape=jax.ShapeDtypeStruct([x.shape[0], x.shape[1] * 2], x.dtype), + )(x) + + x = jnp.arange(2048, dtype=dtype).reshape((8, 256)) + y = test(x) + np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1)) + class PallasUXTest(PallasBaseTest): From 539c94094676650d62c326a4135eca2fc61856d3 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 29 Oct 2024 16:08:23 -0700 Subject: [PATCH 107/698] Removed unused `_tan_impl` Also removed the legacy lowering for `tan_p`. PiperOrigin-RevId: 691195720 --- jax/_src/lax/lax.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index bbb23bcd1725..0d8bfafe6932 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -60,7 +60,6 @@ from jax._src.lax.utils import ( _input_dtype, dtype_to_string, standard_abstract_eval, standard_multi_result_abstract_eval, standard_primitive) -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2388,21 +2387,9 @@ def _cos_lowering(ctx, x): ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) mlir.register_lowering(cos_p, _cos_lowering) -@_upcast_fp16_for_computation -def _tan_impl(x): - return div(sin(x), cos(x)) - tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this -# lowering is mostly supported, but it fails on export or with the PJRT plugin -# because those modes target an older StableHLO version, and the -# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't -# included in the 0.4.33 release. -if jaxlib_version <= (0, 4, 33): - mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) -else: - mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): From b65fdcc61266308eb875d053f74487fc28cbd15f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 29 Oct 2024 16:41:12 -0700 Subject: [PATCH 108/698] pallas: remove build dependency on jax.experimental.export jax.experimental.export is deprecated, and it looks like the build rule is unused. PiperOrigin-RevId: 691205626 --- tests/pallas/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 09728e432a8c..b1f1b12a1d70 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -221,7 +221,6 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep - "//jax/experimental/export", ], ) From 6d8950c04f23ad15a0443006f1e5bd21bfa84156 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Tue, 29 Oct 2024 16:50:16 -0700 Subject: [PATCH 109/698] Cleanup requirements.in and test-requirements.txt PiperOrigin-RevId: 691208596 --- build/requirements.in | 5 ----- build/test-requirements.txt | 5 ++++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index a8d81fa5c670..e122aaa4ad78 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -3,11 +3,6 @@ # -r test-requirements.txt -# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement -# below. -matplotlib~=3.8.4; python_version<="3.10" -matplotlib; python_version>="3.11" - # # build deps # diff --git a/build/test-requirements.txt b/build/test-requirements.txt index bec6afce1853..7e7e5b847009 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -6,7 +6,6 @@ filelock flatbuffers hypothesis mpmath>=1.3 -numpy>=1.22 pillow>=10.4.0 portpicker pytest-xdist @@ -14,3 +13,7 @@ wheel rich # TODO(ybaturina): remove setuptools version setuptools<71.0.0 +# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement +# below. +matplotlib~=3.8.4; python_version<="3.10" +matplotlib; python_version>="3.11" \ No newline at end of file From 249f0101b336a332f2ac3a14b836ef7160c4a5fb Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 29 Oct 2024 16:52:25 -0700 Subject: [PATCH 110/698] Use approximate cost estimates for flash attention instead of reference XLA estimates. PiperOrigin-RevId: 691209201 --- .../pallas/ops/tpu/flash_attention.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 82bcde8153ef..9b122fcc03ef 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -574,28 +574,26 @@ def _fwd_cost_estimate( q: jax.Array, k: jax.Array, v: jax.Array, - ab: jax.Array | None, - segment_ids: SegmentIds | None, *, - causal: bool, - sm_scale: jax.Array | None, kernel_inputs_specs, kernel_outputs_specs, ) -> pl.CostEstimate | None: - full_cost = ( - mha_reference.lower( - q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale - ) - .compile() - .cost_analysis() - ) - if not full_cost: - return None + b, h, tq, dqk = q.shape + tk = k.shape[-2] + dv = v.shape[-1] + + # Simplify flop computation to include only matmul operations. + qk_flops = 2 * tq * tk * dqk + av_flops = 2 * tq * tk * dv + per_head_flops = qk_flops + av_flops + flops = b * h * per_head_flops + + transcendentals = b * tq * tk * h input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs)) output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs)) return pl.CostEstimate( - flops=full_cost[0]["flops"], - transcendentals=full_cost[0]["transcendentals"], + flops=flops, + transcendentals=transcendentals, bytes_accessed=input_bytes + output_bytes, ) @@ -792,10 +790,6 @@ def kv_segment_ids_index_map( q, k, v, - ab, - segment_ids, - causal=causal, - sm_scale=sm_scale, kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids), kernel_outputs_specs=out_shape, ), From 72f9a493589a1046e6927a5f16d7dc71df530743 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Oct 2024 17:46:08 -0700 Subject: [PATCH 111/698] Reverts 6d8950c04f23ad15a0443006f1e5bd21bfa84156 PiperOrigin-RevId: 691222756 --- build/requirements.in | 5 +++++ build/test-requirements.txt | 5 +---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index e122aaa4ad78..a8d81fa5c670 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -3,6 +3,11 @@ # -r test-requirements.txt +# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement +# below. +matplotlib~=3.8.4; python_version<="3.10" +matplotlib; python_version>="3.11" + # # build deps # diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 7e7e5b847009..bec6afce1853 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -6,6 +6,7 @@ filelock flatbuffers hypothesis mpmath>=1.3 +numpy>=1.22 pillow>=10.4.0 portpicker pytest-xdist @@ -13,7 +14,3 @@ wheel rich # TODO(ybaturina): remove setuptools version setuptools<71.0.0 -# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement -# below. -matplotlib~=3.8.4; python_version<="3.10" -matplotlib; python_version>="3.11" \ No newline at end of file From e35e7f8e205632c6914cabaea3f54b89c35985b5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 29 Oct 2024 17:58:12 -0700 Subject: [PATCH 112/698] Allow sparsecore compute with T(8) layout via the layout API and `compute_on` API. To annotate compute on sparsecore, use `@compute_on('tpu_sparsecore')`. PiperOrigin-RevId: 691225280 --- jax/_src/compute_on.py | 7 ++++--- jax/_src/interpreters/mlir.py | 6 +++++- tests/BUILD | 3 +++ tests/layout_test.py | 24 ++++++++++++++++++++++++ 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 4495d38f9da8..b5194ddad21d 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -46,9 +46,10 @@ def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None def _check_valid(c_type: str): - if c_type not in {'device_host', 'device'}: - raise ValueError('Invalid compute type received. Current supported values ' - f'are `device_host` and `device`. Got {c_type}') + if c_type not in {'device_host', 'device', 'tpu_sparsecore'}: + raise ValueError( + 'Invalid compute type received. Current supported values ' + f'are `device_host`, `device` and `tpu_sparsecore`. Got {c_type}') @contextmanager def compute_on(compute_type: str): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2adeb4b16cd9..c71e52385386 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1878,6 +1878,8 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None return () if eqn_ctx.compute_type == 'device_host': return ('cpu',) + if eqn_ctx.compute_type == 'tpu_sparsecore': + return ('tpu',) return () @@ -2160,8 +2162,10 @@ def map_compute_type(c_type): return 'host' elif c_type == 'device': return 'dense' + elif c_type == 'tpu_sparsecore': + return 'sparse' raise ValueError('Invalid compute type received. Current supported values ' - 'are `device_host` and `device`') + 'are `device_host`, `device` and `tpu_sparsecore') def wrap_compute_type_in_place(ctx, op): if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: diff --git a/tests/BUILD b/tests/BUILD index 657d169ba18a..316e98f5ba31 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -271,6 +271,9 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, tags = ["multiaccelerator"], + deps = [ + "//jax:experimental", + ], ) jax_multiplatform_test( diff --git a/tests/layout_test.py b/tests/layout_test.py index 9d26d96e2ae5..406d06dacc9f 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -25,6 +25,7 @@ from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() @@ -600,6 +601,29 @@ def g(x): ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"): g(jnp.arange(8)) + def test_sparsecore_compute(self): + if not (jtu.is_device_tpu('5', 'f') or jtu.is_device_tpu_at_least(6)): + self.skipTest('Does not have a sparsecore present') + shape = (128, 128) + inp = jnp.arange(math.prod(shape)).reshape(shape) + + dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + s = SingleDeviceSharding(jax.devices()[0]) + sparse_layout = Layout(dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + dense_layout = Layout(DLL(major_to_minor=(0, 1)), s) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_compute(x): + return x * x + + @partial(jax.jit, out_shardings=(dense_layout, sparse_layout)) + def f(x, y): + return x * 2, sparsecore_compute(y) + + f(inp, sparecore_arr) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 908c8a828082ce8a0a0d7e021add9f8676b3897b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 02:39:00 -0700 Subject: [PATCH 113/698] Removed unused `_get_memory_space_from_ref` PiperOrigin-RevId: 691342830 --- jax/_src/pallas/pallas_call.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index e20c7783439e..43cd4c2aca8e 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1558,12 +1558,6 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) -def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any: - if isinstance(ref_aval, pallas_core.AbstractMemoryRef): - return ref_aval.memory_space - return pallas_core.MemorySpace.ANY - - @state_discharge.register_discharge_rule(pallas_call_p) def _pallas_call_state_discharge_rule( avals_in, From bdf2ca10fc2d8a3a9cb2beadb5bd3ec398ca2b18 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 02:39:01 -0700 Subject: [PATCH 114/698] Removed more dead code from various submodules PiperOrigin-RevId: 691342832 --- jax/_src/custom_derivatives.py | 6 ------ jax/_src/interpreters/pxla.py | 5 ----- jax/_src/lax/parallel.py | 16 ---------------- jax/_src/pjit.py | 9 --------- 4 files changed, 36 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 0b57ff9028f6..77f73562aecd 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -799,12 +799,6 @@ def bind_with_trace(self, trace, args, params): custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') -def _apply_bwd_transform(todos, bwd): - todos_list = list(todos) - while todos_list: - bwd = todos_list.pop()(bwd) - return bwd - def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): return core.jaxpr_as_fun(fun_jaxpr)(*args) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 02ec54ba5d39..e59d8c89e3e8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -172,11 +172,6 @@ def is_default_layout(curr_layout, sharding, aval): raise -@lru_cache(maxsize=1024) -def _get_replicated_slices(num_addressable_devices: int): - return ((slice(None),),) * num_addressable_devices - - def _masked_array_error(xs, shardings, layouts): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index cbea424a9d95..932fd4b88c08 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1554,22 +1554,6 @@ def _pgather_parallel_lowering(ctx, src, idx, *, axes): return mlir.lower_fun(_pgather_impl, multiple_results=False)( ctx, src, idx, axes=axes) -def _pgather_batcher(vals_in, dims_in, *, axes): - src, idx = vals_in - dsrc, didx = dims_in - if didx is not batching.not_mapped and dsrc is not batching.not_mapped: - # NB: We could just go forward with it and take the diagonal along the - # two axes we get in the output, but that would be quite inefficient - raise NotImplementedError("Please open a feature request!") - elif didx is not batching.not_mapped: - return pgather_p.bind(src, idx, axes=axes), didx - elif dsrc is not batching.not_mapped: - src_last_batched = moveaxis(src, dsrc, -1) - result = pgather_p.bind(src_last_batched, idx, axes=axes) - return result, result.ndim - 1 - else: - assert False # This shouldn't get called anyway - def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, axes): src, idx = vals_in dsrc, didx = dims_in diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 904e92af2f91..3d8df8664052 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2066,15 +2066,6 @@ def _filter_zeros(is_nz_l, l): ad.primitive_jvps[pjit_p] = _pjit_jvp -@weakref_lru_cache -def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr, - in_fwd: tuple[int | None, ...]) -> core.ClosedJaxpr: - updated_jaxpr = known_jaxpr.jaxpr.replace( - outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, in_fwd) - if i is None]) - return known_jaxpr.replace(jaxpr=updated_jaxpr) - - def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, From f1c3109bf503435cf0bae37e744510493aae621d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 02:39:05 -0700 Subject: [PATCH 115/698] Removed `mesh_utils._bounds_from_last_device` which was only used in tests PiperOrigin-RevId: 691342846 --- jax/_src/mesh_utils.py | 10 ---------- tests/mesh_utils_test.py | 25 ------------------------- 2 files changed, 35 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index bb6152167658..c37bbba4d836 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -572,16 +572,6 @@ def _generate_logical_mesh( return logical_mesh -def _bounds_from_last_device(last_device) -> Sequence[int]: - """Gets the bound from the given last device.""" - # Must be passed the device at the highest-coordinate corner of the - # relevant mesh, which is a requirement we know is satisfied by the last - # device in jax.devices(). - assert hasattr(last_device, 'coords'), 'Only TPU supported' - x, y, z = last_device.coords - return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 - - def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: r"""Rearrange TPU devices in a slice into a physical mesh. diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 42522d7f4b1b..66f1fc9f6cfb 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -205,31 +205,6 @@ def mock_2x2x2_v5e_devices(one_device_per_chip=True): class MeshUtilsTest(test_util.JaxTestCase): - @parameterized.named_parameters( - ('1x1', mock_1x1_devices, (1, 1, 1, 2)), - ('2x2', mock_2x2_devices, (2, 2, 1, 2)), - ('4x4', mock_4x4_devices, (4, 4, 1, 2)), - ('8x8', mock_8x8_devices, (8, 8, 1, 2)), - ) - def test_bounds_from_last_device_2d(self, devices, expected_bounds): - self.assertEqual( - mesh_utils._bounds_from_last_device(devices()[-1]), - expected_bounds) - - @parameterized.named_parameters( - ('1x2x1_t', mock_1x2x1_devices, True, (1, 2, 1, 1)), - ('1x2x1_f', mock_1x2x1_devices, False, (1, 2, 1, 2)), - ('2x2x1_t', mock_2x2x1_devices, True, (2, 2, 1, 1)), - ('2x2x1_f', mock_2x2x1_devices, False, (2, 2, 1, 2)), - ('8x8x16_t', mock_8x8x16_devices, True, (8, 8, 16, 1)), - ('8x8x16_f', mock_8x8x16_devices, False, (8, 8, 16, 2)), - ) - def test_bounds_from_last_device_3d(self, devices, one_device_per_chip, - expected_bounds): - self.assertEqual( - mesh_utils._bounds_from_last_device(devices(one_device_per_chip)[-1]), - expected_bounds) - @parameterized.named_parameters( ('1x2x1_t', (1, 2, 1), True), ('4x4x4_t', (4, 4, 4), True), From e61a20b45adcbe21787dd59a620b4f7d4c021abd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 30 Oct 2024 05:27:29 -0700 Subject: [PATCH 116/698] Remove deprecated jax.experimental.export module. These tools are now available at jax.export. --- CHANGELOG.md | 3 ++ jax/experimental/export/BUILD | 42 ----------------- jax/experimental/export/__init__.py | 73 ----------------------------- tests/BUILD | 3 -- tests/export_test.py | 16 ------- 5 files changed, 3 insertions(+), 134 deletions(-) delete mode 100644 jax/experimental/export/BUILD delete mode 100644 jax/experimental/export/__init__.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 542a7d417269..9b629631ea4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * In `jax.interpreters.xla`, the `xb`, `xc`, and `xe` symbols have been removed after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. + * The deprecated module `jax.experimental.export` has been removed. It was replaced + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + for information on migrating to the new API. ## jax 0.4.35 (Oct 22, 2024) diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD deleted file mode 100644 index 1246b0d407af..000000000000 --- a/jax/experimental/export/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# JAX-export provides APIs for exporting StableHLO for serialization purposes. - -load("@rules_python//python:defs.bzl", "py_library") -load( - "//jaxlib:jax.bzl", - "py_deps", -) - -licenses(["notice"]) - -package( - default_applicable_licenses = [], - default_visibility = ["//visibility:private"], -) - -py_library( - name = "export", - srcs = [ - "__init__.py", - ], - srcs_version = "PY3", - # TODO: b/255503696: enable pytype - tags = ["pytype_unchecked_annotations"], - visibility = ["//visibility:public"], - deps = [ - "//jax", - ] + py_deps("numpy") + py_deps("flatbuffers"), -) diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py deleted file mode 100644 index d49aa296328a..000000000000 --- a/jax/experimental/export/__init__.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2023 The JAX Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -_deprecation_message = ( - "The jax.experimental.export module is deprecated. " - "Use jax.export instead. " - "See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export." -) - -from jax._src.export import _export as _src_export -from jax._src.export import shape_poly as _src_shape_poly -from jax._src.export import serialization as _src_serialization -# Import only to set the shape poly decision procedure -from jax._src.export import shape_poly_decision -del shape_poly_decision - -# All deprecations added Jun 14, 2024 -_deprecations = { - # Added Jun 13, 2024 - "Exported": (_deprecation_message, _src_export.Exported), - "DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck), - "export": (_deprecation_message, _src_export.export_back_compat), - "call": (_deprecation_message, _src_export.call), - "call_exported": (_deprecation_message, _src_export.call_exported), - "default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform), - "minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version), - "maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version), - - "serialize": (_deprecation_message, _src_serialization.serialize), - "deserialize": (_deprecation_message, _src_serialization.deserialize), - - "SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope), - "is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim), - "symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape), - "symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs), -} - -import typing -if typing.TYPE_CHECKING: - Exported = _src_export.Exported - DisabledSafetyCheck = _src_export.DisabledSafetyCheck - export = _src_export.export_back_compat - call = _src_export.call - call_exported = _src_export.call_exported - default_lowering_platform = _src_export.default_lowering_platform - - serialize = _src_serialization.serialize - deserialize = _src_serialization.deserialize - - SymbolicScope = _src_shape_poly.SymbolicScope - is_symbolic_dim = _src_shape_poly.is_symbolic_dim - symbolic_shape = _src_shape_poly.symbolic_shape - symbolic_args_specs = _src_shape_poly.symbolic_args_specs -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing -del _src_export -del _src_serialization -del _src_shape_poly diff --git a/tests/BUILD b/tests/BUILD index 316e98f5ba31..39e1d35a3407 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1409,9 +1409,6 @@ jax_multiplatform_test( "tpu_v3_2x2", ], tags = [], - deps = [ - "//jax/experimental/export", - ], ) jax_multiplatform_test( diff --git a/tests/export_test.py b/tests/export_test.py index fd6bef11ee43..b6dde23721a3 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -244,22 +244,6 @@ def test_export_error_no_jit(self): "Function to be exported must be the result of `jit`"): _ = export.export(lambda x: jnp.sin(x)) - @jtu.ignore_warning(category=DeprecationWarning, - message="The jax.experimental.export module is deprecated") - def test_export_experimental_back_compat(self): - if not CAN_SERIALIZE: - self.skipTest("serialization disabled") - from jax.experimental import export - # Can export a lambda, without jit - exp = export.export(lambda x: jnp.sin(x))(.1) - self.assertAllClose(exp.call(1.), np.sin(1.)) - - blob = export.serialize(exp, vjp_order=1) - rehydrated = export.deserialize(blob) - - self.assertAllClose(export.call(exp)(1.), np.sin(1.)) - self.assertAllClose(export.call_exported(exp)(1.), np.sin(1.)) - def test_call_exported_lambda(self): # When we export a lambda, the exported.fun_name is not a valid MLIR function name f = jax.jit(lambda x: jnp.sin(x)) From 15a11365e4f4f97f3353744cfa290421b9770150 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Wed, 30 Oct 2024 06:19:54 -0700 Subject: [PATCH 117/698] Change the lowering rule for `jax.lax.scan` to avoid emitting a `while` loop when the intent is to fully unroll the loop. PiperOrigin-RevId: 691393597 --- jax/_src/lax/control_flow/loops.py | 4 ++++ tests/lax_control_flow_test.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 598601cc4097..6d6338b0bfd5 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -418,6 +418,10 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, consts, carry, xs_ = split_list(args, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) num_trips, remainder = divmod(length, unroll) + if unroll != 1 and num_trips == 1 and remainder == 0: + # In that case, we explicitly want to fully unroll the loop. Put everything + # into the remainder block and avoid lowering to a while loop. + num_trips, remainder = 0, length if unroll == 1: xss = xs_ yss = _map(partial(_empty_array, (length,)), y_avals) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 79d5fb79b44c..d383e4c6ac20 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2424,6 +2424,7 @@ def f(c, a): scan = lambda c, xs: lax.scan(f, c, xs) scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2) + scan_fully_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=True) # jaxprs should be the same size self.assertEqual( @@ -2431,9 +2432,19 @@ def f(c, a): len(str(jax.make_jaxpr(scan_unrolled)(c, xs)))) # but HLO should grow due to unrolling - self.assertLess( - len(str(jax.jit(scan).lower(c, xs).as_text('hlo'))), - len(str(jax.jit(scan_unrolled).lower(c, xs).as_text('hlo')))) + scan_hlo = str(jax.jit(scan).lower(c, xs).as_text("hlo")) + scan_unrolled_hlo = str(jax.jit(scan_unrolled).lower(c, xs).as_text("hlo")) + scan_fully_unrolled_hlo = str( + jax.jit(scan_fully_unrolled).lower(c, xs).as_text("hlo")) + + self.assertLess(len(scan_hlo), len(scan_unrolled_hlo)) + self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo)) + + # and the lowering should contain a while loop, unless the scan is fully + # unrolled + self.assertIn("while(", scan_hlo) + self.assertIn("while(", scan_unrolled_hlo) + self.assertNotIn("while(", scan_fully_unrolled_hlo) def test_scan_xs_none(self): def f(h, _): From 8f96e9082a428c29894c959d66e62ea5f5f43970 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Wed, 30 Oct 2024 06:54:46 -0700 Subject: [PATCH 118/698] [Pallas TPU] Add lowerings for scalar `absi` This PR is a follow-up of https://github.com/jax-ml/jax/pull/24504, which adds lowerings for scalar `absf` and `rsqrt`. PiperOrigin-RevId: 691402430 --- tests/pallas/ops_test.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 9867fb2c0a9d..75f95ae1b630 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -841,16 +841,6 @@ def test_elementwise_scalar(self, fn, dtype): "Scalar population count on TPU is only supported in interpret mode" ) - if ( - jtu.test_device_matches(["tpu"]) - and fn == jnp.abs - and jnp.issubdtype(dtype, jnp.integer) - and not self.INTERPRET - ): - self.skipTest( - "Scalar abs for integers on TPU is only supported in interpret mode" - ) - # TODO(b/370578663): implement these lowerings on TPU if jtu.test_device_matches(["tpu"]) and fn in ( jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, From a45b0856c5017af37180afdd0c7b1a0377f7326d Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Oct 2024 07:21:25 -0700 Subject: [PATCH 119/698] Relax leak checks under the jax_data_dependent_tracing_fallback flag. PiperOrigin-RevId: 691409392 --- jax/_src/core.py | 4 +++- jax/_src/interpreters/partial_eval.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 2a2a0d601848..4c006fc95b89 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -433,7 +433,9 @@ def __repr__(self): def bind(self, *args, **params): for arg in args: - if isinstance(arg, Tracer) and not arg._trace.is_valid(): + if (isinstance(arg, Tracer) + and not arg._trace.is_valid() + and not config.data_dependent_tracing_fallback.value): raise escaped_tracer_error(arg) # TODO: figure out how to handle function arguments # assert (not config.enable_checks.value or diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5bfb758040e3..9e5e1ee9bd42 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -540,6 +540,13 @@ def parents(self) -> Sequence[JaxprTracer]: else: return [] + def full_lower(self): + known = self.pval.get_known() + if known is not None: + return core.full_lower(known) + else: + return self + def is_known(self): return self.pval.is_known() From 2b70ad30fba5ba6bc1cbdb7162cbc4e8f3b823a9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 07:23:08 -0700 Subject: [PATCH 120/698] Removed unused `_upcast_fp16_for_computation` PiperOrigin-RevId: 691409888 --- jax/_src/lax/lax.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0d8bfafe6932..2a44d9ec980d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1912,17 +1912,6 @@ def reciprocal(x: ArrayLike) -> Array: r"""Elementwise reciprocal: :math:`1 \over x`.""" return integer_pow(x, -1) -def _upcast_fp16_for_computation(f): - @functools.wraps(f) - def f_wrapped(x): - dtype = _dtype(x) - if dtype == np.float16 or dtype == dtypes.bfloat16: - return convert_element_type( - f(convert_element_type(x, np.float32)), dtype) - return f(x) - - return f_wrapped - def tan(x: ArrayLike) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`.""" return tan_p.bind(x) From 2652ab56082981783baa85c9918df9ca9dd92718 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 07:29:58 -0700 Subject: [PATCH 121/698] [mosaic_gpu] Added support for bitwise and, or and xor to `FragmentedArray` PiperOrigin-RevId: 691411447 --- .../mosaic/gpu/fragmented_array.py | 38 ++++++++++++++----- tests/mosaic/gpu_test.py | 16 ++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 1767f3edb976..5639f7356ea9 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -472,6 +472,35 @@ def __rmod__(self, other): else: return self._pointwise(lambda s, o: arith.remui(o, s), other) + def __invert__(self): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self ^ ~0 + + def __or__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self._pointwise(arith.ori, other) + + def __ror__(self, other): + return self | other + + def __and__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self._pointwise(arith.andi, other) + + def __rand__(self, other): + return self & other + + def __xor__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + return self._pointwise(arith.xori, other) + + def __rxor__(self, other): + return self ^ other + def __eq__(self, other): return self._compare( other, @@ -607,15 +636,6 @@ def fast_instr(x): raise NotImplementedError(x.type) return fast_instr - def __and__(self, other): - if not ir.IntegerType.isinstance(self.mlir_dtype): - raise ValueError( - "Bitwise operations only defined for integer types, not" - f" {self.mlir_dtype}" - ) - - return self._pointwise(arith.andi, other) - def bitcast(self, elt: ir.Type): reg_type = self.registers.flat[0].type if ir.VectorType.isinstance(reg_type): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index acd06e4e258f..062d2de02bac 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1284,6 +1284,22 @@ def kernel(ctx, dst, _): iota = np.arange(m * n, dtype=dtype).reshape(m, n) np.testing.assert_array_equal(result, op(iota, iota + 1)) + @parameterized.product( + op=[operator.and_, operator.or_, operator.xor], + dtype=[jnp.uint32], + ) + def test_bitwise(self, op, dtype, m=64, n=8): + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, dtype) + op(iota, iota + 1).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_array_equal(result, op(iota, iota + 1)) + @parameterized.product( ops=( (lambda x: -x, jax.lax.neg), From 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20K=C3=B6ppe?= Date: Wed, 30 Oct 2024 07:53:24 -0700 Subject: [PATCH 122/698] Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf PiperOrigin-RevId: 691417832 --- jax/_src/api.py | 7 ------- jax/_src/prng.py | 5 ----- tests/random_test.py | 5 ----- 3 files changed, 17 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index f4ea70e3b5ad..390d3ea337bb 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2440,13 +2440,6 @@ def _device_put_replicated(x): def _device_get(x): if isinstance(x, core.Tracer): return x - if dtypes.issubdtype(getattr(x, "dtype", None), dtypes.extended): - try: - to_device = x.dtype._rules.device_get - except AttributeError: - pass - else: - return to_device(x) try: toarray = x.__array__ except AttributeError: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 0a5a1dff2659..039b0a309775 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -400,11 +400,6 @@ def make_sharded_array(aval, sharding, arrays, committed): phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) - @staticmethod - def device_get(val): - buffer = api.device_get(random_unwrap(val)) - return random_wrap(buffer, impl=val.dtype._impl) - @staticmethod def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) diff --git a/tests/random_test.py b/tests/random_test.py index fed12792d5c6..da182dbccae9 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -936,11 +936,6 @@ def f(x): x = jnp.array([True, False, False]) f(x) # doesn't crash - def test_device_get(self): - keys = self.make_keys(4) - keys_on_host = jax.device_get(keys) - self.assertKeysEqual(keys, keys_on_host) - def test_device_put(self): device = jax.devices()[0] keys = self.make_keys(4) From a8f44c47005873ff422278c4256fc19e073e5813 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 30 Oct 2024 08:29:35 -0700 Subject: [PATCH 123/698] Fix a CI failure under NumPy 2.1. PiperOrigin-RevId: 691428702 --- tests/lax_numpy_reducers_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 4dc0ff5f481f..98c8785f0fb1 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -843,7 +843,8 @@ def np_op(x, axis=None, dtype=None, include_initial=False): np_fun = lambda x: np_op(x, **kwargs) jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, + rtol={jnp.bfloat16: 5e-2}) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( From d2f5804449cc30e35993bc5be76ca0f2ab644e20 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 30 Oct 2024 09:37:14 -0700 Subject: [PATCH 124/698] [Pallas] Add test cases for var + constant. PiperOrigin-RevId: 691450143 --- tests/pallas/ops_test.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 75f95ae1b630..10fac0ac1ade 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -721,6 +721,28 @@ def kernel(x_ref, o_ref): expected.astype(jnp.float32), ) + # TODO(twsung): Add more types once lowering is implemented. + @parameterized.parameters( + jnp.float32, + jnp.bfloat16, + jnp.int32, + ) + def test_add_constant(self, dtype): + + shape = (256, 256) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + np.testing.assert_array_equal( + kernel(jnp.zeros(shape, dtype=dtype)), + jnp.ones(shape, dtype=dtype), + ) + @parameterized.parameters( -3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4, ) From da994d355232fc4e7461089e6556bed7a69e1889 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 30 Oct 2024 09:59:56 -0700 Subject: [PATCH 125/698] Move utility functions in build.py to utils.py This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability. PiperOrigin-RevId: 691458051 --- build/build.py | 272 ++++++------------------------------------- build/tools/utils.py | 249 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+), 239 deletions(-) create mode 100644 build/tools/utils.py diff --git a/build/build.py b/build/build.py index 44343ebab4ef..62e4217c10a2 100755 --- a/build/build.py +++ b/build/build.py @@ -16,225 +16,16 @@ # # Helper script for building JAX's libjax easily. - import argparse -import collections -import hashlib import logging import os -import pathlib import platform -import re -import shutil -import stat -import subprocess -import sys import textwrap -import urllib.request -logger = logging.getLogger(__name__) +from tools import utils -def is_windows(): - return sys.platform.startswith("win32") - - -def shell(cmd): - try: - logger.info("shell(): %s", cmd) - output = subprocess.check_output(cmd) - except subprocess.CalledProcessError as e: - logger.info("subprocess raised: %s", e) - if e.output: print(e.output) - raise - except Exception as e: - logger.info("subprocess raised: %s", e) - raise - return output.decode("UTF-8").strip() - - -# Python - -def get_python_bin_path(python_bin_path_flag): - """Returns the path to the Python interpreter to use.""" - path = python_bin_path_flag or sys.executable - return path.replace(os.sep, "/") - - -def get_python_version(python_bin_path): - version_output = shell( - [python_bin_path, "-c", - ("import sys; print(\"{}.{}\".format(sys.version_info[0], " - "sys.version_info[1]))")]) - major, minor = map(int, version_output.split(".")) - return major, minor - -def check_python_version(python_version): - if python_version < (3, 10): - print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) - sys.exit(-1) - - -def get_githash(): - try: - return subprocess.run( - ["git", "rev-parse", "HEAD"], - encoding='utf-8', - capture_output=True).stdout.strip() - except OSError: - return "" - -# Bazel - -BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" -BazelPackage = collections.namedtuple("BazelPackage", - ["base_uri", "file", "sha256"]) -bazel_packages = { - ("Linux", "x86_64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-linux-x86_64", - sha256= - "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"), - ("Linux", "aarch64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-linux-arm64", - sha256= - "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"), - ("Darwin", "x86_64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-darwin-x86_64", - sha256= - "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"), - ("Darwin", "arm64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-darwin-arm64", - sha256= - "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"), - ("Windows", "AMD64"): - BazelPackage( - base_uri=None, - file="bazel-6.5.0-windows-x86_64.exe", - sha256= - "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"), -} - - -def download_and_verify_bazel(): - """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" - package = bazel_packages.get((platform.system(), platform.machine())) - if package is None: - return None - - if not os.access(package.file, os.X_OK): - uri = (package.base_uri or BAZEL_BASE_URI) + package.file - sys.stdout.write(f"Downloading bazel from: {uri}\n") - - def progress(block_count, block_size, total_size): - if total_size <= 0: - total_size = 170**6 - progress = (block_count * block_size) / total_size - num_chars = 40 - progress_chars = int(num_chars * progress) - sys.stdout.write("{} [{}{}] {}%\r".format( - package.file, "#" * progress_chars, - "." * (num_chars - progress_chars), int(progress * 100.0))) - - tmp_path, _ = urllib.request.urlretrieve( - uri, None, progress if sys.stdout.isatty() else None - ) - sys.stdout.write("\n") - - # Verify that the downloaded Bazel binary has the expected SHA256. - with open(tmp_path, "rb") as downloaded_file: - contents = downloaded_file.read() - - digest = hashlib.sha256(contents).hexdigest() - if digest != package.sha256: - print( - "Checksum mismatch for downloaded bazel binary (expected {}; got {})." - .format(package.sha256, digest)) - sys.exit(-1) - - # Write the file as the bazel file name. - with open(package.file, "wb") as out_file: - out_file.write(contents) - - # Mark the file as executable. - st = os.stat(package.file) - os.chmod(package.file, - st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) - - return os.path.join(".", package.file) - - -def get_bazel_paths(bazel_path_flag): - """Yields a sequence of guesses about bazel path. Some of sequence elements - can be None. The resulting iterator is lazy and potentially has a side - effects.""" - yield bazel_path_flag - yield shutil.which("bazel") - yield download_and_verify_bazel() - - -def get_bazel_path(bazel_path_flag): - """Returns the path to a Bazel binary, downloading Bazel if not found. Also, - checks Bazel's version is at least newer than 6.5.0 - - A manual version check is needed only for really old bazel versions. - Newer bazel releases perform their own version check against .bazelversion - (see for details - https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes). - """ - for path in filter(None, get_bazel_paths(bazel_path_flag)): - version = get_bazel_version(path) - if version is not None and version >= (6, 5, 0): - return path, ".".join(map(str, version)) - - print("Cannot find or download a suitable version of bazel." - "Please install bazel >= 6.5.0.") - sys.exit(-1) - - -def get_bazel_version(bazel_path): - try: - version_output = shell([bazel_path, "--version"]) - except (subprocess.CalledProcessError, OSError): - return None - match = re.search(r"bazel *([0-9\\.]+)", version_output) - if match is None: - return None - return tuple(int(x) for x in match.group(1).split(".")) - - -def get_clang_path_or_exit(): - which_clang_output = shutil.which("clang") - if which_clang_output: - # If we've found a clang on the path, need to get the fully resolved path - # to ensure that system headers are found. - return str(pathlib.Path(which_clang_output).resolve()) - else: - print( - "--clang_path is unset and clang cannot be found" - " on the PATH. Please pass --clang_path directly." - ) - sys.exit(-1) - -def get_clang_major_version(clang_path): - clang_version_proc = subprocess.run( - [clang_path, "-E", "-P", "-"], - input="__clang_major__", - check=True, - capture_output=True, - text=True, - ) - major_version = int(clang_version_proc.stdout) - - return major_version - +logger = logging.getLogger(__name__) def write_bazelrc(*, remote_build, cuda_version, cudnn_version, rocm_toolkit_path, @@ -272,10 +63,10 @@ def write_bazelrc(*, remote_build, if target_cpu_features == "release": if wheel_cpu == "x86_64": - f.write("build --config=avx_windows\n" if is_windows() + f.write("build --config=avx_windows\n" if utils.is_windows() else "build --config=avx_posix\n") elif target_cpu_features == "native": - if is_windows(): + if utils.is_windows(): print("--target_cpu_features=native is not supported on Windows; ignoring.") else: f.write("build --config=native_arch_posix\n") @@ -575,18 +366,18 @@ def main(): else host_cpu) # Find a working Bazel. - bazel_path, bazel_version = get_bazel_path(args.bazel_path) + bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) print(f"Bazel binary path: {bazel_path}") print(f"Bazel version: {bazel_version}") if args.python_version: python_version = args.python_version else: - python_bin_path = get_python_bin_path(args.python_bin_path) + python_bin_path = utils.get_python_bin_path(args.python_bin_path) print(f"Python binary path: {python_bin_path}") - python_version = get_python_version(python_bin_path) + python_version = utils.get_python_version(python_bin_path) print("Python version: {}".format(".".join(map(str, python_version)))) - check_python_version(python_version) + utils.check_python_version(python_version) python_version = ".".join(map(str, python_version)) print("Use clang: {}".format("yes" if args.use_clang else "no")) @@ -594,9 +385,9 @@ def main(): clang_major_version = None if args.use_clang: if not clang_path: - clang_path = get_clang_path_or_exit() + clang_path = utils.get_clang_path_or_exit() print(f"clang path: {clang_path}") - clang_major_version = get_clang_major_version(clang_path) + clang_major_version = utils.get_clang_major_version(clang_path) print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no")) print(f"Target CPU: {wheel_cpu}") @@ -648,7 +439,7 @@ def main(): update_command = ([bazel_path] + args.bazel_startup_options + ["run", "--verbose_failures=true", task, *args.bazel_options]) print(" ".join(update_command)) - shell(update_command) + utils.shell(update_command) return if args.configure_only: @@ -675,27 +466,29 @@ def main(): if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: build_cpu_wheel_command = [ - *command_base, - "//jaxlib/tools:build_wheel", "--", - f"--output_path={output_path_jaxlib}", - f"--jaxlib_git_hash={get_githash()}", - f"--cpu={wheel_cpu}" + *command_base, + "//jaxlib/tools:build_wheel", + "--", + f"--output_path={output_path_jaxlib}", + f"--jaxlib_git_hash={utils.get_githash()}", + f"--cpu={wheel_cpu}", ] if args.build_gpu_plugin: build_cpu_wheel_command.append("--skip_gpu_kernels") if args.editable: build_cpu_wheel_command.append("--editable") print(" ".join(build_cpu_wheel_command)) - shell(build_cpu_wheel_command) + utils.shell(build_cpu_wheel_command) if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ (args.build_gpu_kernel_plugin == "rocm"): build_gpu_kernels_command = [ - *command_base, - "//jaxlib/tools:build_gpu_kernels_wheel", "--", - f"--output_path={output_path_jax_kernel}", - f"--jaxlib_git_hash={get_githash()}", - f"--cpu={wheel_cpu}", + *command_base, + "//jaxlib/tools:build_gpu_kernels_wheel", + "--", + f"--output_path={output_path_jax_kernel}", + f"--jaxlib_git_hash={utils.get_githash()}", + f"--cpu={wheel_cpu}", ] if args.enable_cuda: build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") @@ -708,15 +501,16 @@ def main(): if args.editable: build_gpu_kernels_command.append("--editable") print(" ".join(build_gpu_kernels_command)) - shell(build_gpu_kernels_command) + utils.shell(build_gpu_kernels_command) if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: build_pjrt_plugin_command = [ - *command_base, - "//jaxlib/tools:build_gpu_plugin_wheel", "--", - f"--output_path={output_path_jax_pjrt}", - f"--jaxlib_git_hash={get_githash()}", - f"--cpu={wheel_cpu}", + *command_base, + "//jaxlib/tools:build_gpu_plugin_wheel", + "--", + f"--output_path={output_path_jax_pjrt}", + f"--jaxlib_git_hash={utils.get_githash()}", + f"--cpu={wheel_cpu}", ] if args.enable_cuda: build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") @@ -729,9 +523,9 @@ def main(): if args.editable: build_pjrt_plugin_command.append("--editable") print(" ".join(build_pjrt_plugin_command)) - shell(build_pjrt_plugin_command) + utils.shell(build_pjrt_plugin_command) - shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) + utils.shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) if __name__ == "__main__": diff --git a/build/tools/utils.py b/build/tools/utils.py new file mode 100644 index 000000000000..4c8765371316 --- /dev/null +++ b/build/tools/utils.py @@ -0,0 +1,249 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for tools/utilities used by the JAX build CLI. +import collections +import hashlib +import logging +import os +import pathlib +import platform +import re +import shutil +import stat +import subprocess +import sys +import urllib.request + +logger = logging.getLogger(__name__) + +def is_windows(): + return sys.platform.startswith("win32") + +def shell(cmd): + try: + logger.info("shell(): %s", cmd) + output = subprocess.check_output(cmd) + except subprocess.CalledProcessError as e: + logger.info("subprocess raised: %s", e) + if e.output: + print(e.output) + raise + except Exception as e: + logger.info("subprocess raised: %s", e) + raise + return output.decode("UTF-8").strip() + + +# Bazel +BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" +BazelPackage = collections.namedtuple( + "BazelPackage", ["base_uri", "file", "sha256"] +) +bazel_packages = { + ("Linux", "x86_64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-linux-x86_64", + sha256=( + "a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307" + ), + ), + ("Linux", "aarch64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-linux-arm64", + sha256=( + "5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f" + ), + ), + ("Darwin", "x86_64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-darwin-x86_64", + sha256=( + "bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29" + ), + ), + ("Darwin", "arm64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-darwin-arm64", + sha256=( + "c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb" + ), + ), + ("Windows", "AMD64"): BazelPackage( + base_uri=None, + file="bazel-6.5.0-windows-x86_64.exe", + sha256=( + "6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6" + ), + ), +} + + +def download_and_verify_bazel(): + """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" + package = bazel_packages.get((platform.system(), platform.machine())) + if package is None: + return None + + if not os.access(package.file, os.X_OK): + uri = (package.base_uri or BAZEL_BASE_URI) + package.file + sys.stdout.write(f"Downloading bazel from: {uri}\n") + + def progress(block_count, block_size, total_size): + if total_size <= 0: + total_size = 170**6 + progress = (block_count * block_size) / total_size + num_chars = 40 + progress_chars = int(num_chars * progress) + sys.stdout.write( + "{} [{}{}] {}%\r".format( + package.file, + "#" * progress_chars, + "." * (num_chars - progress_chars), + int(progress * 100.0), + ) + ) + + tmp_path, _ = urllib.request.urlretrieve( + uri, None, progress if sys.stdout.isatty() else None + ) + sys.stdout.write("\n") + + # Verify that the downloaded Bazel binary has the expected SHA256. + with open(tmp_path, "rb") as downloaded_file: + contents = downloaded_file.read() + + digest = hashlib.sha256(contents).hexdigest() + if digest != package.sha256: + print( + "Checksum mismatch for downloaded bazel binary (expected {}; got {})." + .format(package.sha256, digest) + ) + sys.exit(-1) + + # Write the file as the bazel file name. + with open(package.file, "wb") as out_file: + out_file.write(contents) + + # Mark the file as executable. + st = os.stat(package.file) + os.chmod( + package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH + ) + + return os.path.join(".", package.file) + + +def get_bazel_paths(bazel_path_flag): + """Yields a sequence of guesses about bazel path. + + Some of sequence elements can be None. The resulting iterator is lazy and + potentially has a side effects. + """ + yield bazel_path_flag + yield shutil.which("bazel") + yield download_and_verify_bazel() + + +def get_bazel_path(bazel_path_flag): + """Returns the path to a Bazel binary, downloading Bazel if not found. + + Also, checks Bazel's version is at least newer than 6.5.0 + + A manual version check is needed only for really old bazel versions. + Newer bazel releases perform their own version check against .bazelversion + (see for details + https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes). + """ + for path in filter(None, get_bazel_paths(bazel_path_flag)): + version = get_bazel_version(path) + if version is not None and version >= (6, 5, 0): + return path, ".".join(map(str, version)) + + print( + "Cannot find or download a suitable version of bazel." + "Please install bazel >= 6.5.0." + ) + sys.exit(-1) + + +def get_bazel_version(bazel_path): + try: + version_output = shell([bazel_path, "--version"]) + except (subprocess.CalledProcessError, OSError): + return None + match = re.search(r"bazel *([0-9\\.]+)", version_output) + if match is None: + return None + return tuple(int(x) for x in match.group(1).split(".")) + + +def get_clang_path_or_exit(): + which_clang_output = shutil.which("clang") + if which_clang_output: + # If we've found a clang on the path, need to get the fully resolved path + # to ensure that system headers are found. + return str(pathlib.Path(which_clang_output).resolve()) + else: + print( + "--clang_path is unset and clang cannot be found" + " on the PATH. Please pass --clang_path directly." + ) + sys.exit(-1) + + +def get_clang_major_version(clang_path): + clang_version_proc = subprocess.run( + [clang_path, "-E", "-P", "-"], + input="__clang_major__", + check=True, + capture_output=True, + text=True, + ) + major_version = int(clang_version_proc.stdout) + + return major_version + + +# Python +def get_python_bin_path(python_bin_path_flag): + """Returns the path to the Python interpreter to use.""" + path = python_bin_path_flag or sys.executable + return path.replace(os.sep, "/") + + +def get_python_version(python_bin_path): + version_output = shell([ + python_bin_path, + "-c", + ( + 'import sys; print("{}.{}".format(sys.version_info[0], ' + "sys.version_info[1]))" + ), + ]) + major, minor = map(int, version_output.split(".")) + return major, minor + +def check_python_version(python_version): + if python_version < (3, 10): + print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) + sys.exit(-1) + +def get_githash(): + try: + return subprocess.run( + ["git", "rev-parse", "HEAD"], encoding="utf-8", capture_output=True + ).stdout.strip() + except OSError: + return "" From 6283eab2ffed1d4d9e845e14233be82b3e0f5472 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 10:12:47 -0700 Subject: [PATCH 126/698] [pallas] Added a flag disabling verbose error reporting PiperOrigin-RevId: 691463398 --- jax/_src/pallas/mosaic/lowering.py | 3 +++ jax/_src/pallas/mosaic_gpu/lowering.py | 3 +++ jax/_src/pallas/pallas_call.py | 11 +++++++++++ jax/_src/pallas/triton/lowering.py | 3 +++ 4 files changed, 20 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 13d321754b23..b630a14276e3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -50,6 +50,7 @@ from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core @@ -837,6 +838,8 @@ def write_env(var: jax_core.Var, val): except LoweringException: raise # We only add the extra info to the innermost exception. except Exception as e: + if not pallas_call._verbose_errors_enabled(): + raise msg = (f"{type(e).__name__}: {e}\n" + "Additional diagnostics: \n" + f"Failing jaxpr equation: {eqn}\n") diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6697b4ada895..c4447ae95435 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -40,6 +40,7 @@ from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -885,6 +886,8 @@ def write_env(var: jax_core.Var, val): except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: + if not pallas_call._verbose_errors_enabled(): + raise inval_types = map(lambda t: getattr(t, "type", None), invals) raise LoweringError( f"Exception while lowering eqn:\n {eqn}\nWith context:\n " diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 43cd4c2aca8e..f7bd0dd4e4d7 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1448,6 +1448,17 @@ def _trace_kernel_to_jaxpr( " dialect, instead of Trition IR." ), ) +_PALLAS_VERBOSE_ERRORS = config.bool_flag( + "jax_pallas_verbose_errors", + default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True), + help=( + "If True, print verbose error messages for Pallas kernels." + ), +) + + +def _verbose_errors_enabled() -> bool: + return _PALLAS_VERBOSE_ERRORS.value def _unsupported_lowering_error(platform: str) -> Exception: diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 605b975fce25..d3ca18ee507c 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -46,6 +46,7 @@ from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge @@ -390,6 +391,8 @@ def write_env(var: jax_core.Var, val): except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: + if not pallas_call._verbose_errors_enabled(): + raise inval_types = map(lambda t: getattr(t, "type", None), invals) raise LoweringError( f"Exception while lowering eqn:\n {eqn}\nWith context:\n " From 409517fcbcd51b38b9df13c4cb9663bcec2db54d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 30 Oct 2024 10:36:51 -0700 Subject: [PATCH 127/698] [pallas:mosaic_gpu] Disabled verbose lowering errors in Mosaic GPU tests PiperOrigin-RevId: 691472782 --- tests/pallas/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index b1f1b12a1d70..3166526dffcb 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -162,6 +162,7 @@ jax_multiplatform_test( ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "JAX_PALLAS_VERBOSE_ERRORS": "0", }, deps = [ "//jax:pallas", From 99ea4c1a4a41a5e5a05ec9bbbc4f10a6f3bd4d3c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 30 Oct 2024 10:46:38 -0700 Subject: [PATCH 128/698] [Fix] Put * packing into reshape no-op condition (Bug in my original CL) PiperOrigin-RevId: 691476663 --- jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index a9b5ed6876b2..b03b26edd11a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1439,7 +1439,8 @@ class VectorLayoutInferer { int64_t sublane_tiling = vreg_slice[0]; do { auto src_res_tiled_equal = src_tiled_ishape[1] == res_tiled_ishape[1]; - auto vreg_num_elements = target_shape_[0] * target_shape_[1]; + auto vreg_num_elements = + target_shape_[0] * target_shape_[1] * layout.packing(); auto single_subline_mod_1024 = (sublane_tiling == 1 && src_tiled_ishape[1] % vreg_num_elements == 0 && From 3904ced255bf5b2dda7e44c6d145422418550f47 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 30 Oct 2024 10:48:09 -0700 Subject: [PATCH 129/698] [Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only PiperOrigin-RevId: 691477248 --- tests/pallas/ops_test.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 10fac0ac1ade..0219f0a8a536 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1896,6 +1896,28 @@ def reduce(x_ref, y_ref): y_ref = jnp.cumsum(x, axis=axis) np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) + @parameterized.parameters( + (0, jnp.float32), + (0, jnp.bfloat16), + (1, jnp.float32), + (1, jnp.bfloat16), + (-1, jnp.float32), + (-1, jnp.bfloat16), + ) + def test_triu(self, k, dtype): + if dtype == jnp.bfloat16 and jtu.test_device_matches(["tpu"]): + # TODO(mvoz): b/376330700 + raise unittest.SkipTest('NYI - bf16 select') + + x = jnp.arange(128 * 256, dtype=dtype).reshape((128, 256)) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.triu(x_ref[...], k=k) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((128, 256), dtype) + )(x) + np.testing.assert_array_equal(out, np.triu(x, k=k)) class OpsInterpretTest(OpsTest): INTERPRET = True From 32bf19ac6f52a0f6776c730dd352d0530cc5bc9f Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Oct 2024 11:33:29 -0700 Subject: [PATCH 130/698] Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs. debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors. PiperOrigin-RevId: 691494516 --- jax/_src/core.py | 9 ++++++++- tests/debug_nans_test.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 4c006fc95b89..8fa8412660b5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -256,7 +256,14 @@ def _repr_pretty_(self, p, cycle): @curry def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): - return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args) + # TODO(dougalm): remove this hack when we add contexts to jaxpr. + # debug_nans is sometimes disabled locally at the traceable level by ops that + # work with nans internally, like jnp.var. The right thing to do is to add + # contexts to our jaxpr representation so that we can capture these local + # context modifications. In the meantime, disabling the checks when we + # round-trip prevents those ops producing spurious errors. + with config.debug_nans(False): + return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args) class JaxprEqnContext: diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 020c9f744833..4573f542c14f 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -75,6 +75,7 @@ def testSingleResultPrimitiveNaN(self): @jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION) def testCallDeoptimized(self, jit): + raise SkipTest("re-enable once we handle contexts properly") # TODO(dougalm) @jit def f(x): return jax.lax.cond( From 44158ab0e4417342132c33b0d7386e4ec3f9911c Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Wed, 30 Oct 2024 11:39:50 -0700 Subject: [PATCH 131/698] #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases. Only test cases breaking on CPU are related to: - pure callbacks - export - shard alike Note that `layout_test` is broken on TPU, leaving a comment saying to enable it. Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function. PiperOrigin-RevId: 691496997 --- .../array_serialization/serialization_test.py | 5 ++++ tests/BUILD | 29 +++++++++++++++---- tests/compilation_cache_test.py | 4 +++ tests/mock_gpu_test.py | 15 +++++++--- tests/pallas/ops_test.py | 5 ++++ tests/shard_map_test.py | 6 +++- 6 files changed, 53 insertions(+), 11 deletions(-) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 3df3eb25c40d..c525185e6449 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized import jax import jax.numpy as jnp +from jax._src import config from jax._src import test_util as jtu from jax._src import array from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding @@ -375,6 +376,8 @@ def cb1(index): @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) def test_checkpointing_with_int4(self, input_dtype): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -580,6 +583,8 @@ def test_load_with_layout(self): self.assertArraysEqual(s.data, np_inp[s.index]) def test_deserialization_with_int4(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") if jtu.test_device_matches(['gpu']): self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 diff --git a/tests/BUILD b/tests/BUILD index 39e1d35a3407..58de3404979d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -225,9 +225,7 @@ jax_multiplatform_test( "tpu_v4_2x2", "tpu_v5p_2x2", "tpu_v5e_4x2", - "cpu_shardy", "gpu_2gpu_shardy", - "tpu_v3_2x2_shardy", "tpu_v5e_4x2_shardy", ], shard_count = { @@ -246,10 +244,8 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "cpu_shardy", "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", - "tpu_v4_2x2_shardy", "tpu_v3_2x2", "gpu_2gpu", ], @@ -264,6 +260,7 @@ jax_multiplatform_test( ], ) +# TODO(b/355263220): enable on TPU once layouts is supported with Shardy. jax_multiplatform_test( name = "layout_test", srcs = ["layout_test.py"], @@ -279,6 +276,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], + disable_configs = [ + "cpu_shardy", # TODO(b/355263220): enable once shard_alike is supported. + ], enable_configs = [ "tpu_v3_2x2", "tpu_v5e_4x2", @@ -309,6 +309,9 @@ jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], enable_backends = ["gpu"], + enable_configs = [ + "gpu_2gpu_shardy", + ], tags = [ "config-cuda-only", ], @@ -997,6 +1000,9 @@ jax_multiplatform_test( "gpu": ["--jax_num_generated_cases=40"], "tpu": ["--jax_num_generated_cases=40"], }, + disable_configs = [ + "cpu_shardy", # TODO(b/376475853): array values mismatch, need to fix and re-enable. + ], shard_count = { "cpu": 50, "gpu": 50, @@ -1234,6 +1240,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], + disable_configs = [ + "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. + ], enable_configs = [ "cpu", "gpu_h100", @@ -1249,6 +1258,9 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + disable_configs = [ + "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. + ], enable_configs = [ "tpu_v2_1x1", "tpu_v3_2x2", @@ -1263,6 +1275,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], + disable_configs = [ + "cpu_shardy", # TODO(b/364547005): enable once pure callbacks are supported. + ], enable_configs = [ "cpu", "gpu_h100", @@ -1313,10 +1328,8 @@ jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], enable_configs = [ - "cpu_shardy", "gpu_2gpu_shardy", "tpu_v3_2x2_shardy", - "tpu_v4_2x2_shardy", ], shard_count = { "cpu": 50, @@ -1405,6 +1418,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], + disable_configs = [ + "cpu_shardy", # TODO(b/355263220): enable once export is supported. + ], enable_configs = [ "tpu_v3_2x2", ], @@ -1442,6 +1458,7 @@ jax_multiplatform_test( disable_configs = [ "gpu_a100", # TODO(b/269593297): matmul precision issues "gpu_h100", # Scarce resources. + "cpu_shardy", # TODO(b/355263220): enable once export is supported. ], shard_count = { "cpu": 40, diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index e5222814fb02..40c2181a9e3c 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -420,6 +420,8 @@ def test_persistent_cache_hit_no_logging(self): self.assertFalse(msg_exists_in_logs(msg, log.records, logging.WARNING)) def test_persistent_cache_miss_logging_with_explain(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") with (config.explain_cache_misses(True), config.compilation_cache_dir("jax-cache")): @@ -464,6 +466,8 @@ def test_persistent_cache_miss_logging_with_explain(self): def test_persistent_cache_miss_logging_with_no_explain(self): # test that cache failure messages do not get logged in WARNING + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") with (config.explain_cache_misses(False), config.compilation_cache_dir("jax-cache")): # omitting writing to cache because compilation is too fast diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index 1a4de7456167..b84903618fab 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding @@ -58,10 +59,16 @@ def f(x, y): hlo = f_lowered.compiler_ir() mocked_count = NUM_SHARDS * jax.local_device_count() - self.assertIn( - f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"', - str(hlo) - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}', + str(hlo) + ) + else: + self.assertIn( + f'sharding = "{{devices=[{mocked_count},1]<=[{mocked_count}]}}"', + str(hlo) + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0219f0a8a536..598377b75a22 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -29,6 +29,7 @@ import jax.numpy as jnp from jax import lax from jax import random +from jax._src import config from jax._src import dtypes from jax._src import linear_util as lu from jax._src import state @@ -1241,6 +1242,8 @@ def kernel(x_ref, o_ref): "plgpu.TritonCompilerParams unavailable on Windows", ) def test_debug_print(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") if jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on TPU") @@ -1923,6 +1926,8 @@ class OpsInterpretTest(OpsTest): INTERPRET = True def test_debug_print(self): + if config.use_shardy_partitioner.value: + self.skipTest("TODO(b/364547005): pure callbacks not supported by Shardy yet") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index c5df1ca7c872..3541e331e869 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1235,7 +1235,11 @@ def foo(x): hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo')) if config.use_shardy_partitioner.value: - self.assertEqual(2, hlo_str.count('sdy.manual_computation')) + if len(jax.devices()) > 1: + self.assertEqual(2, hlo_str.count('sdy.manual_computation')) + else: + # When devices == 1, the `sdy.manual_computation` is inlined. + self.assertEqual(0, hlo_str.count('sdy.manual_computation')) else: self.assertIn('call @shmap_body', hlo_str) self.assertIn('call @shmap_body_0', hlo_str) From af14c43893971803993eb66ed33272385ca414c0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 30 Oct 2024 12:35:52 -0700 Subject: [PATCH 132/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/2d9d84487ef22d4d5358f20085234c1865b300f1. PiperOrigin-RevId: 691516089 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 47c5c832a19d..4169e30be21b 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b5690e93ea1e4da5ca9f135d0a0e5796694e706a" -XLA_SHA256 = "38505101e6c62b8afd29c31eba6ac7e4f0709aaba6d1c3006bd6afdb9757cf9b" +XLA_COMMIT = "2d9d84487ef22d4d5358f20085234c1865b300f1" +XLA_SHA256 = "a737b701870646278c69ab4388a1316be1467301a2a5ddad11978d619e3981d7" def repo(): tf_http_archive( From 242e6634ff14862789a2d3c1c0ba26817833c065 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Wed, 30 Oct 2024 13:22:56 -0700 Subject: [PATCH 133/698] [Mosaic] Add the core type enum The new attribute allows differentiating compilation by target core. PiperOrigin-RevId: 691531726 --- jaxlib/mosaic/dialect/tpu/tpu.td | 18 ++++++++++++++++++ jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 13 +++++++++++++ 2 files changed, 31 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 783101e839b1..c05b22c5aa88 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -31,6 +31,11 @@ def TPU_Dialect : Dialect { let cppNamespace = "::mlir::tpu"; let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let extraClassDeclaration = [{ + static StringRef GetCoreTypeKey() { return "tpu.core_type"; } + + static std::optional GetCoreTypeAttr(Operation *op); + }]; } class TPU_Attr traits = []> @@ -46,6 +51,19 @@ class TPU_Type traits = []> let mnemonic = mnemonic_; } +def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [ + I32EnumAttrCase<"kTc", 0, "tc">, + I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">, + I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_CoreTypeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>; def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>; def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index df00093fabe6..d884ef197cda 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -68,6 +69,18 @@ void TPUDialect::initialize() { >(); } +/* static */ std::optional TPUDialect::GetCoreTypeAttr( + Operation *op) { + Attribute attr = op->getAttr(GetCoreTypeKey()); + if (attr == nullptr) { + return std::nullopt; + } + if (!mlir::isa(attr)) { + return std::nullopt; + } + return mlir::cast(attr).getValue(); +} + void VectorLayoutAttr::print(AsmPrinter &printer) const { printer << '<'; printer << getLayout(); From 0181cb396d9956ea96f1d1de9695ceecdb1c7c5e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 30 Oct 2024 15:12:04 -0700 Subject: [PATCH 134/698] Re-land #24589 with fixes to handle `dtype` that is not compatible with NumPy. Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case. Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d PiperOrigin-RevId: 691568933 --- jax/_src/api.py | 11 +++++++++++ jax/_src/prng.py | 5 +++++ tests/random_test.py | 5 +++++ 3 files changed, 21 insertions(+) diff --git a/jax/_src/api.py b/jax/_src/api.py index 390d3ea337bb..652542571fa3 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2440,6 +2440,17 @@ def _device_put_replicated(x): def _device_get(x): if isinstance(x, core.Tracer): return x + + # Extended dtypes dispatch via their device_get rule. + if isinstance(x, basearray.Array) and dtypes.issubdtype(x.dtype, dtypes.extended): + try: + to_device = x.dtype._rules.device_get + except AttributeError: + pass + else: + return to_device(x) + + # Other types dispatch via their __array__ method. try: toarray = x.__array__ except AttributeError: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 039b0a309775..0a5a1dff2659 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -400,6 +400,11 @@ def make_sharded_array(aval, sharding, arrays, committed): phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) + @staticmethod + def device_get(val): + buffer = api.device_get(random_unwrap(val)) + return random_wrap(buffer, impl=val.dtype._impl) + @staticmethod def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) diff --git a/tests/random_test.py b/tests/random_test.py index da182dbccae9..fed12792d5c6 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -936,6 +936,11 @@ def f(x): x = jnp.array([True, False, False]) f(x) # doesn't crash + def test_device_get(self): + keys = self.make_keys(4) + keys_on_host = jax.device_get(keys) + self.assertKeysEqual(keys, keys_on_host) + def test_device_put(self): device = jax.devices()[0] keys = self.make_keys(4) From 7f4a34e12beba1d3da446a23e9fc0f516eda3646 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 30 Oct 2024 16:11:34 -0700 Subject: [PATCH 135/698] Remove the `variant` since sparsecore is only on v5p and it's device kind is `TPU v5`. PiperOrigin-RevId: 691586791 --- tests/layout_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/layout_test.py b/tests/layout_test.py index 406d06dacc9f..7ccd0d7cddea 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -602,7 +602,8 @@ def g(x): g(jnp.arange(8)) def test_sparsecore_compute(self): - if not (jtu.is_device_tpu('5', 'f') or jtu.is_device_tpu_at_least(6)): + if not (jax.devices()[0].device_kind == 'TPU v5' or + jtu.is_device_tpu_at_least(6)): self.skipTest('Does not have a sparsecore present') shape = (128, 128) inp = jnp.arange(math.prod(shape)).reshape(shape) From f355dcf34b8118ed592f37907c28a4f109d25bcf Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 30 Oct 2024 18:53:16 -0700 Subject: [PATCH 136/698] Remove UnshapedArray values from JAX (it remains as an abstract class). Part of a plan to move away from our "abstract value" lattice to more traditional types. PiperOrigin-RevId: 691626481 --- jax/_src/core.py | 40 ++++++---------------------------------- jax/_src/prng.py | 2 +- tests/batching_test.py | 17 ----------------- tests/core_test.py | 3 +-- tests/state_test.py | 11 ----------- 5 files changed, 8 insertions(+), 65 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 8fa8412660b5..43cb5cc1e248 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1482,15 +1482,11 @@ class UnshapedArray(AbstractValue): array_abstraction_level = 4 def __init__(self, dtype, weak_type=False): + # Is it silly to initialize this object and then complain that we should + # never create one? Yes. But otherwise pytype complains. self.dtype = _dtype_object(dtype) self.weak_type = weak_type - - def update(self, dtype=None, weak_type=None): - if dtype is None: - dtype = self.dtype - if weak_type is None: - weak_type = self.weak_type - return UnshapedArray(dtype, weak_type) + raise Exception("We should never create an UnshapedArray object") def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and @@ -1517,19 +1513,6 @@ def __repr__(self): _oct = concretization_function_error(oct) _index = concretization_function_error(operator.index) - def to_tangent_aval(self) -> AbstractValue: - return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) - - def join(self, other): - if self.dtype == other.dtype: - if self.weak_type == other.weak_type: - return self - else: - return UnshapedArray(self.dtype, weak_type=False) - else: - raise TypeError(self, other) - def str_short(self, short_dtypes=False) -> str: return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name @@ -1537,13 +1520,6 @@ def strip_weak_type(self): """Returns a copy of the aval with weak_type=False.""" return self.update(weak_type=False) - @property - def shape(self): - msg = ("UnshapedArray has no shape. Please open an issue at " - "https://github.com/jax-ml/jax/issues because it's unexpected for " - "UnshapedArray instances to ever be produced.") - raise TypeError(msg) - def _canonicalize_dimension(dim: DimSize) -> DimSize: # Dimensions are most commonly integral (by far), so we check that first. try: @@ -1670,8 +1646,6 @@ def join(self, other): if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype: weak_type = self.weak_type and other.weak_type return self.update(weak_type=weak_type) - elif self.dtype == other.dtype: - return UnshapedArray(self.dtype) else: raise TypeError(self, other) @@ -1753,8 +1727,6 @@ def join(self, other) -> AbstractValue: elif self.shape == other.shape and self.dtype == other.dtype: weak_type = self.weak_type and other.weak_type return ShapedArray(self.shape, self.dtype, weak_type=weak_type) - elif self.dtype == other.dtype: - return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type) else: raise TypeError(self, other) @@ -1838,8 +1810,6 @@ def join(self, other): self.dtype == other.dtype): weak_type = self.weak_type and other.weak_type return self.update(weak_type=weak_type) - elif self.dtype == other.dtype: - return UnshapedArray(self.dtype) else: raise TypeError(self, other) @@ -1996,6 +1966,8 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None): aval_type = type(aval) if aval_type is ShapedArray and weak_type is None: return aval + if aval_type is DShapedArray and weak_type is None: + return aval if weak_type is None: weak_type = getattr(aval, 'weak_type', False) for typ in aval_type.__mro__: @@ -2011,8 +1983,8 @@ def _shaped_array_mapping(aval, weak_type): raise_to_shaped_mappings: dict[type, Callable] = { AbstractToken: lambda aval, _: aval, Bot: lambda aval, _: aval, - UnshapedArray: lambda aval, _: aval, ShapedArray: _shaped_array_mapping, + DShapedArray: lambda aval, _: aval, DConcreteArray: lambda aval, weak_type: DShapedArray( aval.shape, aval.dtype, weak_type ), diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 0a5a1dff2659..8925a4342b29 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -812,7 +812,7 @@ def _threefry2x32_abstract_eval(*args): shape = lax_internal.broadcasting_shape_rule(*args) aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32)) else: - aval = core.UnshapedArray(jnp.dtype(jnp.uint32)) + raise TypeError(f"Arguments to threefry2x32 must all be arrays, got {args}") return (aval,) * 2 diff --git a/tests/batching_test.py b/tests/batching_test.py index 2b0b0d63a6f5..608053c23254 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -932,23 +932,6 @@ def f(scale): self.assertAllClose(ans, expected, check_dtypes=False, rtol=jtu.default_gradient_tolerance) - def testIssue387(self): - # https://github.com/jax-ml/jax/issues/387 - R = self.rng().rand(100, 2) - - def dist_sq(R): - dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :] - zero = jnp.zeros_like(dR) - dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR)) - return jnp.sum(dR ** 2, axis=2) - - @jit - def f(R): - _ = dist_sq(R) - return jnp.sum(R ** 2) - - _ = hessian(f)(R) # don't crash on UnshapedArray - @jax.legacy_prng_key('allow') def testIssue489(self): # https://github.com/jax-ml/jax/issues/489 diff --git a/tests/core_test.py b/tests/core_test.py index 94b7010907a9..38700037248d 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -33,14 +33,13 @@ from jax._src import linear_util as lu from jax._src import util from jax._src import test_util as jtu -from jax._src.core import UnshapedArray, ShapedArray, DBIdx +from jax._src.core import ShapedArray, DBIdx from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow config.parse_flags_with_absl() -_ = pe.PartialVal.unknown(UnshapedArray(np.float32)) __ = pe.PartialVal.unknown(ShapedArray((), np.float32)) def call(f, *args): diff --git a/tests/state_test.py b/tests/state_test.py index 92ea2473811c..36e93e88c5e0 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -1362,17 +1362,6 @@ def body(y, z): class GeneralRefTest(jtu.JaxTestCase): - def test_unshaped_ref(self): - def f(x_ref): - x = x_ref[...] - x_ref[...] = x - ref_addupdate(x_ref, (), x) - return [x] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))]) - self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray) - self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32")) - def test_token(self): def f(x_ref): x = x_ref[...] From 5aeffde7070db81235693a7198f961f1c62399af Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 31 Oct 2024 00:58:28 -0700 Subject: [PATCH 137/698] [Mosaic] Extend tpu matmulop to have dimension dims. Add support for batching and simple transposition. PiperOrigin-RevId: 691706218 --- jax/_src/pallas/mosaic/lowering.py | 1 + jaxlib/mosaic/dialect/tpu/tpu.td | 29 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 14 + jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 4 + jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 277 ++++++++++++++++++ .../tpu/transforms/apply_vector_layout.cc | 40 ++- .../tpu/transforms/canonicalize_mosaic.cc | 149 +++++++++- jaxlib/mosaic/dialect/tpu/util.cc | 28 ++ jaxlib/mosaic/dialect/tpu/util.h | 10 + 9 files changed, 541 insertions(+), 11 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b630a14276e3..0e1fe9a5b56a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1615,6 +1615,7 @@ def _dot_general_lowering_rule( ) return vector.shape_cast(out_type, red) + # TODO(mvoz): Plumb these into dot dimension numbers on the matmul op! if lhs_dims == (1,): transpose_lhs = False elif lhs_dims == (0,): diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index c05b22c5aa88..3bd4f651c0fc 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -384,22 +384,47 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> { }]; } -// TODO(apaszke): Add a verifier for this op + +def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> { + let parameters = (ins + ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims, + ArrayRefParameter<"int64_t", "">:$rhs_contracting_dims, + ArrayRefParameter<"int64_t", "">:$lhs_non_contracting_dims, + ArrayRefParameter<"int64_t", "">:$rhs_non_contracting_dims, + // The contract is a flattened structure, wherein, each element is half of a + // pair of indices. The first element is always 0 (lhs) or 1 (rhs) and the + // second index is the index from the lhs or rhs. + ArrayRefParameter<"int64_t", "">:$output_dim_order, + OptionalArrayRefParameter<"int64_t", "">:$lhs_batch_dims, + OptionalArrayRefParameter<"int64_t", "">:$rhs_batch_dims + ); + let assemblyFormat = "`<` `[` $lhs_contracting_dims `]` `,` `[` $rhs_contracting_dims `]` `,` " + "`[` $lhs_non_contracting_dims `]` `,` `[` $rhs_non_contracting_dims `]` `,` " + "`[` $output_dim_order `]` `,` " + "`[` (`]`):($lhs_batch_dims^ `]`)? `,` " + "`[` (`]`):($rhs_batch_dims^ `]`)? `>`"; +} + // TODO(apaszke): Think hard about precision def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { let arguments = (ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc, + // These flags are deprecated - if dimension_numbers are defined, + // these flags are ignored. They will always be false after canonicalize. DefaultValuedAttr:$transpose_lhs, DefaultValuedAttr:$transpose_rhs, - OptionalAttr:$precision + OptionalAttr:$precision, + // NOTE: User-level optional, once canonicalized, always present. + OptionalAttr:$dimension_numbers ); let results = (outs AnyVector:$result); let assemblyFormat = [{ $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result) }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index d884ef197cda..10ab154b7c10 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -223,4 +223,18 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { return false; } +DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, + bool transpose_lhs, + bool transpose_rhs) { + return tpu::DotDimensionNumbersAttr::get( + builder.getContext(), + /*lhs_contracting_dims=*/{transpose_lhs ? 0 : 1}, + /*rhs_contracting_dims=*/{transpose_rhs ? 1 : 0}, + /*lhs_non_contracting_dims=*/{transpose_lhs ? 1 : 0}, + /*rhs_non_contracting_dims=*/{transpose_rhs ? 0 : 1}, + /*output_dim_order=*/{0, transpose_lhs ? 1 : 0, 1, transpose_rhs ? 0 : 1}, + /*lhs_batch_dims=*/{}, + /*rhs_batch_dims=*/{}); +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 32ccd45f6e49..dbb2ddaa5853 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -104,6 +104,10 @@ MemRefType getMemRefType(Value value); bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8); +DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, + bool transpose_lhs, + bool transpose_rhs); + #define GEN_PASS_REGISTRATION #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 1d3ea99f4d4c..9a7f4f8a53e1 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -507,6 +510,280 @@ class CanonicalizeAddOfMatmul : public OpRewritePattern { } }; +LogicalResult MatmulOp::verify() { + // Note - this is not yet an exhaustive verification of matmul. Many of the + // invariants are spread across infer, apply, llo and below. This is, + // however, a good start and the recommended place to add more invariants. + const VectorType lhs_ty = getLhs().getType(); + const VectorType rhs_ty = getRhs().getType(); + + if (getTransposeLhs()) { + emitOpError( + "Lhs transpose not supported via this API - please use the " + "dimension numbers API."); + return failure(); + } + + if (getDimensionNumbers().has_value()) { + auto dimension_numbers = getDimensionNumbers().value(); + auto lhs_contracting_dims = dimension_numbers.getLhsContractingDims(); + auto rhs_contracting_dims = dimension_numbers.getRhsContractingDims(); + if (lhs_contracting_dims.size() != 1) { + emitOpError("Not implemented: lhs contracting dims must be of size 1"); + return failure(); + } + if (rhs_contracting_dims.size() != 1) { + emitOpError("Not implemented: rhs contracting dims must be of size 1"); + return failure(); + } + + auto lhs_contracting_dim = lhs_contracting_dims[0]; + auto rhs_contracting_dim = rhs_contracting_dims[0]; + + auto lhs_batch_dims = dimension_numbers.getLhsBatchDims(); + auto rhs_batch_dims = dimension_numbers.getRhsBatchDims(); + + auto lhs_non_contracting_dims = + dimension_numbers.getLhsNonContractingDims(); + auto rhs_non_contracting_dims = + dimension_numbers.getRhsNonContractingDims(); + + if (lhs_contracting_dims.size() + lhs_non_contracting_dims.size() + + lhs_batch_dims.size() != + lhs_ty.getShape().size()) { + emitOpError( + "Not implemented: lhs contracting + non contracting + batch dims " + "must be of the same size as the lhs shape"); + return failure(); + } + if (rhs_contracting_dims.size() + rhs_non_contracting_dims.size() + + rhs_batch_dims.size() != + rhs_ty.getShape().size()) { + emitOpError( + "Not implemented: rhs contracting + non contracting + batch dims " + "must be of the same size as the rhs shape"); + return failure(); + } + + if (lhs_ty.getShape()[lhs_contracting_dim] != + rhs_ty.getShape()[rhs_contracting_dim]) { + emitOpError( + "Not implemented: lhs and rhs contracting dims must be of the same " + "size"); + return failure(); + } + + if (lhs_batch_dims.size() != rhs_batch_dims.size()) { + emitOpError( + "Not implemented: lhs and rhs should have the same number of batch " + "dims"); + return failure(); + } + if (lhs_batch_dims.size() > 1) { + emitOpError("Not implemented: Up to 1 batch dim supported"); + return failure(); + } + + int64_t lhs_rank = lhs_ty.getShape().size(); + int64_t rhs_rank = rhs_ty.getShape().size(); + + std::vector seen_dims_lhs(lhs_rank, false); + std::vector seen_dims_rhs(rhs_rank, false); + + auto check_and_mark_dims = [&](const std::vector &dims, + std::vector &seen_dims, + const std::string_view operand) { + for (int64_t dim : dims) { + if (seen_dims[dim]) { + emitOpError("Illegal: Dim ") + << dim << " repeats in dimension numbers of " << operand; + return failure(); + } + seen_dims[dim] = true; + } + return success(); + }; + + if (failed( + check_and_mark_dims(lhs_contracting_dims, seen_dims_lhs, "lhs")) || + failed(check_and_mark_dims(lhs_non_contracting_dims, seen_dims_lhs, + "lhs")) || + failed(check_and_mark_dims(lhs_batch_dims, seen_dims_lhs, "lhs"))) { + return failure(); + } + + if (failed( + check_and_mark_dims(rhs_contracting_dims, seen_dims_rhs, "rhs")) || + failed(check_and_mark_dims(rhs_non_contracting_dims, seen_dims_rhs, + "rhs")) || + failed(check_and_mark_dims(rhs_batch_dims, seen_dims_rhs, "rhs"))) { + return failure(); + } + + for (int64_t dim = 0; dim < lhs_rank; ++dim) { + if (!seen_dims_lhs[dim]) { + emitOpError("Illegal: Dim ") + << dim << " is not seen in lhs dimension numbers"; + return failure(); + } + } + for (int64_t dim = 0; dim < rhs_rank; ++dim) { + if (!seen_dims_rhs[dim]) { + emitOpError("Illegal: Dim ") + << dim << " is not seen in rhs dimension numbers"; + } + } + + const std::optional batch_dim_lhs = + lhs_batch_dims.empty() ? std::nullopt + : std::optional(lhs_batch_dims[0]); + const std::optional batch_dim_rhs = + rhs_batch_dims.empty() ? std::nullopt + : std::optional(rhs_batch_dims[0]); + if (batch_dim_lhs != batch_dim_rhs) { + emitOpError("Not Implemented: batch dims must be equal"); + return failure(); + } + if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) { + emitOpError("Not Implemented: batch dims pos must be 0"); + return failure(); + } + // Invariant above enforces only 1 batch dim atm, and that both are eq + std::optional batch_size = std::nullopt; + if (batch_dim_lhs.has_value()) { + batch_size = lhs_ty.getShape()[batch_dim_lhs.value()]; + auto rhs_batch_size = rhs_ty.getShape()[batch_dim_rhs.value()]; + if (batch_size != rhs_batch_size) { + emitOpError("Not Implemented: batch dims must be equal"); + return failure(); + } + if (batch_size == 0) { + emitOpError("Illegal: batch size must be > 0"); + return failure(); + } + } + auto output_dim_order = dimension_numbers.getOutputDimOrder(); + if (output_dim_order.size() % 2 != 0) { + emitOpError( + "Illegal: output dim order must have an even number of elements."); + return failure(); + } + if (batch_size.has_value()) { + if (output_dim_order[0] != 0 || output_dim_order[1] != 0) { + emitOpError( + "Not implemented: Output with batch size must be the lhs 0 idx for " + "now."); + return failure(); + } + } + + // Invariants above enforce a single batch idx for now, and that it is in + // position 0. Future extensions to this will be to: + // 1. Support multiple batch dims + // 2. Support batch dims in any position in the output dim order + if (lhs_non_contracting_dims.size() != 1) { + emitOpError( + "Not implemented: lhs non contracting dims must be of size 1"); + return failure(); + } + if (rhs_non_contracting_dims.size() != 1) { + emitOpError( + "Not implemented: rhs non contracting dims must be of size 1"); + return failure(); + } + + // A bit long winded, but the invariants we enforce below are: + // 1. The output order idx is 0 (lhs) or 1 (rhs) + // 2. The output dim order is in valid bounds + // 3. We saw the rhs and lhs non contracting dims in the output dim order + // 4. We never see the contracting dims in the output dim order + // 5. We only see each of the non contracting dim once + std::vector lhs_dims_seen_in_output(lhs_rank, false); + std::vector rhs_dims_seen_in_output(rhs_rank, false); + + // Iterate over the output dimension order + for (int dim_pos = 0; dim_pos < output_dim_order.size(); dim_pos += 2) { + auto idx = output_dim_order[dim_pos]; + auto dim = output_dim_order[dim_pos + 1]; + + if (idx != 0 && idx != 1) { + emitOpError("Illegal: output dim order index must be 0 or 1"); + return failure(); + } + auto is_lhs = (idx == 0); + + if (is_lhs) { + if (dim < 0 || dim >= lhs_rank) { + emitOpError("Illegal: lhs dimension index out of bounds"); + return failure(); + } + if (lhs_dims_seen_in_output[dim]) { + emitOpError("Illegal: lhs dimension ") + << dim << " appears more than once in output dim order"; + return failure(); + } + if (dim == lhs_contracting_dim) { + emitOpError("Illegal: contracting dimension ") + << dim << " appears in lhs output dim order"; + return failure(); + } + // batch_dim_lhs is either 0 or nullopt + if (dim == batch_dim_lhs) { + // Upstream invariants enforce that batch dim is in position 0 + // of the output dim order. + rhs_dims_seen_in_output[dim] = true; + } + lhs_dims_seen_in_output[dim] = true; + } else { + if (dim < 0 || dim >= rhs_rank) { + emitOpError("Illegal: rhs dimension index out of bounds"); + return failure(); + } + if (rhs_dims_seen_in_output[dim]) { + emitOpError("Illegal: rhs dimension ") + << dim << " appears more than once in output dim order"; + return failure(); + } + if (dim == rhs_contracting_dim) { + emitOpError("Illegal: contracting dimension ") + << dim << " appears in rhs output dim order"; + return failure(); + } + if (dim == batch_dim_rhs) { + // Upstream invariants enforce that batch dim is in position 0 + // of the output dim order. + lhs_dims_seen_in_output[dim] = true; + } + rhs_dims_seen_in_output[dim] = true; + } + } + + // Check that all dims have been seen (except contracting dims) + for (int i = 0; i < lhs_rank; ++i) { + if (i == lhs_contracting_dim) { + continue; + } + if (!lhs_dims_seen_in_output[i]) { + emitOpError("Illegal: lhs non-contracting dimension ") + << i << " is not seen in output dim order"; + return failure(); + } + } + + for (int i = 0; i < rhs_rank; ++i) { + if (i == rhs_contracting_dim) { + continue; + } + if (!rhs_dims_seen_in_output[i]) { + emitOpError("Illegal: rhs non-contracting dimension ") + << i << " is not seen in output dim order"; + return failure(); + } + } + } + return success(); +} + void MatmulOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 1b4b2cad1f97..15660aa6cd0a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1696,15 +1696,36 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); })); TPU_ASSERT_OP(layouts_out.front().has_value()); auto matmul_op = cast(op); - const auto transpose_lhs = matmul_op.getTransposeLhs(); - const auto transpose_rhs = matmul_op.getTransposeRhs(); - const auto &layout_lhs = *layouts_in[0]; - const auto &layout_rhs = *layouts_in[1]; - const auto &layout_acc = *layouts_in[2]; - const auto &layout_out = *layouts_out[0]; + if (matmul_op.getTransposeRhs()) { + return op.emitOpError( + "Transposition must have been erased into dimension numbers during " + "canonicalization"); + } + + auto dimension_numbers = matmul_op.getDimensionNumbers(); + if (!dimension_numbers.has_value()) { + return op.emitOpError( + "Dimension numbers must be provided, ensure canonicalization has been " + "run."); + } + auto transposed_mkn = isTransposedMatmul(dimension_numbers.value()); + if (!transposed_mkn.has_value()) { + return op.emitOpError( + "Dimension numbers must be MKN, ensure canonicalization has been " + "run."); + } + auto [transpose_lhs, transpose_rhs] = transposed_mkn.value(); if (transpose_lhs) { - return op.emitOpError("Not implemented: Transposed LHS"); + return op.emitOpError( + "Transposition of LHS is not supported in apply_vector_layout, ensure " + "canonicalization has been run."); } + + auto &layout_lhs = *layouts_in[0]; + auto &layout_rhs = *layouts_in[1]; + auto &layout_acc = *layouts_in[2]; + auto &layout_out = *layouts_out[0]; + const std::array, 4> all_layouts = {layout_lhs, layout_rhs, layout_acc, layout_out}; for (const VectorLayout &layout : all_layouts) { @@ -1965,6 +1986,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, const tpu::ContractPrecisionAttr precision_attr = // May be null op.getAttrOfType("precision"); + const tpu::DotDimensionNumbersAttr dot_dimension_numbers_attr = + defaultDimensionNumbers(builder, false, transpose_rhs); for (int64_t j = 0; j < nj; ++j) { for (int64_t k = 0; k < nk; ++k) { // TODO(tlongeri): there should be a way to slice without copying @@ -1981,7 +2004,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, acc_col->setAttr("out_layout", acc_layout_attr); auto new_acc_col = builder.create( op.getLoc(), acc_col_ty, lhs_cols[k], rhs_rolled_group, acc_col, - transpose_lhs, transpose_rhs, precision_attr); + /*transpose_lhs=*/false, /*transpose_rhs=*/false, precision_attr, + dot_dimension_numbers_attr); auto new_acc_vregs = builder.create( op.getLoc(), TypeRange(ValueRange(XlaArrayToFlatArrayRef(acc_col_vregs))), diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index b95ff5067734..232121cc834c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,6 +1,10 @@ +#include #include #include #include +#include +#include +#include #include #include "llvm/ADT/STLExtras.h" @@ -16,6 +20,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h" @@ -23,6 +29,7 @@ #include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/Block.h" #include "mlir/include/mlir/IR/Builders.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/Operation.h" @@ -40,6 +47,9 @@ namespace mlir::tpu { LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); + auto transpose_lhs = op.getTransposeLhs(); + auto transpose_rhs = op.getTransposeRhs(); + auto lhs = op.getLhs(); auto rhs = op.getRhs(); auto acc = op.getAcc(); @@ -52,6 +62,51 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { auto rhs_element_type = rhs_ty.getElementType(); auto acc_element_type = acc_ty.getElementType(); + // there are a few primary paths for dimension_numbers in matmul + // 1) No dimension numbers provided -> set to default + // 2) defined and not default -> verify and apply + // 3) defined and matching defaultDimensionNumbers -> no-op for + // canonicalization of dims + std::optional batch_size = std::nullopt; + + // MKN matmul - no dims or transpositions set + if (!op.getDimensionNumbers().has_value()) { + // Legacy API - convert it to dimension numbers + op.setDimensionNumbersAttr( + defaultDimensionNumbers(builder, transpose_lhs, transpose_rhs)); + } else if ( + // Dot dim API - dimensions are provided and are not default + (op.getDimensionNumbers().value() != + defaultDimensionNumbers(builder, false, false))) { + auto dimension_numbers = op.getDimensionNumbers(); + auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims(); + auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims(); + + auto lhs_batch_dims = dimension_numbers->getLhsBatchDims(); + auto rhs_batch_dims = dimension_numbers->getRhsBatchDims(); + + // Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs + // are the same + // Invariant in matmul verifier: Exactly one contracting and non contracting + // dim in each of lhs and rhs for now. + batch_size = + lhs_batch_dims.empty() + ? std::nullopt + : std::optional(lhs_ty.getShape()[lhs_batch_dims[0]]); + // Lower each dim in contracting dims by size(batch_dims) + auto batch_adjusted_lhs_contracting_dim = + lhs_contracting_dims[0] - lhs_batch_dims.size(); + auto batch_adjusted_rhs_contracting_dim = + rhs_contracting_dims[0] - rhs_batch_dims.size(); + + if (batch_adjusted_lhs_contracting_dim != 1) { + transpose_lhs = true; + } + if (batch_adjusted_rhs_contracting_dim != 0) { + transpose_rhs = true; + } + } + auto extsi_sitofp = [&builder, &op](TypedValue element) { const VectorType ty = element.getType(); auto shape = ty.getShape(); @@ -88,10 +143,12 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { if (lhs_element_type.isInteger()) { auto float_lhs = extsi_sitofp(lhs); op->setOperand(0, float_lhs); + lhs = cast>(float_lhs.getResult()); } if (rhs_element_type.isInteger()) { auto float_rhs = extsi_sitofp(rhs); op->setOperand(1, float_rhs); + rhs = cast>(float_rhs.getResult()); } } // TODO(mvoz): Add more invariants. @@ -114,6 +171,91 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { return failure(); } } + + auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) { + auto precision_attr = op.getPrecisionAttr(); + + // If we are transposing the lhs, we need to transpose the lhs before + // matmul here, as we don't have lhs fusion implemented in apply. + if (transpose_lhs) { + auto lhs_ty = cast(lhs.getType()); + auto rank = lhs_ty.getShape().size(); + + // This transposition must run on vectors with rank >= 2 + CHECK_GE(rank, 2); + + std::vector perm(rank); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[rank - 2], perm[rank - 1]); + + std::vector shape(lhs_ty.getShape()); + std::swap(shape[rank - 2], shape[rank - 1]); + + auto lhs_ty_transposed = VectorType::get(shape, lhs_ty.getElementType()); + + const SmallVector perm_vec = + SmallVector(perm.begin(), perm.end()); + lhs = builder.create( + lhs_ty_transposed, lhs, + DenseI64ArrayAttr::get(builder.getContext(), perm_vec)); + } + auto ddn = defaultDimensionNumbers(builder, /*transpose_lhs=*/false, + transpose_rhs); + // transpose flags are always false here, because ddn takes precedence + // after this pass. + auto matmul_res = builder.create( + op.getLoc(), acc.getType(), lhs, rhs, acc, + /*transpose_lhs=*/false, + /*transpose_rhs=*/false, precision_attr, ddn); + return matmul_res; + }; + + // If we have a batch_size, we want to slice rhs and lhs [:batch_size], + // and then do O[i] = A[i] @ B[i] + // Produce an output shape of [batch_size, m, n] + if (batch_size.has_value()) { + std::vector outputs; + + for (int64_t i = 0; i < batch_size; ++i) { + auto sliced_lhs = builder.create(op.getLoc(), lhs, + ArrayRef{i}); + auto sliced_rhs = builder.create(op.getLoc(), rhs, + ArrayRef{i}); + + auto sliced_acc = builder.create(op.getLoc(), acc, + ArrayRef{i}); + + auto matmul_res = + dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(), + sliced_acc.getResult()); + auto res_ty = matmul_res.getType().cast(); + auto res_shape = res_ty.getShape(); + // reshape to 1x[prior_shape] + auto reshape_shape = llvm::to_vector(res_shape); + reshape_shape.insert(reshape_shape.begin(), 1); + auto shape_cast = builder.create( + op.getLoc(), VectorType::get(reshape_shape, res_ty.getElementType()), + matmul_res); + outputs.push_back(shape_cast); + } + // Technically almost identical to the case where batch_size is 1, but + // we want to avoid the spurious concat here. + if (batch_size == 1) { + op.replaceAllUsesWith(outputs[0]); + op.erase(); + return success(); + } + auto output = builder + .create(op.getLoc(), acc_ty, outputs, + /*dimension=*/0) + .getResult(); + op.replaceAllUsesWith(output); + op.erase(); + } else { + auto matmul_res = dot_dim_matmul(lhs, rhs, acc).getResult(); + op.replaceAllUsesWith(matmul_res); + op.erase(); + } return success(); }; @@ -309,9 +451,14 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) { } const tpu::ContractPrecisionAttr precision_attr = // May be null contraction_op->getAttrOfType("precision"); + + const auto dot_dimension_numbers_attr = + defaultDimensionNumbers(builder, false, transpose_rhs); + auto matmul_op = builder.create( contraction_op->getLoc(), acc_ty, lhs, rhs, acc, - /*transpose_lhs=*/false, transpose_rhs, precision_attr); + /*transpose_lhs=*/false, + /*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr); contraction_op.replaceAllUsesWith(matmul_op.getResult()); contraction_op.erase(); auto result = tpu_matmul_rule(matmul_op); diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 638a76fa5683..b74a43dce32f 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include "llvm/Support/MathExtras.h" #include "absl/types/span.h" @@ -42,6 +44,31 @@ SmallVector ComputeTileStrides(MemRefType memref_ty, return tile_strides; } +std::optional> isTransposedMatmul( + DotDimensionNumbersAttr dim_numbers) { + auto lhs_contracting_dims = dim_numbers.getLhsContractingDims(); + auto rhs_contracting_dims = dim_numbers.getRhsContractingDims(); + auto lhs_non_contracting_dims = dim_numbers.getLhsNonContractingDims(); + auto rhs_non_contracting_dims = dim_numbers.getRhsNonContractingDims(); + + if (lhs_contracting_dims.size() != 1 || rhs_contracting_dims.size() != 1 || + lhs_non_contracting_dims.size() != 1 || + rhs_non_contracting_dims.size() != 1) { + return std::nullopt; + } + + int64_t lhs_non_contracting_dim = lhs_non_contracting_dims[0]; + int64_t lhs_contracting_dim = lhs_contracting_dims[0]; + int64_t rhs_non_contracting_dim = rhs_non_contracting_dims[0]; + int64_t rhs_contracting_dim = rhs_contracting_dims[0]; + + bool lhs_transposed = lhs_non_contracting_dim > lhs_contracting_dim; + + bool rhs_transposed = rhs_contracting_dim > rhs_non_contracting_dim; + + return std::pair{lhs_transposed, rhs_transposed}; +} + bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty, const std::array& target_shape, bool allow_minormost_padding) { @@ -68,4 +95,5 @@ bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty, *(tiled_layout.getTileStrides().end() - 1) == 1 && *(tiled_layout.getTileStrides().end() - 2) == 1); } + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index f1771b948304..5b068fedd3fd 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -16,6 +16,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/types/span.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with @@ -98,6 +99,14 @@ std::string shapeToString(const T &shape) { SmallVector ComputeTileStrides(MemRefType memref_ty, absl::Span tiling); +// Assuming MKN matmul - This function must only be called after +// canonicalization passes. +// +// Given a set of dimension numbers, Returns a pair of booleans, where the +// first is true if the lhs is transposed +// and the second is true if the rhs is transposed. +std::optional> isTransposedMatmul( + DotDimensionNumbersAttr dim_numbers); // Returns true if a >=2D memref has a tiled layout and can be equivalently // considered as an untiled memref, except for potential padding in the @@ -106,6 +115,7 @@ SmallVector ComputeTileStrides(MemRefType memref_ty, bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty, const std::array &target_shape, bool allow_minormost_padding = false); + } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ From 7d504cd95ad38def890bddd8c37c7d725b55cb9e Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 31 Oct 2024 01:28:38 -0700 Subject: [PATCH 138/698] [MOSAIC:GPU] Extend the mosaic mlir dialect with fragmented layouts. PiperOrigin-RevId: 691712579 --- jaxlib/mosaic/dialect/gpu/BUILD | 32 ++++++++++++++++++------- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 4 +++- jaxlib/mosaic/dialect/gpu/mosaic_gpu.h | 3 +++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 32 ++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 0fff8eee6529..940633182498 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -47,17 +47,11 @@ gentbl_cc_library( "mosaic_gpu_dialect.cc.inc", ), ( - [ - "-gen-op-decls", - "--typedefs-dialect=mosaic_gpu", - ], + ["-gen-op-decls"], "mosaic_gpu_ops.h.inc", ), ( - [ - "-gen-op-defs", - "--typedefs-dialect=mosaic_gpu", - ], + ["-gen-op-defs"], "mosaic_gpu_ops.cc.inc", ), ( @@ -74,6 +68,28 @@ gentbl_cc_library( ], "mosaic_gpu_types.cc.inc", ), + ( + ["-gen-enum-decls"], + "mosaic_gpu_enums.h.inc", + ), + ( + ["-gen-enum-defs"], + "mosaic_gpu_enums.cc.inc", + ), + ( + [ + "-gen-attrdef-decls", + "--attrdefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_attrdefs.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + "--attrdefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_attrdefs.cc.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mosaic_gpu.td", diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 933d798238e3..14aa7fbbe1bc 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -48,7 +48,9 @@ limitations under the License. // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" - +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_enums.cc.inc" +#define GET_ATTRDEF_CLASSES +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc" #define GET_TYPEDEF_CLASSES #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc" #define GET_OP_CLASSES diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index 1badcab28012..5caf773f12a6 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -30,6 +30,9 @@ limitations under the License. // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_enums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.h.inc" #define GET_TYPEDEF_CLASSES #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.h.inc" #define GET_OP_CLASSES diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index e7154f577a7a..2037cfd6566d 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -16,9 +16,12 @@ limitations under the License. #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AttrTypeBase.td" include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypeInterfaces.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/CommonAttrConstraints.td" include "third_party/llvm/llvm-project/mlir/include/mlir/IR/CommonTypeConstraints.td" include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectBase.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/EnumAttr.td" include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" def MosaicGPU_Dialect : Dialect { @@ -53,4 +56,31 @@ def MosaicGPU_InitializeBarrierOp : Op, + + // Convert the array to 1D and then shard across threads. + I32EnumAttrCase<"WGStridedFragLayout", 1>, + + // [m, n] matrix, where m % 64 == 0 == n % 8. + I32EnumAttrCase<"WGMMAFragLayout", 2>, + + // [m] vector, where m % 64 == 0. + I32EnumAttrCase<"WGMMARowFragLayout", 3> + ]> { + let cppNamespace = "::mosaic_gpu"; + let genSpecializedAttr = 0; +} + +def MosaicGPU_FragmentedLayoutAttr : EnumAttr< + MosaicGPU_Dialect, MosaicGPU_FragmentedLayout, "fragmented_layout"> { + let assemblyFormat = "`<` $value `>`"; +} + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ From 85662f6dd83baa92ae2e18620923d25fbdcac420 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 31 Oct 2024 02:20:33 -0700 Subject: [PATCH 139/698] [pallas:mosaic_gpu] `plgpu.copy_smem_to_gmem` no longer transparently commits SMEM Users are expected to call `pltpu.commit_smem` manually instead. PiperOrigin-RevId: 691724662 --- docs/jax.experimental.pallas.mosaic_gpu.rst | 1 + jax/_src/pallas/mosaic_gpu/pipeline.py | 1 + jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- jax/experimental/pallas/ops/gpu/attention_mgpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 3 +++ 5 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 71bf9c3ffae4..2d3452609c75 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -27,6 +27,7 @@ Functions barrier_arrive barrier_wait + commit_smem copy_gmem_to_smem copy_smem_to_gmem emit_pipeline diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 8d2274f1408c..9a17646f0758 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -158,6 +158,7 @@ def loop_body(step, _): ) # Copy the output from SMEM to GMEM. + gpu_primitives.commit_smem() map(lambda bref: bref.copy_out(slot, indices), out_brefs) fetch_step = step + max_concurrent_steps diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index f87e96a30c5f..1a5ed7f0d43e 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -66,7 +66,6 @@ def _copy_smem_to_gmem_lowering( dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) src, src_transforms = lowering._handle_indexing(src, src_transforms) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - mgpu.commit_shared() ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params) return () @@ -105,6 +104,7 @@ def copy_smem_to_gmem( See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` + :func:`jax.experimental.mosaic.gpu.commit_smem` """ if src.memory_space is not gpu_core.SMEM: raise TypeError(f"src must be a SMEM reference, got {src.memory_space}") diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 9320550969e8..1b240305aeff 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -144,6 +144,7 @@ def _wait(): # TODO(apaszke): Invert and multiply to avoid expensive divisions. acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)], ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 8d17d8458134..f60c6c7c6023 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -239,6 +239,7 @@ def test_copy_smem_to_gmem(self, indexer): ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 + plgpu.commit_smem() plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer]) plgpu.wait_smem_to_gmem(0) @@ -294,6 +295,7 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref) plgpu.barrier_wait(barrier_ref) else: + plgpu.commit_smem() plgpu.copy_smem_to_gmem(x_ref, o_ref) plgpu.wait_smem_to_gmem(0) @@ -1046,6 +1048,7 @@ def body(step, _): o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0 + plgpu.commit_smem() plgpu.copy_smem_to_gmem( o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)] ) From c708a04c6e43a52568eaff8709554eb42ba0d975 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 31 Oct 2024 02:46:55 -0700 Subject: [PATCH 140/698] [Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect. Also start moving the existing C++ tests to Python. PiperOrigin-RevId: 691729887 --- jax/_src/lib/BUILD | 1 + jax/_src/lib/__init__.py | 9 ++- jax/experimental/mosaic/gpu/__init__.py | 2 + jaxlib/mlir/_mlir_libs/BUILD.bazel | 12 ++++ jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc | 37 ++++++++++ jaxlib/mosaic/dialect/gpu/BUILD | 52 +++++++++++++- .../dialect/gpu/integrations/c/gpu_dialect.cc | 25 +++++++ .../dialect/gpu/integrations/c/gpu_dialect.h | 33 +++++++++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 44 ++++++------ jaxlib/mosaic/dialect/gpu/mosaic_gpu.h | 12 ++-- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 14 ++-- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 28 -------- jaxlib/mosaic/python/BUILD | 13 ++++ jaxlib/mosaic/python/gpu.py | 31 +++++++++ tests/mosaic/BUILD | 11 +++ tests/mosaic/gpu_dialect_test.py | 68 +++++++++++++++++++ 16 files changed, 327 insertions(+), 65 deletions(-) create mode 100644 jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc create mode 100644 jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc create mode 100644 jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h create mode 100644 jaxlib/mosaic/python/gpu.py create mode 100644 tests/mosaic/gpu_dialect_test.py diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 7068c0ef6732..1fcbd4b6b7ef 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -40,6 +40,7 @@ py_library_providing_imports_info( "//jax:version", ] + if_building_jaxlib([ "//jaxlib", + "//jaxlib/mosaic/python:gpu_dialect", "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib:cpu_feature_guard", "//jaxlib:utils", diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 9cc54a59f259..a3be2390d856 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -120,7 +120,14 @@ def _xla_gc_callback(*args): import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 -import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 +try: + import jaxlib.mosaic.python.gpu as mosaic_gpu_dialect # pytype: disable=import-error +except ImportError: + # TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36. + # Jaxlib doesn't contain Mosaic GPU dialect bindings. + mosaic_gpu_dialect = None # type: ignore + +import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 # Version number for MLIR:Python APIs, provided by jaxlib. mlir_api_version = xla_client.mlir_api_version diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 5d8a4dd9fc14..4feb12704f98 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== from jax import ShapeDtypeStruct as ShapeDtypeStruct +from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401 + from .core import ( Barrier as Barrier, ClusterBarrier as ClusterBarrier, diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 1c45a4ce9463..0b94f9d1d948 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -149,6 +149,18 @@ py_extension( ], ) +py_extension( + name = "_mosaic_gpu_ext", + srcs = ["mosaic_gpu_ext.cc"], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + ], +) + # This is here, instead of in jaxlib/mosaic/python, so it's in the same # directory as libjaxlib_mlir_capi.so (produced by # :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc new file mode 100644 index 000000000000..7204bbaa1658 --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// clang-format: off +// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h, +// otherwise this code will not build on Windows. +#include "pybind11/pybind11.h" +// clang-format: on + +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" // IWYU pragma: keep +#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" + +PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) { + m.def( + "register_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle dialect = mlirGetDialectHandle__mosaic_gpu__(); + mlirDialectHandleRegisterDialect(dialect, context); + if (load) { + mlirDialectHandleLoadDialect(dialect, context); + } + }, + py::arg("context"), py::arg("load") = true); +} diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 940633182498..4207e769e6a1 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load( + "@llvm-project//mlir:tblgen.bzl", + "gentbl_cc_library", + "gentbl_filegroup", + "td_library", +) package( default_applicable_licenses = [], @@ -143,3 +148,48 @@ cc_test( "@tsl//tsl/platform:errors", ], ) + +gentbl_filegroup( + name = "mosaic_gpu_python_gen_raw", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=mosaic_gpu", + ], + "_mosaic_gpu_gen_raw.py", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = ":mosaic_gpu.td", + deps = [ + ":mosaic_gpu_td_files", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +genrule( + name = "mosaic_gpu_python_gen", + srcs = ["_mosaic_gpu_gen_raw.py"], + outs = ["_mosaic_gpu_gen.py"], + cmd = "cat $(location _mosaic_gpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@", +) + +DIALECT_CAPI_SOURCES = [ + ":integrations/c/gpu_dialect.cc", +] + +DIALECT_CAPI_HEADERS = [ + ":integrations/c/gpu_dialect.h", +] + +cc_library( + name = "gpu_dialect_capi", + srcs = DIALECT_CAPI_SOURCES, + hdrs = DIALECT_CAPI_HEADERS, + deps = [ + ":mosaic_gpu", + ":mosaic_gpu_inc_gen", + "@llvm-project//mlir:CAPIIR", + ], +) diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc new file mode 100644 index 000000000000..1a854f395044 --- /dev/null +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc @@ -0,0 +1,25 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" + +#include "mlir/CAPI/Registration.h" +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" + +extern "C" { + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu, + mosaic_gpu::MosaicGPUDialect); +} diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h new file mode 100644 index 000000000000..bb6cf6e3af4a --- /dev/null +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_ +#define JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_ + +#include + +#include "mlir/CAPI/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu); + +#ifdef __cplusplus +} +#endif + +#endif // JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 14aa7fbbe1bc..8c5573bf1b80 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -22,28 +22,28 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/TypeSwitch.h" // IWYU pragma: keep -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Dialect.h" -#include "mlir/include/mlir/IR/DialectImplementation.h" // IWYU pragma: keep -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Location.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/TypeRange.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "llvm/Support/Casting.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "tsl/platform/statusor.h" // Generated definitions. diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index 5caf773f12a6..b46675d1c9a7 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -21,12 +21,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/raw_ostream.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 2037cfd6566d..b05e6ebd71b7 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -16,13 +16,13 @@ limitations under the License. #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AttrTypeBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypeInterfaces.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/CommonAttrConstraints.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/CommonTypeConstraints.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/EnumAttr.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/CommonAttrConstraints.td" +include "mlir/IR/CommonTypeConstraints.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" def MosaicGPU_Dialect : Dialect { let name = "mosaic_gpu"; diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index a1379619d922..3acacd2315b5 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -193,34 +193,6 @@ TEST_F(MosaicGpuTest, RuntimeFunctionsAreRegistered) { mosaic_gpu::kRuntimeMemcpyAsyncH2DName)); } -TEST_F(MosaicGpuTest, InitializeBarrierOpEnforcesRelevantInvariants) { - auto loc = builder_.getUnknownLoc(); - auto f32 = builder_.getF32Type(); - auto barrier = BarrierType::get(&context_); - - // InitializeBarrierOp requires a memref with type `BarrierType`. - auto initialize_op = builder_.create( - loc, mlir::MemRefType::get({1, 2}, f32), /*arrival_count=*/1); - EXPECT_FALSE(mlir::succeeded(mlir::verify(*module_))); - ExpectLastErrorContains("must be memref of barrier values"); - initialize_op->erase(); - - // InitializeBarrierOp requires a non-negative arrival count. - initialize_op = builder_.create( - loc, mlir::MemRefType::get({1, 2}, barrier), /*arrival_count=*/0); - EXPECT_FALSE(mlir::succeeded(mlir::verify(*module_))); - ExpectLastErrorContains("value is positive"); - initialize_op->erase(); - - // Checks that InitializeBarrierOp prints nicely. - initialize_op = builder_.create( - loc, mlir::MemRefType::get({1, 2}, barrier), /*arrival_count=*/1); - EXPECT_TRUE(mlir::succeeded(mlir::verify(*module_))); - EXPECT_THAT( - MlirToString(initialize_op), - HasSubstr( - "mosaic_gpu.initialize_barrier 1 : memref<1x2x!mosaic_gpu.barrier>")); -} } // anonymous namespace } // namespace mosaic_gpu diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 48268bfcf30a..6899914e6b89 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -17,6 +17,19 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") load("@rules_python//python:defs.bzl", "py_library") +py_library( + name = "gpu_dialect", + srcs = [ + "gpu.py", + "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//jaxlib/mlir", + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", + ], +) + gentbl_filegroup( name = "tpu_python_gen_raw", tbl_outs = [ diff --git a/jaxlib/mosaic/python/gpu.py b/jaxlib/mosaic/python/gpu.py new file mode 100644 index 000000000000..755a4d3eff7d --- /dev/null +++ b/jaxlib/mosaic/python/gpu.py @@ -0,0 +1,31 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python bindings for the MLIR Mosaic GPU dialect.""" + +# ruff: noqa: F401 +# ruff: noqa: F403 + + +# pylint: disable=g-bad-import-order +from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen import * # pylint: disable=wildcard-import # type: ignore[import-not-found] +from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import * # pylint: disable=wildcard-import # type: ignore[import-not-found] + +try: + from jaxlib.mlir.dialects._ods_common import _cext +except ImportError: + from mlir.dialects._ods_common import _cext # type: ignore[import-not-found] + + +_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python") diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 3d1348371f07..ca2c9a4bf27d 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -16,6 +16,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", "jax_multiplatform_test", + "jax_py_test", "py_deps", ) @@ -43,6 +44,16 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_py_test( + name = "gpu_dialect_test", + srcs = ["gpu_dialect_test.py"], + deps = [ + "//jax", + "//jax:mosaic_gpu", + "//jax:test_util", + ] + py_deps("absl/testing"), +) + jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py new file mode 100644 index 000000000000..19701012f706 --- /dev/null +++ b/tests/mosaic/gpu_dialect_test.py @@ -0,0 +1,68 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""(Deviceless) tests for the Mosaic GPU MLIR dialect.""" + +from absl.testing import parameterized +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member + + +config.parse_flags_with_absl() + + +def _make_ir_context(): + context = ir.Context() + mgpu.register_dialect(context) + return context + + +class DialectTest(parameterized.TestCase): + + def setUp(self): + if mgpu is None: + raise self.skipTest("Test requires Mosaic GPU dialect") + super().setUp() + self.enter_context(_make_ir_context()) + self.enter_context(ir.Location.unknown()) + self.module = ir.Module.create() + + def test_initialize_barrier_op_result_memref_must_wrap_barriers(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1) + with self.assertRaisesRegex( + ir.MLIRError, "must be memref of barrier values"): + self.module.operation.verify() + + def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + arrival_count=0) + with self.assertRaisesRegex(ir.MLIRError, "value is positive"): + self.module.operation.verify() + + def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + arrival_count=1) + self.assertTrue(self.module.operation.verify()) + + +if __name__ == "__main__": + parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 52ad60521cffaa4d175cb5dcc8c8a8c5bc3d6738 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 23 Oct 2024 09:17:27 -0400 Subject: [PATCH 141/698] Run dot algorithm tests with PJRT plugin. --- tests/lax_test.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 79ad9fcfa02c..f2ce0913e03a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -41,7 +41,6 @@ from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu -from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util @@ -1077,9 +1076,6 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): if jtu.dtypes.supported([dtype]) ]) def testDotAlgorithm(self, algorithm, dtype): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): if algorithm not in { lax.DotAlgorithmPreset.DEFAULT, @@ -1130,9 +1126,6 @@ def testDotAlgorithm(self, algorithm, dtype): self.assertEqual(lax.dot(*args_maker(), precision=algorithm).dtype, dtype) def testDotAlgorithmInvalidFloat8Type(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["cpu"]): raise SkipTest("Not supported on CPU.") lhs_shape = (3, 4) @@ -1143,9 +1136,6 @@ def testDotAlgorithmInvalidFloat8Type(self): lax.dot(lhs, rhs, precision="ANY_F8_ANY_F8_F32") def testDotAlgorithmCasting(self): - if xla_bridge.using_pjrt_c_api(): - raise SkipTest( - "The dot algorithm attribute is not supported by PJRT C API.") if jtu.test_device_matches(["tpu"]): raise SkipTest("F32_F32_F32 is not supported on TPU.") def fun(lhs, rhs): From 692e225657dbf72105898023bb270b39d18e9b2a Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 10:20:23 -0500 Subject: [PATCH 142/698] Add workflow for nightly pull from upstream --- .../workflows/rocm-nightly-upstream-sync.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .github/workflows/rocm-nightly-upstream-sync.yml diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml new file mode 100644 index 000000000000..880ea232d307 --- /dev/null +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -0,0 +1,18 @@ +# Pulls the latest changes from upstream into main and opens a PR to merge +# them into rocm-main. + +name: ROCm Nightly Upstream Sync +on: + schedule: + - cron: '0 6 * * *' +jobs: + sync-main: + runs-on: ubuntu-latest + steps: + - run: gh repo sync rocm/jax -b main + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + open-sync-pr: + runs-on: ubuntu-latest + steps: + - run: gh pr create --repo rocm/jax --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" From 1c877f5a7eecda5bdc834b4f9304946611524425 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 10:29:36 -0500 Subject: [PATCH 143/698] Only run on weekdays --- .github/workflows/rocm-nightly-upstream-sync.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index 880ea232d307..ba81edac5bc9 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -4,7 +4,7 @@ name: ROCm Nightly Upstream Sync on: schedule: - - cron: '0 6 * * *' + - cron: '0 6 * * 1-5' jobs: sync-main: runs-on: ubuntu-latest From 3361fca5b8300d17a007decc0139b0d384b8d9cd Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 10:49:29 -0500 Subject: [PATCH 144/698] Fix yaml checker --- .github/workflows/rocm-nightly-upstream-sync.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index ba81edac5bc9..dcfbc01d1db5 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -15,4 +15,5 @@ jobs: open-sync-pr: runs-on: ubuntu-latest steps: - - run: gh pr create --repo rocm/jax --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" + - run: | + gh pr create --repo $GITHUB_REPOSITORY --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" From ce8dba98fb2be27f939833487a817161e9387126 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 8 Oct 2024 10:03:58 -0400 Subject: [PATCH 145/698] Move the CUDA end-to-end example to FFI examples workflow + hosted runner. --- .github/workflows/ci-build.yaml | 22 ++- docs/cuda_custom_call/BUILD | 60 ------- docs/cuda_custom_call/Makefile | 35 ----- .../cuda_custom_call/cuda_custom_call_test.py | 147 ------------------ examples/ffi/CMakeLists.txt | 11 ++ .../ffi/src/jax_ffi_example/cuda_e2e.cu | 0 examples/ffi/src/jax_ffi_example/cuda_e2e.py | 68 ++++++++ examples/ffi/tests/attrs_test.py | 5 + examples/ffi/tests/cuda_e2e_test.py | 75 +++++++++ examples/ffi/tests/rms_norm_test.py | 5 + 10 files changed, 178 insertions(+), 250 deletions(-) delete mode 100644 docs/cuda_custom_call/BUILD delete mode 100644 docs/cuda_custom_call/Makefile delete mode 100644 docs/cuda_custom_call/cuda_custom_call_test.py rename docs/cuda_custom_call/foo.cu.cc => examples/ffi/src/jax_ffi_example/cuda_e2e.cu (100%) create mode 100644 examples/ffi/src/jax_ffi_example/cuda_e2e.py create mode 100644 examples/ffi/tests/cuda_e2e_test.py diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 581fb858732c..6cdc1175e600 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -61,7 +61,7 @@ jobs: - name: Image Setup run: | apt update - apt install -y libssl-dev + apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: @@ -217,14 +217,16 @@ jobs: ffi: name: FFI example - runs-on: ubuntu-latest - timeout-minutes: 5 + runs-on: linux-x86-g2-16-l4-1gpu + container: + image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12 + timeout-minutes: 30 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Set up Python 3.11 + - name: Set up Python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: - python-version: 3.11 + python-version: 3.12 - name: Get pip cache dir id: pip-cache run: | @@ -236,7 +238,7 @@ jobs: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} - name: Install JAX - run: pip install . + run: pip install .[cuda12] - name: Build and install example project run: python -m pip install -v ./examples/ffi[test] env: @@ -245,6 +247,10 @@ jobs: # a different toolchain. GCC is the default compiler on the # 'ubuntu-latest' runner, but we still set this explicitly just to be # clear. - CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ - - name: Run tests + CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON + - name: Run CPU tests + run: python -m pytest examples/ffi/tests + env: + JAX_PLATFORM_NAME: cpu + - name: Run GPU tests run: python -m pytest examples/ffi/tests diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD deleted file mode 100644 index 4954ce3db4fa..000000000000 --- a/docs/cuda_custom_call/BUILD +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load( - "//jaxlib:jax.bzl", - "cuda_library", - "jax_generate_backend_suites", - "jax_multiplatform_test", -) - -licenses(["notice"]) - -package( - default_applicable_licenses = [], - default_visibility = ["//visibility:private"], -) - -jax_generate_backend_suites() - -jax_multiplatform_test( - name = "cuda_custom_call_test", - srcs = ["cuda_custom_call_test.py"], - data = [":foo"], - enable_backends = ["gpu"], - tags = ["notap"], - deps = [ - "//jax:extend", - ], -) - -# this second target is needed to properly link in CUDA runtime symbols -# such as cudaLaunchKernel, even though we are only building one library. -cc_shared_library( - name = "foo", - deps = [ - ":foo_", - "@xla//xla/tsl/cuda:cudart", - ], -) - -cuda_library( - name = "foo_", - srcs = ["foo.cu.cc"], - deps = [ - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - ], -) diff --git a/docs/cuda_custom_call/Makefile b/docs/cuda_custom_call/Makefile deleted file mode 100644 index ca51b63b5eaf..000000000000 --- a/docs/cuda_custom_call/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# This Makefile is not used by Bazel for this test, it is intended to serve as -# documentation of build instructions for JAX users that are not using Bazel to -# build their custom call code. For that reason, this Makefile is likely subject -# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in -# this directory no longer runs the test to completion. -NVCC = nvcc -NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())') -NVCCFLAGS += -arch native -# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu -NVCCFLAGS += -x cu - -# depends on libfoo.so being in the same directory as cuda_custom_call_test.py -check: libfoo.so - python cuda_custom_call_test.py - -lib%.so: %.cu.cc - $(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $< - -clean: - rm -rf *.so diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py deleted file mode 100644 index f63bbd670bf5..000000000000 --- a/docs/cuda_custom_call/cuda_custom_call_test.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# This test is intentionally structured to stay close to what a standalone JAX -# custom call integration might look like. JAX test harness is in a separate -# section towards the end of this file. The test can be run standalone by typing -# "make" in the directory containing this file. - -import os -import ctypes -import unittest - -import numpy as np - -import jax -import jax.numpy as jnp -from jax.extend import ffi - -# start test boilerplate -from absl.testing import absltest -from jax._src import config -from jax._src import test_util as jtu - -config.parse_flags_with_absl() -# end test boilerplate - -# XLA needs uppercase, "cuda" isn't recognized -XLA_PLATFORM = "CUDA" - -# JAX needs lowercase, "CUDA" isn't recognized -JAX_PLATFORM = "cuda" - -# 0 = original ("opaque"), 1 = FFI -XLA_CUSTOM_CALL_API_VERSION = 1 - -# these strings are how we identify kernels to XLA: -# - first we register a pointer to the kernel with XLA under this name -# - then we "tell" JAX to emit StableHLO specifying this name to XLA -XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd" -XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd" - -# load the shared library with the FFI target definitions -if jtu.is_running_under_pytest(): - raise unittest.SkipTest("libfoo.so hasn't been built") -SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so") -library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) - -# register the custom calls targets with XLA, api_version=1 by default -ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_FWD, - fn=ffi.pycapsule(library.FooFwd), - platform=XLA_PLATFORM) -ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_BWD, - fn=ffi.pycapsule(library.FooBwd), - platform=XLA_PLATFORM) - -def foo_fwd(a, b): - assert a.dtype == jnp.float32 - assert a.shape == b.shape - assert a.dtype == b.dtype - n = np.prod(a.shape).astype(np.uint64) - out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) - c, b_plus_1 = ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_FWD, (out_type, out_type), - a, b, n=n) - return c, (a, b_plus_1) - - -def foo_bwd(res, c_grad): - a, b_plus_1 = res - assert c_grad.dtype == jnp.float32 - assert c_grad.shape == a.shape - assert a.shape == b_plus_1.shape - assert c_grad.dtype == a.dtype - assert a.dtype == b_plus_1.dtype - n = np.prod(a.shape).astype(np.uint64) - out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) - return ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_BWD, (out_type, out_type), - c_grad, a, b_plus_1, n=n) - - -@jax.custom_vjp -def foo(a, b): - c, _ = foo_fwd(a, b) - return c - - -foo.defvjp(foo_fwd, foo_bwd) - -#-----------------------------------------------------------------------------# -# Test # -#-----------------------------------------------------------------------------# - - -class CustomCallTest(jtu.JaxTestCase): - - def test_fwd_interpretable(self): - shape = (2, 3) - a = 2. * jnp.ones(shape) - b = 3. * jnp.ones(shape) - observed = jax.jit(foo)(a, b) - expected = (2. * (3. + 1.)) - self.assertArraysEqual(observed, expected) - - def test_bwd_interpretable(self): - shape = (2, 3) - a = 2. * jnp.ones(shape) - b = 3. * jnp.ones(shape) - - def loss(a, b): - return jnp.sum(foo(a, b)) - - da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b) - da_expected = b + 1 - db_expected = a - self.assertArraysEqual(da_observed, da_expected) - self.assertArraysEqual(db_observed, db_expected) - - def test_fwd_random(self): - shape = (2, 3) - akey, bkey = jax.random.split(jax.random.key(0)) - a = jax.random.normal(key=akey, shape=shape) - b = jax.random.normal(key=bkey, shape=shape) - observed = jax.jit(foo)(a, b) - expected = a * (b + 1) - self.assertAllClose(observed, expected) - - def test_bwd_random(self): - shape = (2, 3) - akey, bkey = jax.random.split(jax.random.key(0)) - a = jax.random.normal(key=akey, shape=shape) - b = jax.random.normal(key=bkey, shape=shape) - jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",)) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 62142fd49034..4179f4bd9ad4 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -1,6 +1,8 @@ cmake_minimum_required(VERSION 3.15...3.30) project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) +option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF) + find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) execute_process( COMMAND "${Python_EXECUTABLE}" @@ -17,3 +19,12 @@ install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc") target_include_directories(_attrs PUBLIC ${XLA_DIR}) install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) + +if(JAX_FFI_EXAMPLE_ENABLE_CUDA) + enable_language(CUDA) + add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu") + set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON + CUDA_STANDARD 17) + target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR}) + install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +endif() diff --git a/docs/cuda_custom_call/foo.cu.cc b/examples/ffi/src/jax_ffi_example/cuda_e2e.cu similarity index 100% rename from docs/cuda_custom_call/foo.cu.cc rename to examples/ffi/src/jax_ffi_example/cuda_e2e.cu diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.py b/examples/ffi/src/jax_ffi_example/cuda_e2e.py new file mode 100644 index 000000000000..500677050a4b --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/cuda_e2e.py @@ -0,0 +1,68 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An end-to-end example demonstrating the use of the JAX FFI with CUDA. + +The specifics of the kernels are not very important, but the general structure, +and packaging of the extension are useful for testing. +""" + +import os +import ctypes + +import numpy as np + +import jax +import jax.numpy as jnp +import jax.extend as jex + +# Load the shared library with the FFI target definitions +SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so") +library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) + +jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd), + platform="CUDA") +jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd), + platform="CUDA") + + +def foo_fwd(a, b): + assert a.dtype == jnp.float32 + assert a.shape == b.shape + assert a.dtype == b.dtype + n = np.prod(a.shape).astype(np.uint64) + out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) + c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n) + return c, (a, b_plus_1) + + +def foo_bwd(res, c_grad): + a, b_plus_1 = res + assert c_grad.dtype == jnp.float32 + assert c_grad.shape == a.shape + assert a.shape == b_plus_1.shape + assert c_grad.dtype == a.dtype + assert a.dtype == b_plus_1.dtype + n = np.prod(a.shape).astype(np.uint64) + out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) + return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1, + n=n) + + +@jax.custom_vjp +def foo(a, b): + c, _ = foo_fwd(a, b) + return c + + +foo.defvjp(foo_fwd, foo_bwd) diff --git a/examples/ffi/tests/attrs_test.py b/examples/ffi/tests/attrs_test.py index 0288b31cf9fa..2eef1f627006 100644 --- a/examples/ffi/tests/attrs_test.py +++ b/examples/ffi/tests/attrs_test.py @@ -24,6 +24,11 @@ class AttrsTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + def test_array_attr(self): self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum()) self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum()) diff --git a/examples/ffi/tests/cuda_e2e_test.py b/examples/ffi/tests/cuda_e2e_test.py new file mode 100644 index 000000000000..83397f7ff5d7 --- /dev/null +++ b/examples/ffi/tests/cuda_e2e_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +jax.config.parse_flags_with_absl() + + +class CudaE2eTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Unsupported platform") + + # Import here to avoid trying to load the library when it's not built. + from jax_ffi_example import cuda_e2e + self.foo = cuda_e2e.foo + + def test_fwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + observed = jax.jit(self.foo)(a, b) + expected = (2. * (3. + 1.)) + self.assertArraysEqual(observed, expected) + + def test_bwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + + def loss(a, b): + return jnp.sum(self.foo(a, b)) + + da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b) + da_expected = b + 1 + db_expected = a + self.assertArraysEqual(da_observed, da_expected) + self.assertArraysEqual(db_observed, db_expected) + + def test_fwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + observed = jax.jit(self.foo)(a, b) + expected = a * (b + 1) + self.assertAllClose(observed, expected) + + def test_bwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + jtu.check_grads(f=jax.jit(self.foo), args=(a, b), order=1, + modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/examples/ffi/tests/rms_norm_test.py b/examples/ffi/tests/rms_norm_test.py index aad5562629ed..bccd696c601e 100644 --- a/examples/ffi/tests/rms_norm_test.py +++ b/examples/ffi/tests/rms_norm_test.py @@ -29,6 +29,11 @@ def rms_norm_ref(x, eps=1e-5): class RmsNormTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + def test_basic(self): x = jnp.linspace(-0.5, 0.5, 15) self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x)) From 7d9f5656473eb3fc0a4265bca3eef999bf07db9a Mon Sep 17 00:00:00 2001 From: Praveen Batra Date: Thu, 31 Oct 2024 09:24:57 -0700 Subject: [PATCH 146/698] [Mosaic] Fix some imports. PiperOrigin-RevId: 691830491 --- jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 3acacd2315b5..34f6241661d5 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" +#include "llvm/include/llvm/ADT/ArrayRef.h" +#include "llvm/include/llvm/ADT/SmallVector.h" #include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" #include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" From ca261ac59f4291b19eb9f5ab4baf7cfd770ae78c Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Thu, 31 Oct 2024 12:32:07 -0400 Subject: [PATCH 147/698] [Mosaic GPU] Improve correctness of benchmarking scripts --- jax/experimental/mosaic/gpu/examples/flash_attention.py | 4 +++- jax/experimental/mosaic/gpu/examples/matmul.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index daacefb135e9..04a64098ff17 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -649,7 +649,9 @@ def ref(q, k, v): matmul_flops = ( 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size ) - peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS + # Table 1 in + # https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper + peak_flops = 989.4 * 1e12 # f16 TensorCore peak optimal_time = matmul_flops / peak_flops * 1e6 # us achieved_tc_util = optimal_time / runtime_us * 100 has_tma_warp = impl == Implementation.TWO_COMPUTE_ONE_TMA_WG diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index ce99bf423bae..2ca22f54e1b4 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -378,7 +378,7 @@ def ref_f(x, y): x, y, dimension_numbers=dimension_numbers, - preferred_element_type=jnp.float32, + preferred_element_type=out_dtype, ).astype(out_dtype) ref, ref_runtime = profiler.measure(ref_f, x, y) From 1f66c29d0585fe64f60683fb9929b8a2f9abe7f5 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 11:39:28 -0500 Subject: [PATCH 148/698] Set runners for ROCM --- .github/workflows/ci-build.yaml | 2 +- .github/workflows/upstream-nightly.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 581fb858732c..7805e3206fcd 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -38,7 +38,7 @@ jobs: build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" - runs-on: linux-x86-n2-32 + runs-on: ROCM-Ubuntu container: image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 60 diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 04df278019a5..ada9b4e5825f 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -22,7 +22,7 @@ on: jobs: upstream-dev: - runs-on: ubuntu-20.04-16core + runs-on: ROCM-Ubuntu permissions: contents: read checks: write # for upload-artifact From a75d94622caa748f388a055a411caf81e7b309bc Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Thu, 31 Oct 2024 10:04:38 -0700 Subject: [PATCH 149/698] Reverts 72f9a493589a1046e6927a5f16d7dc71df530743 PiperOrigin-RevId: 691843537 --- build/requirements.in | 5 - build/requirements_lock_3_10.txt | 168 +++++++++++++++++-------------- build/requirements_lock_3_11.txt | 168 +++++++++++++++++-------------- build/requirements_lock_3_12.txt | 168 +++++++++++++++++-------------- build/requirements_lock_3_13.txt | 16 ++- build/test-requirements.txt | 7 +- 6 files changed, 298 insertions(+), 234 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index a8d81fa5c670..e122aaa4ad78 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -3,11 +3,6 @@ # -r test-requirements.txt -# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement -# below. -matplotlib~=3.8.4; python_version<="3.10" -matplotlib; python_version>="3.11" - # # build deps # diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index adabb0dd2e70..ccffa247f36d 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -12,6 +12,10 @@ attrs==23.2.0 \ --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 @@ -295,7 +299,7 @@ matplotlib==3.8.4 ; python_version <= "3.10" \ --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -323,7 +327,7 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0 \ +numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ @@ -371,7 +375,6 @@ numpy==2.0.0 \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -380,84 +383,93 @@ numpy==2.0.0 \ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 # via + # auditwheel # build # matplotlib # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a +pillow==11.0.0 \ + --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ + --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ + --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ + --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ + --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ + --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ + --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ + --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ + --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ + --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ + --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ + --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ + --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ + --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ + --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ + --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ + --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ + --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ + --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ + --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ + --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ + --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ + --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ + --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ + --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ + --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ + --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ + --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ + --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ + --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ + --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ + --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ + --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ + --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ + --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ + --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ + --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ + --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ + --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ + --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ + --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ + --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ + --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ + --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ + --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ + --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ + --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ + --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ + --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ + --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ + --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ + --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ + --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ + --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ + --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ + --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ + --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ + --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ + --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ + --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ + --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ + --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ + --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ + --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ + --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ + --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ + --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ + --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ + --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ + --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ + --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ + --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ + --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ + --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ + --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 # via # -r build/test-requirements.txt # matplotlib @@ -487,6 +499,10 @@ psutil==5.9.8 \ --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a @@ -621,4 +637,6 @@ zstandard==0.22.0 \ setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 053e996cefad..7f3ee61ff7f6 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -12,6 +12,10 @@ attrs==23.2.0 \ --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 @@ -290,7 +294,7 @@ matplotlib==3.9.0 ; python_version >= "3.11" \ --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -318,7 +322,7 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0 \ +numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ @@ -366,7 +370,6 @@ numpy==2.0.0 \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -375,84 +378,93 @@ numpy==2.0.0 \ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 # via + # auditwheel # build # matplotlib # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a +pillow==11.0.0 \ + --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ + --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ + --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ + --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ + --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ + --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ + --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ + --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ + --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ + --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ + --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ + --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ + --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ + --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ + --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ + --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ + --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ + --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ + --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ + --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ + --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ + --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ + --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ + --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ + --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ + --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ + --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ + --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ + --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ + --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ + --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ + --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ + --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ + --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ + --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ + --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ + --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ + --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ + --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ + --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ + --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ + --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ + --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ + --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ + --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ + --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ + --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ + --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ + --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ + --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ + --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ + --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ + --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ + --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ + --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ + --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ + --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ + --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ + --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ + --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ + --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ + --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ + --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ + --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ + --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ + --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ + --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ + --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ + --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ + --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ + --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ + --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ + --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ + --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ + --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 # via # -r build/test-requirements.txt # matplotlib @@ -482,6 +494,10 @@ psutil==5.9.8 \ --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a @@ -610,4 +626,6 @@ zstandard==0.22.0 \ setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 1468e64c29cd..bf22c3623b47 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -12,6 +12,10 @@ attrs==23.2.0 \ --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 @@ -290,7 +294,7 @@ matplotlib==3.9.0 ; python_version >= "3.11" \ --hash=sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a \ --hash=sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d \ --hash=sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241 - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -318,7 +322,7 @@ mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==2.0.0 \ +numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ @@ -366,7 +370,6 @@ numpy==2.0.0 \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -375,84 +378,93 @@ numpy==2.0.0 \ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 # via + # auditwheel # build # matplotlib # pytest -pillow==10.3.0 \ - --hash=sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c \ - --hash=sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2 \ - --hash=sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb \ - --hash=sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d \ - --hash=sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa \ - --hash=sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3 \ - --hash=sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1 \ - --hash=sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a \ - --hash=sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd \ - --hash=sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8 \ - --hash=sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999 \ - --hash=sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599 \ - --hash=sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936 \ - --hash=sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375 \ - --hash=sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d \ - --hash=sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b \ - --hash=sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60 \ - --hash=sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572 \ - --hash=sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3 \ - --hash=sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced \ - --hash=sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f \ - --hash=sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b \ - --hash=sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19 \ - --hash=sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f \ - --hash=sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d \ - --hash=sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383 \ - --hash=sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795 \ - --hash=sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355 \ - --hash=sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57 \ - --hash=sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09 \ - --hash=sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b \ - --hash=sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462 \ - --hash=sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf \ - --hash=sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f \ - --hash=sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a \ - --hash=sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad \ - --hash=sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9 \ - --hash=sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d \ - --hash=sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45 \ - --hash=sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994 \ - --hash=sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d \ - --hash=sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338 \ - --hash=sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463 \ - --hash=sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451 \ - --hash=sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591 \ - --hash=sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c \ - --hash=sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd \ - --hash=sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32 \ - --hash=sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9 \ - --hash=sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf \ - --hash=sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5 \ - --hash=sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828 \ - --hash=sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3 \ - --hash=sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5 \ - --hash=sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2 \ - --hash=sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b \ - --hash=sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2 \ - --hash=sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475 \ - --hash=sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3 \ - --hash=sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb \ - --hash=sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef \ - --hash=sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015 \ - --hash=sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002 \ - --hash=sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170 \ - --hash=sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84 \ - --hash=sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57 \ - --hash=sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f \ - --hash=sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27 \ - --hash=sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a +pillow==11.0.0 \ + --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ + --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ + --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ + --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ + --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ + --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ + --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ + --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ + --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ + --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ + --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ + --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ + --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ + --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ + --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ + --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ + --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ + --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ + --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ + --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ + --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ + --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ + --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ + --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ + --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ + --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ + --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ + --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ + --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ + --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ + --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ + --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ + --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ + --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ + --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ + --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ + --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ + --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ + --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ + --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ + --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ + --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ + --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ + --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ + --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ + --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ + --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ + --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ + --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ + --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ + --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ + --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ + --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ + --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ + --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ + --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ + --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ + --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ + --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ + --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ + --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ + --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ + --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ + --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ + --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ + --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ + --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ + --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ + --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ + --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ + --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ + --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ + --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ + --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ + --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 # via # -r build/test-requirements.txt # matplotlib @@ -482,6 +494,10 @@ psutil==5.9.8 \ --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a @@ -610,4 +626,6 @@ zstandard==0.22.0 \ setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r build/test-requirements.txt + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 019c088fbd91..9fa78c062ce9 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -12,6 +12,10 @@ attrs==24.2.0 \ --hash=sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346 \ --hash=sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2 # via hypothesis +auditwheel==6.1.0 \ + --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ + --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 + # via -r build/test-requirements.txt build==1.2.2.post1 \ --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 @@ -338,7 +342,7 @@ matplotlib==3.9.2 ; python_version >= "3.11" \ --hash=sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49 \ --hash=sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c \ --hash=sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413 - # via -r build/requirements.in + # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba @@ -426,7 +430,6 @@ numpy==2.1.2 ; python_version >= "3.13" \ --hash=sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648 # via # -r build/requirements.in - # -r build/test-requirements.txt # contourpy # matplotlib # ml-dtypes @@ -434,11 +437,14 @@ numpy==2.1.2 ; python_version >= "3.13" \ opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via -r build/requirements.in + # via + # -r build/requirements.in + # -r build/test-requirements.txt packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 # via + # auditwheel # build # matplotlib # pytest @@ -553,6 +559,10 @@ psutil==6.0.0 \ --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 # via portpicker +pyelftools==0.31 \ + --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ + --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 + # via auditwheel pygments==2.18.0 \ --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a diff --git a/build/test-requirements.txt b/build/test-requirements.txt index bec6afce1853..41a6ed4588a0 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -6,7 +6,6 @@ filelock flatbuffers hypothesis mpmath>=1.3 -numpy>=1.22 pillow>=10.4.0 portpicker pytest-xdist @@ -14,3 +13,9 @@ wheel rich # TODO(ybaturina): remove setuptools version setuptools<71.0.0 +# matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement +# below. +matplotlib~=3.8.4; python_version<="3.10" +matplotlib; python_version>="3.11" +opt-einsum +auditwheel From 8296f6e0ba4a1659f7a8f032d3de1c8666465ad6 Mon Sep 17 00:00:00 2001 From: Praveen Batra Date: Thu, 31 Oct 2024 11:07:52 -0700 Subject: [PATCH 150/698] [Mosaic] Add extension files for infer/apply vector layout. PiperOrigin-RevId: 691868278 --- jaxlib/jax.bzl | 1 + jaxlib/mosaic/BUILD | 18 +++- .../tpu/transforms/apply_vector_layout.cc | 91 ++++++++++--------- .../apply_vector_layout_extensions.h | 21 +++++ .../apply_vector_layout_extensions.cc | 19 ++++ .../infer_vector_layout_extensions.cc | 13 +++ .../tpu/transforms/infer_vector_layout.cc | 5 + .../infer_vector_layout_extensions.h | 15 +++ 8 files changed, 137 insertions(+), 46 deletions(-) create mode 100644 jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h create mode 100644 jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc create mode 100644 jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc create mode 100644 jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index d6811bf66b7b..40ec4ca7fe55 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -43,6 +43,7 @@ mosaic_gpu_internal_users = [] mosaic_internal_users = [] pallas_gpu_internal_users = [] pallas_tpu_internal_users = [] +pallas_extension_deps = [] jax_internal_export_back_compat_test_util_visibility = [] jax_internal_test_harnesses_visibility = [] diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 5452520204b8..14f3ee13c0f5 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -1,3 +1,6 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_python//python:defs.bzl", "py_library") + # Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,9 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "pallas_extension_deps") licenses(["notice"]) @@ -41,6 +42,7 @@ cc_library( "dialect/tpu/tpu_dialect.cc", "dialect/tpu/tpu_ops.cc", "dialect/tpu/util.cc", + ":extension_srcs", ] + glob([ "dialect/tpu/transforms/*.cc", ]), @@ -83,7 +85,7 @@ cc_library( "@xla//xla:array", "@xla//xla:shape_util", "@xla//xla:util", - ], + ] + pallas_extension_deps, ) gentbl_cc_library( @@ -226,3 +228,11 @@ cc_library( ], alwayslink = True, ) + +filegroup( + name = "extension_srcs", + srcs = [ + "dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc", + "dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc", + ], +) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 15660aa6cd0a..b4b2280ceea8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -13,16 +13,15 @@ #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -52,7 +51,6 @@ #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" -#include "llvm/ADT/ArrayRef.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" @@ -61,6 +59,7 @@ #include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" @@ -4586,45 +4585,53 @@ LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op, } const llvm::StringMap &rules() { - static auto rules = new llvm::StringMap{ - {arith::ConstantOp::getOperationName(), arith_constant_rule}, - {arith::ExtFOp::getOperationName(), arith_extf_rule}, - {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, - {arith::ExtUIOp::getOperationName(), arith_extui_rule}, - {arith::TruncFOp::getOperationName(), arith_truncf_rule}, - {arith::TruncIOp::getOperationName(), arith_trunci_rule}, - {func::ReturnOp::getOperationName(), func_return_rule}, - {scf::ForOp::getOperationName(), scf_for_rule}, - {scf::WhileOp::getOperationName(), scf_while_rule}, - {scf::ConditionOp::getOperationName(), scf_condition_rule}, - {scf::IfOp::getOperationName(), scf_if_rule}, - {scf::YieldOp::getOperationName(), yield_rule}, - {tpu::YieldOp::getOperationName(), yield_rule}, - {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, - {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, - {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, - {tpu::IotaOp::getOperationName(), tpu_iota_rule}, - {tpu::GatherOp::getOperationName(), tpu_gather_rule}, - {tpu::LoadOp::getOperationName(), tpu_load_rule}, - {tpu::StoreOp::getOperationName(), tpu_store_rule}, - {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, - {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, - {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, - {tpu::RegionOp::getOperationName(), tpu_region_rule}, - {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, - {tpu::TraceOp::getOperationName(), tpu_trace_rule}, - {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, - {tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule}, - {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, - {vector::ExtractOp::getOperationName(), vector_extract_rule}, - {vector::LoadOp::getOperationName(), vector_load_rule}, - {vector::MultiDimReductionOp::getOperationName(), - vector_multi_reduction_rule}, - {vector::ExtractStridedSliceOp::getOperationName(), - vector_extract_strided_slice_rule}, - {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, - {vector::StoreOp::getOperationName(), vector_store_rule}, - {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; + static const llvm::StringMap *rules = [] { + static auto rules = new llvm::StringMap{ + {arith::ConstantOp::getOperationName(), arith_constant_rule}, + {arith::ExtFOp::getOperationName(), arith_extf_rule}, + {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, + {arith::ExtUIOp::getOperationName(), arith_extui_rule}, + {arith::TruncFOp::getOperationName(), arith_truncf_rule}, + {arith::TruncIOp::getOperationName(), arith_trunci_rule}, + {func::ReturnOp::getOperationName(), func_return_rule}, + {scf::ForOp::getOperationName(), scf_for_rule}, + {scf::WhileOp::getOperationName(), scf_while_rule}, + {scf::ConditionOp::getOperationName(), scf_condition_rule}, + {scf::IfOp::getOperationName(), scf_if_rule}, + {scf::YieldOp::getOperationName(), yield_rule}, + {tpu::YieldOp::getOperationName(), yield_rule}, + {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, + {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, + {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, + {tpu::IotaOp::getOperationName(), tpu_iota_rule}, + {tpu::GatherOp::getOperationName(), tpu_gather_rule}, + {tpu::LoadOp::getOperationName(), tpu_load_rule}, + {tpu::StoreOp::getOperationName(), tpu_store_rule}, + {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, + {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, + {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, + {tpu::RegionOp::getOperationName(), tpu_region_rule}, + {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, + {tpu::TraceOp::getOperationName(), tpu_trace_rule}, + {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, + {tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule}, + {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, + {vector::ExtractOp::getOperationName(), vector_extract_rule}, + {vector::LoadOp::getOperationName(), vector_load_rule}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_reduction_rule}, + {vector::ExtractStridedSliceOp::getOperationName(), + vector_extract_strided_slice_rule}, + {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, + {vector::StoreOp::getOperationName(), vector_store_rule}, + {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; + + llvm::StringMap extended_rules = mlir::tpu::extensions::rules(); + for (auto &entry : extended_rules) { + rules->insert(&entry); + } + return rules; + }(); return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h new file mode 100644 index 000000000000..33c9e7421004 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -0,0 +1,21 @@ +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ + +#include + +#include "llvm/include/llvm/ADT/StringMap.h" +#include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" + +namespace mlir::tpu::extensions { + +const llvm::StringMap< + std::function, ArrayRef)>> & +rules(); + +} // namespace mlir::tpu::extensions + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc new file mode 100644 index 000000000000..e7528533938f --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -0,0 +1,19 @@ +#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" + +#include "llvm/include/llvm/ADT/StringMap.h" +#include "mlir/include/mlir/IR/Operation.h" + +namespace mlir::tpu::extensions { + +using RewriteContext = ApplyVectorLayoutContext; + +using rule_type = std::function, ArrayRef)>; + +const llvm::StringMap &rules() { + static const llvm::StringMap *rules = + new llvm::StringMap{}; + return *rules; +} + +} // namespace mlir::tpu::extensions \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc new file mode 100644 index 000000000000..a67728076de1 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -0,0 +1,13 @@ +#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" + +#include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/include/mlir/Support/LogicalResult.h" + +namespace mlir::tpu::extensions { + +bool canInferVectorLayout(const Operation &op) { return false; } + +LogicalResult inferVectorLayout(const Operation &op) { return failure(); } + +} // namespace mlir::tpu::extensions \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index b03b26edd11a..3f5d6262d13c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -50,6 +50,7 @@ limitations under the License. #include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" @@ -337,6 +338,10 @@ class VectorLayoutInferer { if (inferElementwise(&any_op).failed()) { return failure(); } + } else if (mlir::tpu::extensions::canInferVectorLayout(any_op)) { + if (mlir::tpu::extensions::inferVectorLayout(any_op).failed()) { + return failure(); + } } else { any_op.emitOpError("unsupported in vector layout inference"); return failure(); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h new file mode 100644 index 000000000000..dc16ddbdf26c --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -0,0 +1,15 @@ +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ + +#include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/Support/LLVM.h" + +namespace mlir::tpu::extensions { + +bool canInferVectorLayout(const Operation &op); + +LogicalResult inferVectorLayout(const Operation &op); + +} // namespace mlir::tpu::extensions + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ From c758373b9c44589d0c648dea432cbe3525609571 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 31 Oct 2024 11:29:27 -0700 Subject: [PATCH 151/698] Remove implicit sharding annotation for tpu custom call. PiperOrigin-RevId: 691876343 --- jax/_src/tpu_custom_call.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 6e7402c20a15..dae1bdd76b33 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -204,9 +204,6 @@ def to_json(self) -> bytes: if i + 1 != len(self.flags): config.write(b",") config.write(b"]") - # Prevent the compiler from sharding the custom call beyond what Mosaic does - # based on user annotations - config.write(b', "implicit_sharding": {"type": "MANUAL"}') config.write(b"}") return config.getvalue() From 8536eca46e21145605228013f9c6bdf07ebbaee5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 31 Oct 2024 13:05:26 -0700 Subject: [PATCH 152/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/edf18ce242f234fbd20d1fbf4e9c96dfa5be2847. PiperOrigin-RevId: 691908973 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4169e30be21b..c42f373418f2 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "2d9d84487ef22d4d5358f20085234c1865b300f1" -XLA_SHA256 = "a737b701870646278c69ab4388a1316be1467301a2a5ddad11978d619e3981d7" +XLA_COMMIT = "edf18ce242f234fbd20d1fbf4e9c96dfa5be2847" +XLA_SHA256 = "14294c2b264cb13102bc16a2f837e9f8ed7ba72eab9e3e9dc036be0b699c3c84" def repo(): tf_http_archive( From 7a15265542d850fd7859bba7619bcae6a7407bd5 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Thu, 31 Oct 2024 15:43:35 -0500 Subject: [PATCH 153/698] Allow devs to kick off sync job manually (#119) --- .github/workflows/rocm-nightly-upstream-sync.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index dcfbc01d1db5..98c958c3daa0 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -3,6 +3,7 @@ name: ROCm Nightly Upstream Sync on: + workflow_dispatch: schedule: - cron: '0 6 * * 1-5' jobs: From 48f24b6acb9fe67dfe227ff3349787b4045c09ff Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Thu, 31 Oct 2024 14:06:08 -0700 Subject: [PATCH 154/698] Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it. PiperOrigin-RevId: 691929385 --- jax/_src/abstract_arrays.py | 9 +- jax/_src/api.py | 8 +- jax/_src/array.py | 2 - jax/_src/core.py | 108 ++++++---------------- jax/_src/interpreters/ad.py | 3 + jax/_src/interpreters/mlir.py | 1 - jax/_src/interpreters/partial_eval.py | 3 +- jax/_src/interpreters/xla.py | 3 +- jax/_src/lax/control_flow/conditionals.py | 7 +- jax/_src/lax/control_flow/loops.py | 7 +- jax/_src/lax/lax.py | 22 ++--- jax/_src/lax/utils.py | 11 +-- jax/_src/lax/windowed_reductions.py | 11 ++- jax/_src/linear_util.py | 2 +- jax/_src/numpy/lax_numpy.py | 8 +- jax/_src/numpy/ufuncs.py | 2 +- jax/_src/state/indexing.py | 7 +- jax/core.py | 2 +- jax/experimental/shard_map.py | 14 ++- jax/experimental/slab/slab.py | 2 +- tests/api_test.py | 12 --- tests/core_test.py | 7 -- 22 files changed, 83 insertions(+), 168 deletions(-) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 9a49a09c7483..95216fb6fcb2 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -24,9 +24,7 @@ from jax._src import traceback_util traceback_util.register_exclusion(__file__) -UnshapedArray = core.UnshapedArray ShapedArray = core.ShapedArray -ConcreteArray = core.ConcreteArray AbstractToken = core.AbstractToken abstract_token = core.abstract_token canonicalize_shape = core.canonicalize_shape @@ -47,8 +45,11 @@ array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic def canonical_concrete_aval(val, weak_type=None): - return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val, - weak_type=weak_type) + weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type + dtype = dtypes.canonicalize_dtype(np.result_type(val)) + dtypes.check_valid_dtype(dtype) + sharding = core._get_abstract_sharding(val) + return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding) def masked_array_error(*args, **kwargs): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " diff --git a/jax/_src/api.py b/jax/_src/api.py index 652542571fa3..0b3cb08715e9 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -56,7 +56,7 @@ from jax._src import traceback_util from jax._src import pjit from jax._src import xla_bridge as xb -from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray +from jax._src.core import eval_jaxpr, ShapedArray from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, flatten_axes, donation_vector, @@ -2188,9 +2188,9 @@ def _infer_src_sharding(src, x) -> Sharding | None: if isinstance(x, array.ArrayImpl): return x.sharding elif isinstance(x, core.Tracer): - aval = core.get_aval(x) - if isinstance(aval, ConcreteArray) and isinstance(aval.val, array.ArrayImpl): - return aval.val.sharding + val = x.to_concrete_value() + if val is not None and isinstance(val, array.ArrayImpl): + return val.sharding return None diff --git a/jax/_src/array.py b/jax/_src/array.py index 2f29f137675b..30fedf4cff50 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1184,7 +1184,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed): global_aval, out_sharding, committed=committed, _skip_checks=True ) pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler -pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler # Only used for Arrays that come out of pmap. def _array_local_result_handler(aval, sharding, indices): @@ -1197,7 +1196,6 @@ def _array_local_result_handler(aval, sharding, indices): aval, sharding, committed=True, _skip_checks=True ) pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler -pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler # Token handlers diff --git a/jax/_src/core.py b/jax/_src/core.py index 43cb5cc1e248..7d912e3c207b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -19,7 +19,7 @@ from contextlib import contextmanager, ExitStack from dataclasses import dataclass import functools -from functools import partial, partialmethod, total_ordering +from functools import partial, total_ordering import gc import inspect import itertools as it @@ -696,6 +696,10 @@ def __reversed__(self): def __len__(self): return self.aval._len(self) + def to_concrete_value(self): + # Should return the concrete value if there is one, or else None. + return None + @property def sharding(self): # This attribute is part of the jax.Array API, but only defined on concrete arrays. @@ -739,10 +743,12 @@ def get_referent(self) -> Any: return self # Override for object equivalence checking def __bool__(self): + if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_bool_conversion(self) return self.aval._bool(self) def __int__(self): + if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_scalar_conversion(self) return self.aval._int(self) @@ -755,14 +761,17 @@ def __complex__(self): return self.aval._complex(self) def __hex__(self): + if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._hex(self) def __oct__(self): + if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._oct(self) def __index__(self): + if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) return self.aval._index(self) @@ -1393,12 +1402,16 @@ def get_aval(x): else: return concrete_aval(x) -def get_type(x): - aval = get_aval(x) - if isinstance(aval, ConcreteArray): - return raise_to_shaped(aval) +get_type = get_aval + +def is_concrete(x): + return to_concrete_value(x) is not None + +def to_concrete_value(x): + if isinstance(x, Tracer): + return x.to_concrete_value() else: - return aval + return x def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) @@ -1423,10 +1436,11 @@ def concrete_or_error(force: Any, val: Any, context=""): if force is None: force = lambda x: x if isinstance(val, Tracer): - if isinstance(val.aval, ConcreteArray): - return force(val.aval.val) - else: + maybe_concrete = val.to_concrete_value() + if maybe_concrete is None: raise ConcretizationTypeError(val, context) + else: + return force(maybe_concrete) else: return force(val) @@ -1578,7 +1592,7 @@ def _invalid_shape_error(shape: Shape, context: str=""): msg += f" {context}." if not config.dynamic_shapes.value and any( isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) - and not isinstance(get_aval(x), ConcreteArray) for x in shape): + and not is_concrete(x) for x in shape): msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " "smaller subfunctions.") for x in shape: @@ -1677,10 +1691,6 @@ def _get_shape_sharding_str(shape, spec): else: yield f"{s1}@{s2}" - -def _forward_to_value(self, fun, ignored_tracer, *args): - return fun(self.val, *args) - def _get_abstract_sharding(val): from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error @@ -1690,59 +1700,6 @@ def _get_abstract_sharding(val): val.sharding._normalized_spec(val.ndim)) return None -class ConcreteArray(ShapedArray): - __slots__ = ['val'] - array_abstraction_level = 0 - - def __init__(self, dtype, val, weak_type=None): - super().__init__( - np.shape(val), dtype, - weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type, - sharding=_get_abstract_sharding(val)) - dtypes.check_valid_dtype(self.dtype) - # Note: canonicalized self.dtype doesn't necessarily match self.val - assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype) - self.val = val - - def update(self, dtype=None, val=None, weak_type=None): - dtype = self.dtype if dtype is None else dtype - val = self.val if val is None else val - weak_type = self.weak_type if weak_type is None else weak_type - return ConcreteArray(dtype, val, weak_type) - - def __eq__(self, other): - if (type(self) is type(other) and self.dtype == other.dtype - and self.shape == other.shape and self.weak_type == other.weak_type): - with eval_context(): # in case self.val is an Array - return (self.val == other.val).all() - else: - return False - - def __hash__(self): - return id(self.val) - - def join(self, other) -> AbstractValue: - if self == other: - return self - elif self.shape == other.shape and self.dtype == other.dtype: - weak_type = self.weak_type and other.weak_type - return ShapedArray(self.shape, self.dtype, weak_type=weak_type) - else: - raise TypeError(self, other) - - def str_short(self, short_dtypes=False) -> str: - dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - return f'{self.val}, dtype={dt_str}' - - _bool = partialmethod(_forward_to_value, bool) - _int = partialmethod(_forward_to_value, int) - _hex = partialmethod(_forward_to_value, hex) - _oct = partialmethod(_forward_to_value, oct) - _index = partialmethod(_forward_to_value, operator.index) - - _float = concretization_function_error(float, True) - _complex = concretization_function_error(complex, True) - def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): return primal_dtype._rules.tangent_dtype(primal_dtype) @@ -1817,14 +1774,6 @@ def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) -class DConcreteArray(DShapedArray): - __slots__ = ['val'] - array_abstraction_level = 1 - def __init__(self, shape, dtype, weak_type, val): - super().__init__(shape, dtype, weak_type) - self.val = val - - pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} @@ -1881,8 +1830,7 @@ def data(self): pytype_aval_mappings[DArray] = \ - lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, - x._data) + lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) @dataclass(frozen=True) class bint(dtypes.ExtendedDType): @@ -1984,10 +1932,7 @@ def _shaped_array_mapping(aval, weak_type): AbstractToken: lambda aval, _: aval, Bot: lambda aval, _: aval, ShapedArray: _shaped_array_mapping, - DShapedArray: lambda aval, _: aval, - DConcreteArray: lambda aval, weak_type: DShapedArray( - aval.shape, aval.dtype, weak_type - ), + DShapedArray: lambda aval, _: aval } ### Operations on shapes and dimension sizes. @@ -2323,7 +2268,6 @@ def _unmap_dshaped_array( aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { DShapedArray: (_map_dshaped_array, _unmap_dshaped_array), ShapedArray: (_map_shaped_array, _unmap_shaped_array), - ConcreteArray: (_map_shaped_array, _unmap_shaped_array), AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a) } diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index b9cace3dec70..47c7882372ab 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -429,6 +429,9 @@ def full_lower(self): else: return self + def to_concrete_value(self): + return core.to_concrete_value(self.primal) + def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c71e52385386..10154bbd661e 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -230,7 +230,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes: raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err ir_type_handlers[core.ShapedArray] = _array_ir_types -ir_type_handlers[core.ConcreteArray] = _array_ir_types ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get() ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 9e5e1ee9bd42..2f63eb386029 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -40,7 +40,7 @@ fun_sourceinfo) from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - ConcreteArray, Var, DropVar, raise_to_shaped, Atom, + Var, DropVar, raise_to_shaped, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) @@ -299,7 +299,6 @@ def process_call(self, primitive, f, tracers, params): # With dynamic shapes, we may need to substitute Tracers into avals. out_tracers = [] for aval, _ in out_type: - assert not isinstance(aval, ConcreteArray) if type(aval) is DShapedArray: shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val] if type(d) is InDBIdx else d for d in aval.shape] diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 14635a46ea33..46bc7bef7ca7 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -25,7 +25,7 @@ from jax._src import core from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -101,7 +101,6 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: _xla_shape_handlers: dict[type[core.AbstractValue], Callable[[Any], Sequence[xc.Shape]]] = { ShapedArray: _make_array_shape, - ConcreteArray: _make_array_shape, } _xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index d189dc0bd2cf..8dae3433e4f6 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -35,7 +35,7 @@ from jax._src import util from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect -from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects +from jax._src.core import raise_to_shaped, replace_jaxpr_effects from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -130,8 +130,7 @@ def switch(index, branches, *operands): hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) - if (config.disable_jit.value and - isinstance(core.get_aval(index), ConcreteArray)): + if (config.disable_jit.value and core.is_concrete(index)): return branches[int(index)](*operands) ops, ops_tree = tree_flatten(operands) @@ -220,7 +219,7 @@ def cond(pred, true_fun, false_fun, *operands): msg = ("Pred type must be either boolean or number, got {}.") raise TypeError(msg.format(pred_dtype)) - if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray): + if config.disable_jit.value and core.is_concrete(pred): if pred: return true_fun(*operands) else: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 6d6338b0bfd5..ddbbe0213f6f 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -35,7 +35,7 @@ from jax._src import state from jax._src import util from jax._src.api_util import shaped_abstractify -from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped +from jax._src.core import ShapedArray, raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -2015,12 +2015,11 @@ def fori_loop(lower, upper, body_fun, init_val): # If we can specialize on the trip count, call scan instead of a while_loop # to enable efficient reverse-mode differentiation. - if (isinstance(core.get_aval(lower), ConcreteArray) and - isinstance(core.get_aval(upper), ConcreteArray)): + if core.is_concrete(lower) and core.is_concrete(upper): try: lower_ = int(lower) upper_ = int(upper) - except TypeError: + except (TypeError, core.InconclusiveDimensionOperation): use_scan = False else: use_scan = True diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2a44d9ec980d..e6dbcbb12a1c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -47,7 +47,7 @@ from jax._src import state from jax._src import util from jax._src.abstract_arrays import array_types -from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, +from jax._src.core import (Primitive, UnshapedArray, ShapedArray, raise_to_shaped, abstract_token, canonicalize_shape) from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -582,8 +582,7 @@ def _convert_element_type( if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array) and - not (isinstance(operand, core.Tracer) and - isinstance(core.get_aval(operand), core.ConcreteArray)) and + not (isinstance(operand, core.Tracer) and core.is_concrete(operand)) and (sharding is None or getattr(operand, 'sharding', None) == sharding)): return operand else: @@ -1438,23 +1437,24 @@ def _get_monoid_reducer(monoid_op: Callable, x, = xs aval = core.get_aval(x) dtype = _dtype(x) - if (type(aval) is ConcreteArray) and aval.shape == (): + if core.is_concrete(x) and aval.shape == (): + val = core.to_concrete_value(x) # allow bitwise reductions for boolean and integer types _is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer) if monoid_op is add: - return _reduce_sum if np.equal(aval.val, 0) else None + return _reduce_sum if np.equal(val, 0) else None elif monoid_op is mul: - return _reduce_prod if np.equal(aval.val, 1) else None + return _reduce_prod if np.equal(val, 1) else None elif monoid_op is bitwise_or and _is_intlike: - return _reduce_or if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None + return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is bitwise_and and _is_intlike: - return _reduce_and if np.equal(aval.val, _get_bitwise_and_identity(dtype)) else None + return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None elif monoid_op is bitwise_xor and _is_intlike: - return _reduce_xor if np.equal(aval.val, _get_bitwise_or_identity(dtype)) else None + return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is max: - return _reduce_max if np.equal(aval.val, _get_max_identity(dtype)) else None + return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None elif monoid_op is min: - return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None + return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None return None def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray: diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index deb3c19c0a61..82804c796e6e 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -52,10 +52,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = type(max(avals, key=_get_array_abstraction_level)) - if least_specialized is core.ConcreteArray: - out = prim.impl(*[x.val for x in avals], **kwargs) - return core.ConcreteArray(out.dtype, out, weak_type=weak_type) - elif least_specialized is core.ShapedArray: + if least_specialized is core.ShapedArray: return core.ShapedArray( shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, @@ -77,11 +74,7 @@ def standard_multi_result_abstract_eval( assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = max(map(type, avals), key=_get_array_abstraction_level) weak_types = weak_type_rule(*avals, **kwargs) - if least_specialized is core.ConcreteArray: - out_vals = prim.impl(*[x.val for x in avals], **kwargs) - return [core.ConcreteArray(val.dtype, val, weak_type=weak_type) - for val, weak_type in zip(out_vals, weak_types)] - elif least_specialized is core.ShapedArray: + if least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) return [core.ShapedArray(s, d, weak_type=weak_type) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 089a77de2949..462e5fbed1c5 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -23,7 +23,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import util -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -142,14 +142,15 @@ def _get_monoid_window_reducer( return None x, = xs aval = core.get_aval(x) - if (type(aval) is ConcreteArray) and aval.shape == (): + if core.is_concrete(x) and aval.shape == (): + val = core.to_concrete_value(x) if monoid_op is lax.add: - return aval.val == 0 and _reduce_window_sum + return val == 0 and _reduce_window_sum elif monoid_op is lax.max: - return (aval.val == lax._get_max_identity(aval.dtype) + return (val == lax._get_max_identity(aval.dtype) and _reduce_window_max) elif monoid_op is lax.min: - return (aval.val == lax._get_min_identity(aval.dtype) + return (val == lax._get_min_identity(aval.dtype) and _reduce_window_min) return None diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index dd8f671c639c..08f94c6e8eda 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -276,7 +276,7 @@ def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed assert type(in_type) is tuple and all(type(e) is tuple for e in in_type) assert all(isinstance(a, core.AbstractValue) and type(b) is bool - and not isinstance(a, core.ConcreteArray) for a, b in in_type) + for a, b in in_type) def valid_size(d) -> bool: if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ee33be8a10d8..f79d6bc0758f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -52,7 +52,7 @@ from jax._src import xla_bridge from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl -from jax._src.core import ConcreteArray, ShapedArray +from jax._src.core import ShapedArray from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax as lax_internal from jax._src.lax.lax import ( PrecisionLike,_array_copy, @@ -11789,7 +11789,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], except TypeError: abstract_i = None # Handle basic int indexes. - if isinstance(abstract_i, (ConcreteArray, ShapedArray)) and _int(abstract_i): + if isinstance(abstract_i, ShapedArray) and _int(abstract_i): if core.definitely_equal(x_shape[x_axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") @@ -11945,7 +11945,7 @@ def _expand_bool_indices(idx, shape): i = array(i) abstract_i = core.get_aval(i) - if not type(abstract_i) is ConcreteArray: + if not core.is_concrete(i): # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete raise errors.NonConcreteBooleanIndexError(abstract_i) elif _ndim(i) == 0: @@ -11975,7 +11975,7 @@ def _is_slice_element_none_or_constant_or_symbolic(elt): if elt is None: return True if core.is_symbolic_dim(elt): return True try: - return type(core.get_aval(elt)) is ConcreteArray + return core.is_concrete(elt) except TypeError: return False diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index ade9cb2062b8..8692c30a3e17 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2512,7 +2512,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # lax.pow. # Case 1: concrete integer scalar powers: - if isinstance(core.get_aval(x2), core.ConcreteArray): + if core.is_concrete(x2): try: x2 = operator.index(x2) # type: ignore[arg-type] except TypeError: diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index cb653547baff..538f3f8e4888 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -123,12 +123,7 @@ def _maybe_concretize(x: Any): # This is roughly the same logic as core.concrete_or_error, but we avoid # calling that because constructing the ConcretizationTypeError can be # expensive as the size of the tracing context (i.e. the jaxpr) grows. - if isinstance(x, core.Tracer): - if isinstance(x.aval, core.ConcreteArray): - return x.aval.val - else: - return None - return x + return core.to_concrete_value(x) @tree_util.register_pytree_node_class @dataclasses.dataclass diff --git a/jax/core.py b/jax/core.py index 6869f747b0d8..2880e42c681b 100644 --- a/jax/core.py +++ b/jax/core.py @@ -24,7 +24,6 @@ AxisName as AxisName, CallPrimitive as CallPrimitive, ClosedJaxpr as ClosedJaxpr, - ConcreteArray as ConcreteArray, ConcretizationTypeError as ConcretizationTypeError, DShapedArray as DShapedArray, DropVar as DropVar, @@ -84,6 +83,7 @@ get_aval as get_aval, get_type as get_type, get_referent as get_referent, + is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, jaxpr_as_fun as jaxpr_as_fun, diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 2fa028b2fe1e..615fd3128309 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -911,13 +911,14 @@ def __init__(self, trace, rep, val): @property def aval(self): aval = core.get_aval(self.val) - if (isinstance(aval, core.ConcreteArray) and - self.rep == set(self._trace.mesh.axis_names)): + return core.mapped_aval(self._trace.mesh.size, 0, aval) + + def to_concrete_value(self): + if self.rep == set(self._trace.mesh.axis_names): with core.eval_context(): - return core.get_aval(self.val[0]) + return core.to_concrete_value(self.val[0]) else: - aval = core.raise_to_shaped(aval) - return core.mapped_aval(self._trace.mesh.size, 0, aval) + return None def __str__(self) -> str: with core.eval_context(): @@ -1768,6 +1769,9 @@ def __init__(self, trace, rep, val): def aval(self) -> core.AbstractValue: return core.get_aval(self.val) + def to_concrete_value(self): + return core.to_concrete_value(self.val) + def __str__(self) -> str: return str(self.val) # TODO(mattjj): could show replication info here __repr__ = __str__ # for debuggers, like `p x` diff --git a/jax/experimental/slab/slab.py b/jax/experimental/slab/slab.py index af7b079eeb7f..8324e4c55457 100644 --- a/jax/experimental/slab/slab.py +++ b/jax/experimental/slab/slab.py @@ -89,7 +89,7 @@ def xprod(xs: Iterable[XInt]) -> XInt: return xmul(*list(xs)) def static_int(x: XInt) -> bool: - return isinstance(core.get_aval(x), core.ConcreteArray) + return core.is_concrete(x) def static_shape(s: DShape) -> bool: return all(map(static_int, s)) diff --git a/tests/api_test.py b/tests/api_test.py index 197784d99772..e98f4299c5e1 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3691,18 +3691,6 @@ def g(x): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): g(1) - def test_join_concrete_arrays_with_omnistaging(self): - # https://github.com/jax-ml/jax/issues/4622 - x = jnp.array([1., 2., 3.]) - y = jnp.array([1., 2., 4.]) - - @jit - def f(): - core.lattice_join(core.ConcreteArray(x.dtype, x), - core.ConcreteArray(y.dtype, y)) - - f() # doesn't crash - def test_linearize_aux(self): def fn(x): return x * 2 - 3, x > 0 diff --git a/tests/core_test.py b/tests/core_test.py index 38700037248d..1471e334c880 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -347,13 +347,6 @@ def g_vmap(x): 'This BatchTracer with object id'): g_vmap(jnp.ones((1, ))) - def test_concrete_array_string_representation(self): - # https://github.com/jax-ml/jax/issues/5364 - self.assertEqual( - str(core.ConcreteArray(np.dtype(np.int32), - np.array([1], dtype=np.int32))), - 'ConcreteArray([1], dtype=int32)') - def test_dropvar_avals(self): def f(x): def body(c, _): From 7af7a60dcc923df2c0c9821356132523247a447e Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 31 Oct 2024 14:37:02 -0700 Subject: [PATCH 155/698] [Pallas:TPU] Use arith.divui for uint32 div. PiperOrigin-RevId: 691939453 --- jax/_src/pallas/mosaic/lowering.py | 2 +- tests/pallas/ops_test.py | 16 ---------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 0e1fe9a5b56a..054ea4fa80e0 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1934,7 +1934,7 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out - if jnp.issubdtype(aval_out.dtype, jnp.integer): + if jnp.issubdtype(aval_out.dtype, jnp.signedinteger): return arith.divsi(x, y) if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger): return arith.divui(x, y) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 598377b75a22..aeb93860ee92 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1088,14 +1088,6 @@ def test_binary(self, f, dtype): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - # TODO(ayx): Fix these operations on TPU - if ( - jtu.test_device_matches(["tpu"]) - and f in (jnp.floor_divide, jnp.subtract) - and dtype == "uint32" - ): - self.skipTest("Not supported on TPU") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1 ) @@ -1121,14 +1113,6 @@ def test_binary_scalar(self, f, dtype): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - # TODO(ayx): Fix these operations on TPU - if ( - jtu.test_device_matches(["tpu"]) - and f in (jnp.floor_divide, jnp.subtract) - and dtype == "uint32" - ): - self.skipTest("Not supported on TPU") - @functools.partial( self.pallas_call, in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), From 17ad8a9582d69670ebf882e4ac325abe095624ce Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 31 Oct 2024 14:53:28 -0700 Subject: [PATCH 156/698] [array api] update test suite to latest commit --- .github/workflows/jax-array-api.yml | 2 +- tests/array_api_skips.txt | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 648ea0bbe26c..942034169e09 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09 + ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2ac2edcdfd99..4646d87f096c 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -7,8 +7,5 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking -# Returns wrong zero sign -array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] - # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted From 2b9c73d10d08d7415337ab71cb7718022e89c408 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 31 Oct 2024 15:40:54 -0700 Subject: [PATCH 157/698] Remove a number of expired deprecations. These APIs were all removed 3 or more months ago, and the registrations here cause them to raise informative AttributeErrors. Enough time has passed now that we can remove these. --- jax/__init__.py | 7 ------- jax/core.py | 22 ---------------------- jax/interpreters/ad.py | 11 ----------- jax/interpreters/xla.py | 36 ------------------------------------ jax/lax/__init__.py | 13 ------------- jax/nn/__init__.py | 14 -------------- jax/numpy/__init__.py | 5 ----- jax/random.py | 17 ----------------- 8 files changed, 125 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 4f5c256b0c9d..7916ef0e3962 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -223,13 +223,6 @@ "jax.clear_backends is deprecated.", _deprecated_clear_backends ), - # Remove after jax 0.4.35 release. - "xla_computation": ( - "jax.xla_computation is deleted. Please use the AOT APIs; see " - "https://jax.readthedocs.io/en/latest/aot.html. For example, replace " - "xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See " - "CHANGELOG.md for 0.4.30 for more examples.", None - ), } import typing as _typing diff --git a/jax/core.py b/jax/core.py index 2880e42c681b..fb08763fd3a1 100644 --- a/jax/core.py +++ b/jax/core.py @@ -147,28 +147,6 @@ "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), - # Finalized 2024-05-13; remove after 2024-08-13 - "DimSize": ( - "jax.core.DimSize is deprecated. Use DimSize = int | Any.", - None, - ), - "Shape": ( - "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].", - None, - ), - # Finalized 2024-06-24; remove after 2024-09-24 - "canonicalize_shape": ( - "jax.core.canonicalize_shape is deprecated.", None, - ), - "dimension_as_value": ( - "jax.core.dimension_as_value is deprecated. Use jnp.array.", None, - ), - "definitely_equal": ( - "jax.core.definitely_equal is deprecated. Use ==.", None, - ), - "symbolic_equal_dim": ( - "jax.core.symbolic_equal_dim is deprecated. Use ==.", None, - ), # Added Jan 8, 2024 "non_negative_dim": ( "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim, diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 160a96fae368..4ded4a803ae0 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -68,17 +68,6 @@ zeros_like_p as zeros_like_p, ) -_deprecations = { - # Finalized Mar 18, 2024; remove after June 18, 2024 - "config": ( - "jax.interpreters.ad.config is deprecated. Use jax.config directly.", - None, - ), - "source_info_util": ( - "jax.interpreters.ad.source_info_util is deprecated. Use jax.extend.source_info_util.", - None, - ), -} def backward_pass(jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 2711bcfb80d5..b3a470f5e049 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -42,42 +42,6 @@ ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " "Use jax.lib.xla_extension instead."), None ), - # Finalized 2024-05-13; remove after 2024-08-13 - "backend_specific_translations": ( - "jax.interpreters.xla.backend_specific_translations is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "translations": ( - "jax.interpreters.xla.translations is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "register_translation": ( - "jax.interpreters.xla.register_translation is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "xla_destructure": ( - "jax.interpreters.xla.xla_destructure is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "TranslationRule": ( - "jax.interpreters.xla.TranslationRule is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "TranslationContext": ( - "jax.interpreters.xla.TranslationContext is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), - "XlaOp": ( - "jax.interpreters.xla.XlaOp is deprecated. " - "Register custom primitives via jax.interpreters.mlir instead.", - None, - ), } from jax._src.deprecations import deprecation_getattr as _deprecation_getattr diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 5f3bfa057912..d2fb6a9bae3c 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -377,16 +377,3 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p - - -_deprecations = { - # Finalized 2024-05-13; remove after 2024-08-13 - "tie_in": ( - "jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. " - "Replace z = tie_in(x, y) with z = y.", None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 496d03261384..ebe725c448ee 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -49,17 +49,3 @@ squareplus as squareplus, mish as mish, ) - -# Deprecations - -_deprecations = { - # Finalized 2024-05-13; remove after 2024-08-13 - "normalize": ( - "jax.nn.normalize is deprecated. Use jax.nn.standardize instead.", - None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 93405cc03ef7..9be73e96adcf 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -472,11 +472,6 @@ "jnp.round_ is deprecated; use jnp.round instead.", round ), - # Deprecated 18 Sept 2023 and removed 06 Feb 2024 - "trapz": ( - "jnp.trapz is deprecated; use jnp.trapezoid instead.", - None - ), } import typing diff --git a/jax/random.py b/jax/random.py index 29a625389811..b99cd531f18c 100644 --- a/jax/random.py +++ b/jax/random.py @@ -251,20 +251,3 @@ weibull_min as weibull_min, wrap_key_data as wrap_key_data, ) - -_deprecations = { - # Finalized Jul 26 2024; remove after Nov 2024. - "shuffle": ( - "jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.", - None, - ) -} - -import typing -if typing.TYPE_CHECKING: - pass -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing From 467bd09f03065d1d14f3e35c959aaab7739c6c40 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Thu, 31 Oct 2024 12:46:03 -0700 Subject: [PATCH 158/698] Add a register_dataclass example to the pytree tutorial. --- docs/jit-compilation.md | 2 ++ docs/working-with-pytrees.md | 43 ++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 59c7bbd8fb90..51322fda9476 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -192,6 +192,8 @@ def g_inner_jitted(x, n): g_inner_jitted(10, 20) ``` +(jit-marking-arguments-as-static)= + ## Marking arguments as static If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md index e41179996bc4..537a4df3e5a6 100644 --- a/docs/working-with-pytrees.md +++ b/docs/working-with-pytrees.md @@ -272,6 +272,49 @@ jax.tree.leaves([ Notice that the `name` field now appears as a leaf, because all tuple elements are children. This is what happens when you don't have to register the class the hard way. +Unlike `NamedTuple` subclasses, classes decorated with `@dataclass` are not automatically pytrees. However, they can be registered as pytrees using the {func}`jax.tree_util.register_dataclass` decorator: + +```{code-cell} +from dataclasses import dataclass +import functools + +@functools.partial(jax.tree_util.register_dataclass, + data_fields=['a', 'b', 'c'], + meta_fields=['name']) +@dataclass +class MyDataclassContainer(object): + name: str + a: Any + b: Any + c: Any + +# MyDataclassContainer is now a pytree node. +jax.tree.leaves([ + MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])), + MyDataclassContainer('banana', np.array([3, 4]), -1., 0.) +]) +``` + +Notice that the `name` field does not appear as a leaf. This is because we included it in the `meta_fields` argument to {func}`jax.tree_util.register_dataclass`, indicating that it should be treated as metadata/auxiliary data, just like `aux_data` in `RegisteredSpecial` above. Now instances of `MyDataclassContainer` can be passed into JIT-ed functions, and `name` will be treated as static (see {ref}`jit-marking-arguments-as-static` for more information on static args): + +```{code-cell} +@jax.jit +def f(x: MyDataclassContainer | MyOtherContainer): + return x.a + x.b + +# Works fine! `mdc.name` is static. +mdc = MyDataclassContainer('mdc', 1, 2, 3) +y = f(mdc) +``` + +Contrast this with `MyOtherContainer`, the `NamedTuple` subclass. Since the `name` field is a pytree leaf, JIT expects it to be convertible to {class}`jax.Array`, and the following raises an error: + +```{code-cell} +:tags: [raises-exception] + +moc = MyOtherContainer('moc', 1, 2, 3) +y = f(moc) +``` (pytree-and-jax-transformations)= ## Pytrees and JAX transformations From 423cd2ad5ef82d4680ad84e7137af0633c2b8c4b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 31 Oct 2024 16:27:31 -0700 Subject: [PATCH 159/698] Simplified conditional in flash attention. PiperOrigin-RevId: 691972341 --- jax/_src/cudnn/fused_attention_stablehlo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index e20271f66301..8ccf08ec643c 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -295,8 +295,8 @@ def check_is_flash_attention( _, T, _, H = query.shape _, S, _, _ = key.shape - if not ((H <= 128 and H % 8 == 0) and - (not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)): + if (H > 128 or H % 8 != 0 or + (is_training and has_bias and (T % 2 != 0 or S % 2 != 0))): # check if flash attention is supported # for training, for patterns with bias, seqlen should be divisible by 2 raise NotImplementedError( From 84c8794b30e5e79fb2fa9ef49bb0c1e5b7d4a0cf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 31 Oct 2024 17:16:09 -0700 Subject: [PATCH 160/698] Add a JaxIrContext that subclasses mlir.ir.Context and avoids calling ir.Context's __init__. mlir.ir.Context has the unfortunate behavior that it loads all dialects linked into the binary, even those we have no intention of using. This is fairly benign in JAX's usual configuration, but if JAX is linked together with other MLIR-using software it can be problematic. PiperOrigin-RevId: 691984229 --- jax/_src/interpreters/mlir.py | 11 ++++++++++- jax/_src/pallas/mosaic/pallas_call_registration.py | 2 +- jax/_src/pallas/triton/lowering.py | 3 ++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 10154bbd661e..2c0e26019e4d 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -586,9 +586,18 @@ def module_to_bytecode(module: ir.Module) -> bytes: return output.getvalue() # Translation rules + +class JaxIrContext(ir.Context): + def __init__(self, *args, **kwargs): + # Note: we're very intentionally *not* calling the __init__() of our + # immediate superclass ir.Context, whose __init__() has the unfortunate side + # effect of loading all the dialects linked into the binary into the + # context. We want to ensure that only the dialects we need are loaded. + super(ir.Context, self).__init__(*args, **kwargs) + def make_ir_context() -> ir.Context: """Creates an MLIR context suitable for JAX IR.""" - context = ir.Context() + context = JaxIrContext() context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index e34b5dbdd162..4382cea914f0 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -130,7 +130,7 @@ def pallas_call_tpu_lowering_rule( if axis_context is not None: if isinstance(axis_context, sharding_impls.SPMDAxisContext): mesh = axis_context.mesh - mlir_ctx = ir.Context() + mlir_ctx = mlir.JaxIrContext() mlir_ctx.append_dialect_registry(mlir.upstream_dialects) mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d3ca18ee507c..0a9c0a197740 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -248,7 +248,8 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): def _new_ir_context() -> ir.Context: - ctx = ir.Context() + ctx = mlir.JaxIrContext() + ctx.append_dialect_registry(mlir.upstream_dialects) tt_dialect.register_dialect(ctx) ctx.load_all_available_dialects() return ctx From f60b97cea1581059dfac5775448b316f2100eda3 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 31 Oct 2024 17:33:50 -0700 Subject: [PATCH 161/698] [Pallas TPU] Add lowering for `lax.nextafter` Also improved the corresponding test cases to ensure better coverage and accuracy. This PR is similar to https://github.com/jax-ml/jax/pull/22283, which adds lowering for `lax.sign`. PiperOrigin-RevId: 691988164 --- jax/_src/pallas/mosaic/lowering.py | 9 +++ jax/_src/pallas/triton/lowering.py | 2 +- jax/_src/pallas/utils.py | 88 ++++++++++++++++++++++++++++++ tests/pallas/ops_test.py | 30 ++++++---- 4 files changed, 117 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 054ea4fa80e0..f9014da221eb 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1997,6 +1997,15 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.sign_p] = _sign_lowering_rule +def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): + return lower_fun( + pallas_utils.nextafter_lowering_helper, multiple_results=False, + )(ctx, x, y) + + +lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule + + def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): return math.rsqrt(x) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 0a9c0a197740..19328b44800b 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -996,7 +996,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): lax.nextafter_p: _make_dispatch_table( "nextafter", cuda=[ - _Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32 ), + _Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32), _Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64), ], rocm=[ diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index e485537216ca..0dc19aa75fb6 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -301,3 +301,91 @@ def sign_lowering_helper(x): return jnp.where(jnp.isnan(x), jnp.nan, out) raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}") + + +# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L1339-L1422 +def nextafter_lowering_helper(x, y): + if x.dtype != y.dtype: + raise ValueError( + "The two inputs to `nextafter` must have the same dtype, but got" + f" {x.dtype} and {y.dtype}" + ) + + if x.dtype not in (jnp.float32, jnp.float64): + raise ValueError( + f"`nextafter` only supports float32 and float64, but got {x.dtype}" + ) + + jnp_float, jnp_uint, np_float, np_uint, np_int = ( + jnp.float32, jnp.uint32, np.float32, np.uint32, np.int32, + ) if x.dtype == jnp.float32 else ( + jnp.float64, jnp.uint64, np.float64, np.uint64, np.int64, + ) + + bitwidth = dtype_bitwidth(x.dtype) + + x_as_int = x.view(jnp_uint) + y_as_int = y.view(jnp_uint) + + # The result is NaN if either "x" or "y" are NaN. + nan_input = jnp.isnan(x) | jnp.isnan(y) + result_for_nan = jnp.full_like(x_as_int, np_float(np.nan).view(np_uint)) + + # The sign bit is the MSB. + sign_bit = jnp_uint(1 << (bitwidth - 1)) + # Discard the sign bit to make the result non-negative. + sign_mask = sign_bit + negated_sign_mask = ~sign_bit + x_abs = x_as_int & negated_sign_mask + y_abs = y_as_int & negated_sign_mask + + # When both "x" and "y" are equal, the result is "y". + x_and_y_are_equal = x == y + result_for_equal = y_as_int + + # When both "x" and "y" are 0, the result is "y". This is a separate case + # from above because "x" and "y" might have a different sign. + zero = jnp.zeros_like(x_as_int) + x_is_zero = x_abs == zero + y_is_zero = y_abs == zero + result_for_both_zero = y_as_int + + x_sign = x_as_int & sign_mask + y_sign = y_as_int & sign_mask + + # If x == 0 && y != 0, we need to return the smallest subnormal number + # signed like "y". + one = jnp.ones_like(x_as_int) + result_for_x_zero_y_non_zero = y_sign | one + + # If the sign of "x" and "y" disagree: + # - we need to make the magnitude of "from" smaller so that it is closer to + # zero. + # + # Otherwise the signs agree: + # - "x" with a magnitude larger than "y" means we need to make the magnitude + # smaller. + # - "x" with a magnitude smaller than "y" means we need to make the magnitude + # larger. + signs_disagree = x_sign != y_sign + x_magnitude_larger_than_y = x_abs > y_abs + result_has_smaller_magnitude = x_magnitude_larger_than_y | signs_disagree + minus_one = jnp.full_like(x_as_int, np_int(-1).view(np_uint)) + magnitude_adjustment = jnp.where(result_has_smaller_magnitude, minus_one, one) + result = x_as_int + magnitude_adjustment + + # Handle x == +-0. + result = jnp.where( + x_is_zero, + jnp.where(y_is_zero, result_for_both_zero, result_for_x_zero_y_non_zero), + result, + ) + + # Handle x == y. + result = jnp.where(x_and_y_are_equal, result_for_equal, result) + + # Handle isnan(x) || isnan(y). + result = jnp.where(nan_input, result_for_nan, result) + + # Cast back to the original type. + return result.view(jnp_float) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index aeb93860ee92..70ced6eb2be6 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -930,24 +930,32 @@ def kernel(x_ref, o_ref): x = jnp.array([1, 2, 3, 4]).astype(jnp.float32) / 10 np.testing.assert_allclose(kernel(x), lax.integer_pow(x, y)) - @parameterized.parameters("float32", "float64") - def test_nextafter(self, dtype): + _NEXTAFTER_VALUES = (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf) + + @parameterized.named_parameters( + (f"{dtype.__name__} ({x=}, {y=})", dtype, x, y) + for dtype, x, y in itertools.product( + (jnp.float32, jnp.float64), _NEXTAFTER_VALUES, _NEXTAFTER_VALUES, + ) + ) + def test_nextafter(self, dtype, x, y): if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") - # TODO: implement this on TPU - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not implemented: nextafter") - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1 + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), dtype), ) def kernel(x_ref, y_ref, o_ref): - o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...]) + o_ref[...] = jnp.nextafter(x_ref[...], y_ref[...]) + + x = jnp.full((4,), x, dtype=dtype) + y = jnp.full((4,), y, dtype=dtype) + out = kernel(x, y) + expected = jnp.nextafter(x, y) - x = jnp.array([1, 2, 3, 4]).astype(dtype) - y = jnp.array([1, 2, 3, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y)) + # `nextafter` requires exact equality + self.assertArraysEqual(out, expected) COMPARISON_OPS = [ jnp.equal, From 14139a3f4a01c1d6b7d31baf34c05a2a3f2cc4ef Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Fri, 1 Nov 2024 10:05:43 -0500 Subject: [PATCH 162/698] Unpin container in CI build and remove libssl-dev install --- .github/workflows/ci-build.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 7805e3206fcd..6ac7a138d7da 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -39,8 +39,6 @@ jobs: build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" runs-on: ROCM-Ubuntu - container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 60 strategy: matrix: @@ -58,10 +56,6 @@ jobs: num_generated_cases: 1 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Image Setup - run: | - apt update - apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: From e083c0800170927ffaeade5b846c857673bf17cb Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 2 Oct 2024 12:44:21 +0200 Subject: [PATCH 163/698] Re-enable cudnn_fusion_test on A100. Check that the required cuDNN version is available. --- tests/BUILD | 1 + tests/cudnn_fusion_test.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 58de3404979d..9b6b0bf66c89 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1523,6 +1523,7 @@ jax_multiplatform_test( srcs = ["cudnn_fusion_test.py"], enable_backends = [], enable_configs = [ + "gpu_a100", "gpu_h100", ], tags = ["multiaccelerator"], diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index 151cb72be8dc..7dc0571bc172 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest, parameterized from unittest import SkipTest from jax._src import test_util as jtu +from jax._src.lib import cuda_versions import jax import jax.numpy as jnp from jax._src.cudnn import cudnn_fusion @@ -26,8 +27,9 @@ class CudnnFusionTest(jtu.JaxTestCase): def setUp(self): if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on >= sm90 GPUs") + not jtu.is_cuda_compute_capability_at_least("8.0") or + cuda_versions.cudnn_get_version() < 90110): + self.skipTest("Only works on >= sm80 GPUs with cuDNN 9.1.1+") super().setUp() @parameterized.parameters(["", "pmap"]) From 26f70c9c16c77462af4f468f362ce096803d3890 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 1 Nov 2024 16:37:46 +0000 Subject: [PATCH 164/698] remove busted example from shmap jep --- docs/jep/14273-shard-map.md | 111 ++---------------------------------- 1 file changed, 4 insertions(+), 107 deletions(-) diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 8e66a675a522..63742bc852c6 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -3,6 +3,9 @@ *January 2023* +**This was the design doc proposing `shard_map`. You may instead want +[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** + ## Motivation JAX supports two schools of thought for multi-device programming: @@ -374,114 +377,8 @@ One philosophy is: it is almost always simpler to write a program in `jit==pjit` — but if a given part of the program is less optimized by the compiler than it could be, drop into `shmap`! -### A realistic transformer example - -In fact, we can implement a simple version of the ["collective -matmul"](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959) algorithm -recently introduced in XLA to overlap communication and computation using `shmap` -and 30 lines of Python. The basic idea of the algorithm can be grasped with a -simple example. - -Suppose we want to compute `C = A @ B` where `A` is sharded by a 1D mesh on the -0-th dimension while `B` and `C` are replicated. - -```python -M, K, N = 4096, 2048, 1024 -A = jnp.arange(np.prod((M, K))).reshape((M, K)) -B = jnp.arange(np.prod((K, N))).reshape((K, N)) - -mesh = Mesh(np.array(jax.devices()), axis_names=('i')) -A_x = jax.device_put(A, NamedSharding(mesh, P('i', None))) - -@jax.jit -def f(lhs, rhs): - return lhs @ rhs - -C = f(A_x, B) -``` - -A profile shows the blocking all-gather across 8 devices before the matmul can -start. This is suboptimal because `A` is sharded on a non-contracting dimension, -and each shard of `A` can be matmul'ed with `B` independently and this chunked -computation can be overlapped with fetching of the next shard of `A` from -another device. - -image - -This overlap can be implemented using `shmap` and explicit collectives. - -```python -def collective_matmul_allgather_lhs_non_contracting(lhs, rhs): - # lhs is the looped operand; rhs is the local operand - axis_size = jax.lax.psum(1, axis_name='i') - axis_index = jax.lax.axis_index(axis_name='i') - chunk_size = lhs.shape[0] - - def f(i, carrys): - accum, lhs = carrys - # matmul for a chunk - update = lhs @ rhs - # circular shift to the left - lhs = jax.lax.ppermute( - lhs, - axis_name='i', - perm=[(j, (j - 1) % axis_size) for j in range(axis_size)] - ) - # device 0 computes chunks 0, 1, ... - # device 1 computes chunks 1, 2, ... - update_index = (((axis_index + i) % axis_size) * chunk_size, 0) - accum = jax.lax.dynamic_update_slice(accum, update, update_index) - return accum, lhs - - accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype) - # fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual() - # accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs)) - for i in range(0, axis_size - 1): - accum, lhs = f(i, (accum, lhs)) - - # compute the last chunk, without the ppermute - update = lhs @ rhs - i = axis_size - 1 - update_index = (((axis_index + i) % axis_size) * chunk_size, 0) - accum = jax.lax.dynamic_update_slice(accum, update, update_index) - - return accum -``` - -``` -jit_sharded_f = jax.jit(shard_map( - collective_matmul_allgather_lhs_non_contracting, mesh, - in_specs=(P('i', None), P()), out_specs=P())) -C = jit_sharded_f(A_x, B) -``` -A profile shows that the all-gather is gone, and replaced with overlapped matmul -with async collective permute. This profile matches very closely with the -collective matmul paper result. - -image - -This collective matmul technique can be used to speed up feedforward blocks in -transformer layers. This typically consists of two matrix multiplications -followed by a `ReduceScatter` (to resolve partial sums from a parallelized -matrix multiplication) and preceded by an `AllGather` (to collect the sharded -dimensions along some axes and allow partial sum computation). Together, the -`ReduceScatter` from one layer and the `AllGather` for the next amount to an -`AllReduce`. - -In a typical profile, the two matmuls will be followed by an `AllReduce`, and -they will not be overlapped. Collective matmul can be used to achieve the -overlap, but is difficult to trigger, has a minimum slice size and does not yet -cover all topologies, tensor shapes and variants of collective matmul (i.e -latency and throughput optimized variants). [In a recent -paper](https://arxiv.org/abs/2211.05102), we found a ~40% gain in many -circumstances from manually implementing collective matmul variants in `shmap` -style. - -But it isn’t always more complex! We expect this to be a much more natural way -to think about pipelined computation, and plan to do some demos of that soon! - -### Another realistic example +### A realistic example Here's how `shmap` might look in a transformer layer pass with a 2D weight gathered pattern ([paper](https://arxiv.org/abs/2211.05102), Sec 3.2.3 on p. 5): From e657a4b28309de8769a8dfb17dcad01185976111 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 1 Nov 2024 09:48:45 -0700 Subject: [PATCH 165/698] Fix array API tests. This is currently causing failures on main. --- tests/array_api_skips.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 4646d87f096c..2ac2edcdfd99 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -7,5 +7,8 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking +# Returns wrong zero sign +array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] + # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted From 5e1366c4ce2295e231a78591ef22662a1f183389 Mon Sep 17 00:00:00 2001 From: Li-Jesse-Jiaze <963204825@qq.com> Date: Fri, 1 Nov 2024 17:57:18 +0100 Subject: [PATCH 166/698] Fix #24661: Add zsh support to conda install documentation --- docs/installation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation.md b/docs/installation.md index 5b8893628d85..b7a56c48ec1f 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -244,7 +244,7 @@ conda install jax -c conda-forge To install it on a machine with an NVIDIA GPU, run: ```bash -conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +conda install "jaxlib=*=*cuda*" jax cuda-nvcc -c conda-forge -c nvidia ``` Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which From f462d7e5865d4db512770ca2ca40d127a7308985 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Fri, 1 Nov 2024 10:02:07 -0700 Subject: [PATCH 167/698] [Mosaic] Set TPU CustomCall device type based on the core_type attribute This CL deprecates the device_type parameter of `tpu_custom_call.as_tpu_kernel()` in favour of the `tpu.core_type` annotation. The latter is more fine-grained: it is applied on `func.FuncOp` instead of the entire module, supports `tc`, `sc_scalar_subcore` and `sc_vector_subcore`. `device_type` of the TPU CustomCall HLO is set to `sparsecore` if `sc_scalar_subcore` or `sc_vector_subcore` annotation is provided. Otherwise, `device_type` is not set and the CustomCall targets TC. PiperOrigin-RevId: 692212644 --- jax/_src/tpu_custom_call.py | 40 ++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index dae1bdd76b33..f463986ffb50 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -453,6 +453,44 @@ def _lower_mosaic_module_to_asm( ) +def _get_device_type(module: ir.Module) -> str | None: + """Determines the device type based on the core_type annotations.""" + sparsecore_func_found = False + tensorcore_func_found = False + + def assign_device_type_based_on_core_type(op: ir.Operation) -> ir.WalkResult: + nonlocal sparsecore_func_found + nonlocal tensorcore_func_found + if op.name == "func.func": + if "tpu.core_type" in op.attributes: + core_type = op.attributes["tpu.core_type"] + if str(core_type) in [ + f"#tpu.core_type<{c}>" + for c in ["sc_scalar_subcore", "sc_vector_subcore"] + ]: + sparsecore_func_found = True + if tensorcore_func_found: + return ir.WalkResult.INTERRUPT + return ir.WalkResult.SKIP + if str(core_type) == "#tpu.core_type": + tensorcore_func_found = True + return ir.WalkResult.SKIP + raise ValueError(f"Unknown core type: {core_type}") + return ir.WalkResult.ADVANCE + + module.operation.walk( + assign_device_type_based_on_core_type, walk_order=ir.WalkOrder.PRE_ORDER + ) + if tensorcore_func_found and sparsecore_func_found: + raise ValueError( + "A single Mosaic kernel cannot contain both " + "TensorCore and SparseCore functions." + ) + if sparsecore_func_found: + return "sparsecore" + return None + + def _lower_to_custom_call_config( module: ir.Module, *, @@ -592,7 +630,6 @@ def as_tpu_kernel( *, cost_estimate: CostEstimate | None = None, backend: str | xla_client.Client = "tpu", - device_type: str | None = None, kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, @@ -604,6 +641,7 @@ def as_tpu_kernel( output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" + device_type = _get_device_type(module) config = _lower_to_custom_call_config( module, backend=backend, From b7bdee9056543d7a18a243d9198846f626134da2 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 25 Oct 2024 20:49:47 -0400 Subject: [PATCH 168/698] Update pre-commit workflow to cache on jax version --- .github/workflows/ci-build.yaml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 6cdc1175e600..5c786272ee3d 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -34,7 +34,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: 3.11 - - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 + - run: python -m pip install pre-commit + - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} + - run: pre-commit run --show-diff-on-failure --color=always build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" From 97e8a4c8c6dbe959d51af8d4ae451fe0edff231a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 1 Nov 2024 05:40:00 -0700 Subject: [PATCH 169/698] Fix signatures test: new axis argument in trim_zeros --- jax/_src/numpy/lax_numpy.py | 1 + tests/lax_numpy_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f79d6bc0758f..4200a9fdae72 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8328,6 +8328,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res +# TODO(jakevdp): add support for N-dimensional inputs as in NumPy v2.2 def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: """Trim leading and/or trailing zeros of the input array. diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b37237cae28c..61baa7c97df4 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6329,6 +6329,7 @@ def testWrappedSignaturesMatch(self): 'stack': ['casting'], 'std': ['mean'], 'tri': ['like'], + 'trim_zeros': ['axis'], 'var': ['mean'], 'vstack': ['casting'], 'zeros_like': ['subok', 'order'] From a0b0a8e5a1708317f2ddb7ad6c6694941079045a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 1 Nov 2024 12:33:58 -0700 Subject: [PATCH 170/698] Set minimum supported Python version to 3.10 for matplotlib. Temporary fixes an issue with `python -m build` that fails when python 3.8 is used because `matplotlib~=3.8.4` is unavailable for this python version. We are working on creating Bazel build rule with the hermetic Python for JAX wheel ([we already have Jaxlib and plugins build rules ready](https://github.com/jax-ml/jax/pull/23276)). The required python modules are provided in requirements.in file, so when we implement Bazel build rule for JAX wheel, requirements.in will be the only source of dependencies, and test-requirements.txt won't be needed for building JAX wheel. PiperOrigin-RevId: 692260046 --- build/test-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 41a6ed4588a0..94b2bbb965dc 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -15,7 +15,7 @@ rich setuptools<71.0.0 # matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement # below. -matplotlib~=3.8.4; python_version<="3.10" +matplotlib~=3.8.4; python_version=="3.10" matplotlib; python_version>="3.11" opt-einsum auditwheel From d606c242938f15ce326bb38aaa15857fbcd747ab Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 1 Nov 2024 12:56:33 -0700 Subject: [PATCH 171/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8ec02b3611ffa4378ef2189842b5912241b604d0. PiperOrigin-RevId: 692266047 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c42f373418f2..32b8f9207059 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "edf18ce242f234fbd20d1fbf4e9c96dfa5be2847" -XLA_SHA256 = "14294c2b264cb13102bc16a2f837e9f8ed7ba72eab9e3e9dc036be0b699c3c84" +XLA_COMMIT = "8ec02b3611ffa4378ef2189842b5912241b604d0" +XLA_SHA256 = "d5f22ae989dfffda803c8493862733bdf105f63961fff115553ae2bd815436db" def repo(): tf_http_archive( From 07858fa98dbd9f2e84ef66a592fa630d99df4589 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 1 Nov 2024 13:24:26 -0700 Subject: [PATCH 172/698] [sharding_in_types] Allow `device_put` to reshard inputs. `device_put` is a good choice for resharding since it already handles transpose correctly because it tracks the `src` sharding too. PiperOrigin-RevId: 692274137 --- jax/_src/api.py | 5 +++-- jax/_src/dispatch.py | 19 ++++++++++++++++++- tests/pjit_test.py | 16 ++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 0b3cb08715e9..250743821604 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2183,11 +2183,12 @@ def make_jaxpr_f(*args, **kwargs): def _infer_src_sharding(src, x) -> Sharding | None: if src is not None: - # TODO(slebedev): This looks like an error and needs investigation. return src # pytype: disable=bad-return-type if isinstance(x, array.ArrayImpl): return x.sharding - elif isinstance(x, core.Tracer): + if config.sharding_in_types.value and hasattr(x, 'sharding'): + return x.sharding + if isinstance(x, core.Tracer): val = x.to_concrete_value() if val is not None and isinstance(val, array.ArrayImpl): return val.sharding diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 97e702a9f25c..b0b390773512 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -526,7 +526,12 @@ def _batched_device_put_impl( device_put_p = core.Primitive('device_put') device_put_p.multiple_results = True device_put_p.def_impl(_batched_device_put_impl) -device_put_p.def_abstract_eval(lambda *xs, devices, srcs, copy_semantics: xs) + +def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics): + if config.sharding_in_types.value: + return [x.update(sharding=s) for x, s in zip(xs, devices)] + return xs +device_put_p.def_abstract_eval(_device_put_abstract_eval) def _device_put_transpose(cts, *_, devices, srcs, copy_semantics): results = [None] * len(cts) @@ -567,6 +572,12 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): # TODO(yashkatariya): Maybe we should add the custom calls anyways if it's # being used inside jit? Atleast for now, this preserves the old behavior. if ctx.module_context.all_default_mem_kind: + if config.sharding_in_types.value: + return [ + mlir.wrap_with_sharding_op( + ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto()) + for x, a in zip(xs, ctx.avals_out) + ] return xs def lower(x, device, aval, out_aval): if (isinstance(device, (Sharding, TransferToMemoryKind)) and @@ -592,6 +603,12 @@ def lower(x, device, aval, out_aval): def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): + if config.sharding_in_types.value: + return [ + mlir.wrap_with_sharding_op( + ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto()) + for x, a in zip(xs, ctx.avals_out) + ] return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index df98f0156c92..3a0b6cc86114 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5156,6 +5156,22 @@ def f(pred, on_true, on_false): TypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) + def test_device_put_reshard(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = jax.device_put(x, NamedSharding(x.sharding.mesh, P('x', None))) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertArraysEqual(out, np_inp) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From fff33f90b209bdc930e1164f0fa7eac92243dbdf Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 1 Nov 2024 14:00:10 -0700 Subject: [PATCH 173/698] Add `compiler_options` argument to `jax.jit`. This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})` PiperOrigin-RevId: 692283964 --- jax/_src/api.py | 3 +- jax/_src/checkify.py | 4 +- jax/_src/interpreters/pxla.py | 38 ++++++------ jax/_src/pjit.py | 86 ++++++++++++++++++---------- jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/sparse/transform.py | 5 +- tests/api_test.py | 28 ++++++++- 7 files changed, 109 insertions(+), 56 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 250743821604..cc42a37b0e7c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -151,6 +151,7 @@ def jit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, + compiler_options: dict[str, Any] | None = None, ) -> pjit.JitWrapped: """Sets up ``fun`` for just-in-time compilation with XLA. @@ -280,7 +281,7 @@ def jit( return pjit.make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env=False) + keep_unused, inline, compiler_options, use_resource_env=False) @contextmanager diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 944bf303b8f6..55db5d13e848 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -898,7 +898,8 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, inline, keep_unused): + resource_env, donated_invars, name, inline, keep_unused, + compiler_options_kvs): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] @@ -929,6 +930,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, name=name, inline=inline, keep_unused=keep_unused, + compiler_options_kvs=compiler_options_kvs, ) return tree_unflatten(out_tree, err_and_out) error_checks[pjit.pjit_p] = pjit_error_check diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e59d8c89e3e8..04d479fb757c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2121,6 +2121,7 @@ def lower_sharding_computation( *, keep_unused: bool, context_mesh: mesh_lib.Mesh | None, + compiler_options_kvs: tuple[tuple[str, Any], ...], lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None, @@ -2247,6 +2248,7 @@ def lower_sharding_computation( module, donated_invars, platforms, + compiler_options_kvs, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2298,11 +2300,13 @@ class MeshComputation(stages.XlaLowering): def __init__(self, name: str, hlo: ir.Module, donated_invars: Sequence[bool], platforms: Sequence[str], + compiler_options_kvs: tuple[tuple[str, Any], ...], **compile_args): self._name = name self._hlo = hlo self._donated_invars = donated_invars self._platforms = platforms + self._compiler_options_kvs = compiler_options_kvs self.compile_args = compile_args self._executable = None @@ -2312,11 +2316,14 @@ def stablehlo(self) -> ir.Module: return self._hlo def compile(self, compiler_options=None) -> MeshExecutable: - if self._executable is None or compiler_options is not None: + t_compiler_options = (() if compiler_options is None else + tuple(compiler_options.items())) + compiler_options_kvs = self._compiler_options_kvs + t_compiler_options + if self._executable is None or compiler_options_kvs: executable = UnloadedMeshExecutable.from_hlo( self._name, self._hlo, **self.compile_args, - compiler_options=compiler_options) - if compiler_options is None: + compiler_options_kvs=compiler_options_kvs) + if not compiler_options_kvs: self._executable = executable return executable return self._executable @@ -2581,8 +2588,7 @@ def create_compile_options( else: xla_device_assignment = np_dev.reshape((num_replicas, num_partitions)) - fdo_profile = (None if compiler_options is None else - compiler_options.pop("fdo_profile", None)) + fdo_profile = compiler_options.pop("fdo_profile", None) compile_options = compiler.get_compile_options( num_replicas=num_replicas, @@ -2614,17 +2620,11 @@ def create_compile_options( def _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, - da, pmap_nreps, compiler_options_keys, - compiler_options_values, - pgle_profiler): + da, pmap_nreps, compiler_options_kvs, pgle_profiler): # One would normally just write: dev = np.array(device_assignment) # The formulation below is substantially faster if there are many devices. dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da))) - - if compiler_options_keys is None: - compiler_options = None - else: - compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + compiler_options = dict(compiler_options_kvs) compile_options = create_compile_options( computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, @@ -2788,22 +2788,18 @@ def from_hlo(name: str, committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, + compiler_options_kvs: tuple[tuple[str, Any], ...], pmap_nreps: int = 1, mut: MutationData | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, - compiler_options=None, pgle_profiler: profiler.PGLEProfiler | None = None, intermediate_shardings: Sequence[JSharding] | None = None, context_mesh: mesh_lib.Mesh | None = None ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) - compiler_options_keys = tuple( - compiler_options.keys()) if compiler_options is not None else None - compiler_options_values = tuple( - compiler_options.values()) if compiler_options is not None else None if isinstance(device_assignment, xc.DeviceList): da = device_assignment else: @@ -2826,7 +2822,7 @@ def from_hlo(name: str, hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, - compiler_options_keys, compiler_options_values, pgle_profiler) + compiler_options_kvs, pgle_profiler) if auto_spmd_lowering: assert mesh is not None @@ -2918,6 +2914,7 @@ class JitGlobalCppCacheKeys: out_layouts_treedef: PyTreeDef | None = None out_layouts_leaves: tuple[Any, ...] | None = None use_resource_env: bool = False + compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None @functools.cached_property def contains_explicit_attributes(self): @@ -2928,7 +2925,8 @@ def contains_explicit_attributes(self): any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or any(i is not None for i in self.in_layouts_leaves) or - any(o is not None for o in self.out_layouts_leaves)) + any(o is not None for o in self.out_layouts_leaves) or + self.compiler_options_kvs) def reflatten_outputs_for_dispatch(out_tree, out_flat): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 3d8df8664052..604acfb39c16 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -164,6 +164,7 @@ class PjitInfo(NamedTuple): inline: bool abstracted_axes: Any | None use_resource_env: bool # False for jit, True for pjit + compiler_options_kvs: tuple[tuple[str, Any], ...] # Hash and compare PjitInfo by identity when used as a cache key. def __hash__(self): @@ -357,7 +358,8 @@ def cache_miss(*args, **kwargs): in_layouts_leaves=jit_info.in_layouts_leaves, out_layouts_treedef=jit_info.out_layouts_treedef, out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env) + use_resource_env=jit_info.use_resource_env, + compiler_options_kvs=jit_info.compiler_options_kvs) cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, cache_key, tree_util.dispatch_registry, @@ -398,7 +400,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, static_argnames: str | Iterable[str] | None, device: xc.Device | None, backend: str | None, abstracted_axes: Any | None, keep_unused: bool, - inline: bool, use_resource_env: bool) -> PjitInfo: + inline: bool, compiler_options: dict[str, Any] | None, + use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. Performs any preprocessing and validation of the arguments that we can do @@ -453,6 +456,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) + compiler_options_kvs = (() if compiler_options is None else + tuple(compiler_options.items())) return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -470,7 +475,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - use_resource_env=use_resource_env) + use_resource_env=use_resource_env, + compiler_options_kvs=compiler_options_kvs) def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): @@ -514,12 +520,13 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, static_argnames: str | Iterable[str] | None, device: xc.Device | None, backend: str | None, abstracted_axes: Any | None, keep_unused: bool, - inline: bool, use_resource_env: bool) -> Any: + inline: bool, compiler_options: dict[str, Any] | None, + use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" jit_info = _parse_jit_arguments( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env) + keep_unused, inline, compiler_options, use_resource_env) return _make_jit_wrapper(fun, jit_info) @@ -676,6 +683,7 @@ def _infer_params_impl( name=fun_qual_name(flat_fun), keep_unused=ji.keep_unused, inline=ji.inline, + compiler_options_kvs=ji.compiler_options_kvs, ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), donated_invars, dbg.arg_names if dbg else None, len(consts), @@ -815,6 +823,7 @@ def pjit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, + compiler_options: dict[str, Any] | None = None, ) -> JitWrapped: """Makes ``fun`` compiled and automatically partitioned across multiple devices. @@ -987,7 +996,7 @@ def pjit( return make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env=True) + keep_unused, inline, compiler_options, use_resource_env=True) def hashable_pytree(pytree): @@ -1594,25 +1603,25 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_platforms, lowering_parameters, pgle_profiler): + lowering_platforms, lowering_parameters, pgle_profiler, + compiler_options_kvs): in_shardings = _resolve_in_shardings(args, in_shardings) in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) - lowered = _pjit_lower( + return _pjit_lower( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, name, keep_unused, inline, + donated_invars, name, keep_unused, inline, compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) - return lowered def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): global _most_recent_pjit_call_executable - compile_options = None - pgle_profiler = None + pgle_compile_options, pgle_profiler = {}, None pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: if jaxpr not in pgle_profiler_dict: @@ -1626,8 +1635,9 @@ def _pjit_call_impl_python( # be None. fdo_profile = pgle_profiler.consume_fdo_profile() if fdo_profile is not None: - compile_options = {'fdo_profile': fdo_profile} + pgle_compile_options['fdo_profile'] = fdo_profile + compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) # TODO(patrios): Do not pass mutable profile session through cached lowering # chain. Instead we need to move profilers dictionary to pxla module and use # module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode. @@ -1638,8 +1648,9 @@ def _pjit_call_impl_python( donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline, lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), - pgle_profiler=pgle_profiler - ).compile(compile_options) + pgle_profiler=pgle_profiler, + compiler_options_kvs=compiler_options_kvs, + ).compile() _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. @@ -1693,7 +1704,7 @@ def _pjit_call_impl_python( @weakref_lru_cache def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): + keep_unused, inline, compiler_options_kvs): # The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so # returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to # the jaxpr defeating the purpose of weakref_lru_cache. So return a function @@ -1706,15 +1717,15 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, def _pjit_call_impl(*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, - donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): def call_impl_cache_miss(*args_, **kwargs_): out_flat, compiled = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, compiler_options_kvs=compiler_options_kvs) pgle_profiler = _read_pgle_profiler(jaxpr) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, @@ -1723,7 +1734,8 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline) + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs) donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) cache_key = pxla.JitGlobalCppCacheKeys( donate_argnums=donated_argnums, donate_argnames=None, @@ -1757,6 +1769,7 @@ def _pjit_lower_cached( name: str, keep_unused: bool, inline: bool, + compiler_options_kvs: tuple[tuple[str, Any], ...], *, lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, @@ -1767,6 +1780,7 @@ def _pjit_lower_cached( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), keep_unused=keep_unused, context_mesh=mesh, + compiler_options_kvs=compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) @@ -1911,7 +1925,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, keep_unused, inline): + donated_invars, keep_unused, inline, compiler_options_kvs): effects = list(ctx.tokens_in.effects()) output_types = map(mlir.aval_to_ir_type, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types @@ -1939,7 +1953,8 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, def _pjit_batcher(axis_data, vals_in, dims_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) @@ -1974,7 +1989,8 @@ def _pjit_batcher(axis_data, vals_in, dims_in, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs( vals_in, vals_out, axes_out) @@ -2024,7 +2040,8 @@ def _pjit_batcher_for_sharding( def _pjit_jvp(primals_in, tangents_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): if any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr) mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals) @@ -2056,7 +2073,8 @@ def _filter_zeros(is_nz_l, l): donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)]) assert len(primals_out) == len(jaxpr.jaxpr.outvars) @@ -2069,7 +2087,7 @@ def _filter_zeros(is_nz_l, l): def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, - name, keep_unused, inline): + name, keep_unused, inline, compiler_options_kvs): in_pvals = [t.pval for t in in_tracers] known_ins = tuple(pv.is_known() for pv in in_pvals) @@ -2127,7 +2145,8 @@ def keep_where(l, should_keep): in_layouts=keep_where(in_layouts, known_ins), out_layouts=known_out_layouts, resource_env=resource_env, donated_invars=keep_where(donated_invars, known_ins), - name=name, keep_unused=keep_unused, inline=inline) + name=name, keep_unused=keep_unused, inline=inline, + compiler_options_kvs=compiler_options_kvs) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals) @@ -2161,7 +2180,8 @@ def keep_where(l, should_keep): (False,) * num_residuals), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()] unknown_out_avals = unknown_jaxpr.out_avals unknown_tracers_out = [ @@ -2241,7 +2261,8 @@ def _pjit_transpose_trace(fun, in_avals): def _pjit_transpose(cts_in, *primals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -2292,7 +2313,8 @@ def prune_type(ty, xs, maybe_zeros): donated_invars=(False,) * len(primals_and_nz_cts_in), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) if attrs_tracked: final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) @@ -2358,6 +2380,8 @@ def _pjit_pp_rule(eqn, context, settings): if (params['resource_env'] is None or params['resource_env'].physical_mesh.empty): del params['resource_env'] + if not params['compiler_options_kvs']: + del params['compiler_options_kvs'] # Move name= to the front to make the resulting equation easier to scan. del params["name"] diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 972d1b3dd570..783661e71c2e 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3581,6 +3581,7 @@ def _pjit(*args: TfVal, name: str, keep_unused: bool, inline: bool, + compiler_options_kvs, _in_avals: Sequence[core.ShapedArray], _out_aval: Sequence[core.ShapedArray]) -> TfVal: del donated_invars diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 5348dd62a32e..7c5a966500f7 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -762,7 +762,7 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): + keep_unused, inline, compiler_options_kvs): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") @@ -798,7 +798,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat)) sparse_rules_bcoo[pjit.pjit_p] = _pjit_sparse diff --git a/tests/api_test.py b/tests/api_test.py index e98f4299c5e1..bb1d24729860 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1373,6 +1373,18 @@ def f(x): } ) + def test_compile_options_jit(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + f_jit = jit( + f, + compiler_options={ + "xla_embed_ir_in_executable": True, + "xla_dump_max_hlo_modules": 200, + "xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5, + })(1.0) # doesn't crash. + def test_jit_lower_compile_with_compiler_options_invalid(self): def f(x): return jnp.sqrt(x ** 2) + 1. @@ -1390,7 +1402,21 @@ def f(x): lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) - def test_jit_lower_compile_with_compiler_options_multiple(self): + def test_jit_compile_with_compiler_options_multiple(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + with jtu.count_jit_compilation_cache_miss() as count: + jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.) + jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.) + self.assertEqual(count[0], 2) + + # We should still error on invalid options after some valid compiles + with self.assertRaisesRegex( + xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"): + jit(f, compiler_options={"invalid_key": "invalid_value"})(1.) + + def test_lower_compile_with_compiler_options_multiple(self): def f(x): return jnp.sqrt(x ** 2) + 1. From d38da5d1b4be5167c25ead8e6b1a6de04e531870 Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Fri, 1 Nov 2024 14:34:54 -0700 Subject: [PATCH 174/698] Verify we can offload more than a single computation to SparseCore PiperOrigin-RevId: 692293824 --- tests/layout_test.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/layout_test.py b/tests/layout_test.py index 7ccd0d7cddea..31f3d71d0537 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -625,6 +625,36 @@ def f(x, y): f(inp, sparecore_arr) + def test_sparsecore_compute_twice(self): + if not ( + jax.devices()[0].device_kind == 'TPU v5' + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest('Does not have a sparsecore present') + shape = (4096, 8) + inp = jnp.arange(math.prod(shape)).reshape(shape) + + dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + s = SingleDeviceSharding(jax.devices()[0]) + sparse_layout = Layout(dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_multiply(x, y): + return x * y + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_add(x, y): + return x + y + + @partial(jax.jit, donate_argnums=0, out_shardings=sparse_layout) + def f(x): + return sparsecore_multiply(sparsecore_add(x, x) + 1, x) + + f(sparecore_arr) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 0f3ba4250d0730a16017a5847868fea2cb142dbd Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 1 Nov 2024 16:24:42 -0700 Subject: [PATCH 175/698] support exec_time_optimization_effort and memory_fitting_effort xla compilation options PiperOrigin-RevId: 692322944 --- CHANGELOG.md | 5 +++++ jax/_src/compiler.py | 7 +++++++ tests/api_test.py | 22 +++++++++++++++++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b629631ea4d..457107d8af9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. +* New Features + * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for + passing compilation options to XLA. For the moment it's undocumented and + may be in flux. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 8a2d6047e9b8..113f7507c4b0 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -189,6 +189,13 @@ def get_compile_options( compile_options.device_assignment = device_assignment if env_options_overrides is not None: + # Some overrides are passed directly on build_options. + overrides_on_build_options = [ + 'exec_time_optimization_effort', 'memory_fitting_effort'] + env_options_overrides = dict(env_options_overrides) + for name in overrides_on_build_options: + if name in env_options_overrides: + setattr(build_options, name, env_options_overrides.pop(name)) compile_options.env_option_overrides = list(env_options_overrides.items()) debug_options = compile_options.executable_build_options.debug_options diff --git a/tests/api_test.py b/tests/api_test.py index bb1d24729860..8ab5d90f6e07 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -60,7 +60,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_extension +from jax._src.lib import xla_extension, xla_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -1385,6 +1385,26 @@ def f(x): "xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5, })(1.0) # doesn't crash. + def test_exec_time_optimization_effort_compiler_option(self): + if xla_extension_version < 294: + raise unittest.SkipTest("test requires newer xla extension version") + + def f(x): + return jnp.sqrt(x ** 2) + 1. + + f_jit = jit( + f, + compiler_options={ + "exec_time_optimization_effort": 0.0, + })(1.0) # doesn't crash. + + with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"): + f_jit = jit( + f, + compiler_options={ + "exec_time_compilation_effort": 0.0, + })(1.0) + def test_jit_lower_compile_with_compiler_options_invalid(self): def f(x): return jnp.sqrt(x ** 2) + 1. From 292a00b35afcf9e46df3af0edd3d36987886a484 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 1 Nov 2024 22:55:56 -0700 Subject: [PATCH 176/698] [export] Cleanup in the export module. With jax.experimental.export gone we can now do some cleanup in the export module. In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024. PiperOrigin-RevId: 692398132 --- CHANGELOG.md | 11 + docs/export/export.md | 6 +- jax/_src/export/_export.py | 213 ++++-------------- jax/_src/export/serialization.fbs | 2 +- jax/_src/export/serialization.py | 9 +- .../export_back_compat_test_util.py | 2 +- jax/experimental/jax2tf/jax2tf.py | 15 +- jax/experimental/jax2tf/tests/call_tf_test.py | 2 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 +- tests/export_harnesses_multi_platform_test.py | 2 +- tests/export_test.py | 25 +- tests/pallas/export_pallas_test.py | 2 +- 12 files changed, 98 insertions(+), 193 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 457107d8af9e..b122675e1ffc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,17 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The deprecated module `jax.experimental.export` has been removed. It was replaced by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. + * The following deprecated methods and functions in {mod}`jax.export` have + been removed: + * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect + already. + * `jax.export.Exported.lowering_platforms`: use `platforms`. + * `jax.export.Exported.mlir_module_serialization_version`: + use `calling_convention_version`. + * `jax.export.Exported.uses_shape_polymorphism`: + use `uses_global_constants`. + * the `lowering_platforms` kwarg for {func}`jax.export.export`: use + `platforms` instead. * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/docs/export/export.md b/docs/export/export.md index 5960fcaea65a..b62cf9fe0113 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -273,7 +273,7 @@ ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used o >>> # compilation platform (which is the case for `cos` in this >>> # example): >>> exp_unsafe = export.export(jax.jit(lax.cos), -... lowering_platforms=['tpu'], +... platforms=['tpu'], ... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.) >>> exp_unsafe.call(1.) @@ -281,7 +281,7 @@ Array(0.5403023, dtype=float32, weak_type=True) # and similarly with multi-platform lowering >>> exp_multi = export.export(jax.jit(lax.cos), -... lowering_platforms=['tpu', 'cpu', 'cuda'])(1.) +... platforms=['tpu', 'cpu', 'cuda'])(1.) >>> exp_multi.call(1.) Array(0.5403023, dtype=float32, weak_type=True) @@ -310,7 +310,7 @@ the same StableHLO as for the single-plaform export. 9220 >>> exp_multi = export.export(jax.jit(f), -... lowering_platforms=["cpu", "tpu", "cuda"])(1.) +... platforms=["cpu", "tpu", "cuda"])(1.) >>> len(exp_multi.mlir_module_serialized) # doctest: +SKIP 9282 diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 99794c8cc23c..fe4deed57d19 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -26,7 +26,6 @@ import json import re from typing import Any, Protocol, TypeVar, Union, cast -import warnings from absl import logging import numpy as np @@ -102,20 +101,6 @@ def custom_call(cls, target_name: str) -> DisabledSafetyCheck: """ return DisabledSafetyCheck(f"custom_call:{target_name}") - @classmethod - def shape_assertions(cls) -> DisabledSafetyCheck: - """DEPRECATED: A noop. - - Was used previously to allow invocations with shapes that do not meet the - constraints. Has no effect anymore, shape assertions cannot be disabled. - """ - # TODO(necula): remove this after compatibility period. Was deprecated in - # May 2024. - warnings.warn( - "DisabledSafetyCheck.shape_assertions is deprecated, has no effect anymore", - DeprecationWarning, stacklevel=2) - return DisabledSafetyCheck("shape_assertions") - def is_custom_call(self) -> str | None: """Returns the custom call target allowed by this directive.""" m = re.match(r'custom_call:(.+)$', self._impl) @@ -274,33 +259,6 @@ def out_shardings_jax( return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) for s in self.out_shardings_hlo) - # For backwards compatibility - # TODO(necula): remove after September 2024. - @property - def lowering_platforms(self): - """DEPRECATED.""" - warnings.warn("lowering_platform is deprecated. Use .platforms instead.", - DeprecationWarning, stacklevel=2) - return self.platforms - - # For backwards compatibility - # TODO(necula): remove after September 2024. - @property - def mlir_module_serialization_version(self): - """DEPRECATED.""" - warnings.warn("mlir_module_serialization_version is deprecated. Use .calling_convention_version instead.", - DeprecationWarning, stacklevel=2) - return self.calling_convention_version - - # For backwards compatibility - # TODO(necula): remove after September 2024. - @property - def uses_shape_polymorphism(self): - """DEPRECATED.""" - warnings.warn("uses_shape_polymorphism is deprecated. Use .uses_global_constants instead.", - DeprecationWarning, stacklevel=2) - return self.uses_global_constants - def has_vjp(self) -> bool: """Returns if this Exported supports VJP.""" return self._get_vjp is not None @@ -546,109 +504,11 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: aval = core.raise_to_shaped(core.get_aval(a)) return aval.shape, aval.dtype -def args_specs( - args, # pytree of arguments - polymorphic_shapes, # prefix pytree of strings - get_shape_and_dtype=shape_and_dtype_jax_array, -): - # TODO: deprecated in January 2024, to be removed. - warnings.warn( - "export.args_specs is deprecated in favor of export.symbolic_args_specs", - DeprecationWarning, stacklevel=2) - if get_shape_and_dtype is not shape_and_dtype_jax_array: - # This was needed in some older jax2tf implementations - args = tree_util.tree_map(lambda a: jax.ShapeDtypeStruct(* get_shape_and_dtype(a)), - args) - return shape_poly.symbolic_args_specs(args, polymorphic_shapes) - - -# TODO(necula): remove this once we remove jax.experimental.export. -def export_back_compat( - fun_jax: Callable, - *, - lowering_platforms: Sequence[str] | None = None, - disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, - ) -> Callable[..., Exported]: - """Exports native serialization for a JAX function. - - Note: this function exists only for internal usage by jax2tf and for - backwards compatibility with jax.experimental.export. Use - `jax.export` instead. - See https://jax.readthedocs.io/en/latest/export/export.html - - Args: - fun_jax: the function to lower and serialize. - lowering_platforms: - Optional sequence containing a subset of 'tpu', 'cpu', - 'cuda', 'rocm'. If more than one platform is specified, then - the lowered code takes an argument specifying the platform. - If None, then use the default JAX backend. - The calling convention for multiple platforms is explained - at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. - disabled_checks: the safety checks to disable. See docstring - of `DisabledSafetyCheck`. - - Returns: - a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, - or values with `.shape` and `.dtype` attributes, and returns an - `Exported`. - - Usage: - - def f_jax(*args, **kwargs): ... - exported = jax_export.export(f_jax)(*args, **kwargs) - """ - - def do_export(*args_specs, **kwargs_specs) -> Exported: - if hasattr(fun_jax, "trace"): - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - wrapped_fun_jax = fun_jax - else: - # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also - # convert(f_jax), in which case a "jit" is implied. In that case we raise - # an error if the lowered function contains non-replicated sharding annotations. - wrapped_fun_jax = jax.jit(fun_jax) - - if lowering_platforms is not None: - actual_lowering_platforms = tuple(lowering_platforms) - else: - actual_lowering_platforms = (default_export_platform(),) - - # TODO: move to `lower` - symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] - for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: - # Static args may have no `shape` attribute. - if not hasattr(aval, "shape"): - continue - for d in aval.shape: - if shape_poly.is_symbolic_dim(d): - if symbolic_scope is None: - symbolic_scope = (d.scope, k_path) - continue - symbolic_scope[0]._check_same_scope( - d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}", - self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", - other_descr=shape_poly.args_kwargs_path_to_str(k_path)) - - traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs) - lowered = traced.lower( - lowering_platforms=actual_lowering_platforms, - _private_parameters=mlir.LoweringParameters( - for_export=True, - export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) - return _export_lowered( - lowered, traced.jaxpr, traced.fun_name, - disabled_checks=disabled_checks, - _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) - return do_export def export( fun_jit: stages.Wrapped, *, platforms: Sequence[str] | None = None, - lowering_platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), ) -> Callable[..., Exported]: """Exports a JAX function for persistent serialization. @@ -662,7 +522,6 @@ def export( If None, then use the default JAX backend. The calling convention for multiple platforms is explained at https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. - lowering_platforms: DEPRECATED, use `platforms`. disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. @@ -689,34 +548,38 @@ def export( >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32)) Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) """ + return _export_internal(fun_jit, platforms=platforms, + disabled_checks=disabled_checks) + + +# TODO(necula): remove this once we improve the integration with jax2tf. +def _export_internal( + fun_jit: stages.Wrapped, + *, + platforms: Sequence[str] | None = None, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + _device_assignment_for_internal_jax2tf_use_only = None, + ) -> Callable[..., Exported]: + """Exports native serialization for a JAX function. + + Note: this function exists only for internal usage by jax2tf. Use + `jax.export` instead. + See https://jax.readthedocs.io/en/latest/export/export.html + + See docstring of `export` for more details. + """ if not isinstance(fun_jit, stages.Wrapped): raise ValueError( f"Function to be exported must be the result of `jit` but is: {fun_jit}") - if platforms is not None and lowering_platforms is not None: - raise ValueError("Cannot use both `platforms` and `lowering_platforms`") - if platforms is None and lowering_platforms is not None: - platforms = lowering_platforms - if platforms is not None: - actual_lowering_platforms = tuple(platforms) - else: - actual_lowering_platforms = (default_export_platform(),) def do_export(*args_specs, **kwargs_specs) -> Exported: + if platforms is not None: + actual_lowering_platforms = tuple(platforms) + else: + actual_lowering_platforms = (default_export_platform(),) + # TODO: move to `lower` - symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] - for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: - # Static args may have no `shape` attribute. - if not hasattr(aval, "shape"): - continue - for d in aval.shape: - if shape_poly.is_symbolic_dim(d): - if symbolic_scope is None: - symbolic_scope = (d.scope, k_path) - continue - symbolic_scope[0]._check_same_scope( - d, when=f"when exporting {util.fun_name(fun_jit)}", - self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", - other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + check_symbolic_scope_errors(fun_jit, args_specs, kwargs_specs) traced = fun_jit.trace(*args_specs, **kwargs_specs) lowered = traced.lower( @@ -726,12 +589,32 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) return _export_lowered( lowered, traced.jaxpr, traced.fun_name, - disabled_checks=disabled_checks) + disabled_checks=disabled_checks, + _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) return do_export + +def check_symbolic_scope_errors(fun_jax, args_specs, kwargs_specs): + symbolic_scope: tuple[shape_poly.SymbolicScope, tree_util.KeyPath] | None = None # type: ignore[invalid-annotation,unused-ignore] + for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: + # Static args may have no `shape` attribute. + if not hasattr(aval, "shape"): + continue + for d in aval.shape: + if shape_poly.is_symbolic_dim(d): + if symbolic_scope is None: + symbolic_scope = (d.scope, k_path) + continue + symbolic_scope[0]._check_same_scope( + d, when=f"when exporting {util.fun_name(fun_jax)}", + self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", + other_descr=shape_poly.args_kwargs_path_to_str(k_path)) + + def _export_lowered( lowered: stages.Lowered, - jaxpr: core.ClosedJaxpr, fun_name: str, + jaxpr: core.ClosedJaxpr, + fun_name: str, disabled_checks: Sequence[DisabledSafetyCheck] = (), _device_assignment_for_internal_jax2tf_use_only = None, ) -> Exported: diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index b72d0134cf1f..3198f83aa120 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -97,7 +97,7 @@ table Effect { enum DisabledSafetyCheckKind: byte { platform, custom_call, - shape_assertions, + shape_assertions, // unused } table DisabledSafetyCheck { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 434c4c5cf10c..e392289da64d 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -485,8 +485,6 @@ def _serialize_disabled_safety_check( custom_call_target = builder.CreateString(custom_call_target_str) elif check == _export.DisabledSafetyCheck.platform(): kind = ser_flatbuf.DisabledSafetyCheckKind.platform - elif check == _export.DisabledSafetyCheck.shape_assertions(): - kind = ser_flatbuf.DisabledSafetyCheckKind.shape_assertions else: raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}") @@ -510,5 +508,10 @@ def _deserialize_disabled_safety_check( if kind == ser_flatbuf.DisabledSafetyCheckKind.platform: return _export.DisabledSafetyCheck.platform() if kind == ser_flatbuf.DisabledSafetyCheckKind.shape_assertions: - return _export.DisabledSafetyCheck.shape_assertions() + # shape_assertions has been deprecated in June 2024 (turned into a no-op), + # and removed in November 2024. We deserialize it to a DisabledSafetyCheck + # that has no effect. + # TODO(necula): remove this after June 2025, when we should not have any + # more serialized artifacts with shape_assertions. + return _export.DisabledSafetyCheck.custom_call("no op") assert False, kind diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 70826eec8806..5d5e95b5cb9a 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -294,7 +294,7 @@ def serialize(self, args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) exported = export.export( jax.jit(func), - lowering_platforms=(self.default_jax_backend(),), + platforms=(self.default_jax_backend(),), disabled_checks=tuple( export.DisabledSafetyCheck.custom_call(target) for target in allow_unstable_custom_call_targets) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 783661e71c2e..c6d920918074 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -532,7 +532,16 @@ def __init__(self, fun_jax, *, self.convert_kwargs = dict(native_serialization=True, native_serialization_platforms=native_serialization_platforms, native_serialization_disabled_checks=native_serialization_disabled_checks) - self.fun_jax = fun_jax + if hasattr(fun_jax, "trace"): + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + fun_jit = fun_jax + else: + # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also + # convert(f_jax), in which case a "jit" is implied. In that case we raise + # an error if the lowered function contains non-replicated sharding annotations. + fun_jit = jax.jit(fun_jax) + self.fun_jax = fun_jit self.args_specs = args_specs self.kwargs_specs = kwargs_specs self.native_serialization_disabled_checks = native_serialization_disabled_checks @@ -547,9 +556,9 @@ def _restore_context(): self._restore_context = _restore_context _exported_device_assignment = [None] - self.exported = _export.export_back_compat( + self.exported = _export._export_internal( self.fun_jax, - lowering_platforms=self.native_serialization_platforms, + platforms=self.native_serialization_platforms, disabled_checks=self.native_serialization_disabled_checks, _device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment, )(*self.args_specs, **self.kwargs_specs) diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 6d5efb7b1e66..f23bd58c48d3 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -803,7 +803,7 @@ def f_jax(x): lowering_platforms = ("tpu", "cpu", "cuda") exp = export.export(jax.jit(f_jax), - lowering_platforms=lowering_platforms)(x) + platforms=lowering_platforms)(x) for jax_platform in jax_and_tf_platforms: with self.subTest(jax_platform): jax_device = jax.devices(jax_platform)[0] diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 68c7b15383fe..e59084041306 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1531,7 +1531,7 @@ def apply_transform(func, transform: str): _ = func_to_convert(*args) exported = export.export( (jax.jit(func_to_convert) if not hasattr(func_to_convert, "trace") else func_to_convert), - lowering_platforms=("tpu",) + platforms=("tpu",) )(*(core.ShapedArray(a.shape, a.dtype) for a in args)) if transform1 == "shard_map": diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index c74eec550342..e8b1afc224b7 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -149,7 +149,7 @@ def export_and_compare_to_native( logging.info("Exporting harness for %s", lowering_platforms) exp = export.export(jax.jit(func_jax), - lowering_platforms=lowering_platforms)(*args) + platforms=lowering_platforms)(*args) for device in devices: if device.platform in skip_run_on_platforms: diff --git a/tests/export_test.py b/tests/export_test.py index b6dde23721a3..2946854aa549 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -196,7 +196,7 @@ def test_pytree_export_only(self): def f(a_b_pair, *, a, b): return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b)) - exp = get_exported(jax.jit(f), lowering_platforms=("cpu",))((a, b), a=a, b=b) + exp = get_exported(jax.jit(f), platforms=("cpu",))((a, b), a=a, b=b) a_aval = core.ShapedArray(a.shape, a.dtype) b_aval = core.ShapedArray(b.shape, b.dtype) self.assertEqual(exp.platforms, ("cpu",)) @@ -463,7 +463,7 @@ def test_default_export_platform(self): def test_error_wrong_platform(self, platform): a = np.arange(4, dtype=np.float32) - exp_f = get_exported(jnp.sin, lowering_platforms=(platform,))(a) + exp_f = get_exported(jnp.sin, platforms=(platform,))(a) if xb.canonicalize_platform(jtu.device_under_test()) == platform: raise unittest.SkipTest("Uninteresting scenario") @@ -473,7 +473,7 @@ def test_error_wrong_platform(self, platform): # Now try with the platform check disabled exp_f_no_platform_check = get_exported( - jnp.sin, lowering_platforms=(platform,), + jnp.sin, platforms=(platform,), disabled_checks=[export.DisabledSafetyCheck.platform()])(a) res = exp_f_no_platform_check.call(a) self.assertAllClose(res, jnp.sin(a)) @@ -1464,7 +1464,7 @@ def f(x): def test_multi_platform(self): x = np.arange(8, dtype=np.float32) exp = get_exported(jax.jit(_testing_multi_platform_func), - lowering_platforms=("tpu", "cpu", "cuda", "rocm"))(x) + platforms=("tpu", "cpu", "cuda", "rocm"))(x) self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "rocm")) module_str = str(exp.mlir_module()) expected_main_re = ( @@ -1487,14 +1487,14 @@ def test_multi_platform(self): def test_multi_platform_nested(self): x = np.arange(5, dtype=np.float32) exp = get_exported(jax.jit(lambda x: _testing_multi_platform_func(jnp.sin(x))), - lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + platforms=("cpu", "tpu", "cuda", "rocm"))(x) self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call to the exported using a different sequence of # lowering platforms, but included in the lowering platforms for the # nested exported. exp2 = get_exported(jax.jit(exp.call), - lowering_platforms=("cpu", "cuda", "rocm"))(x) + platforms=("cpu", "cuda", "rocm"))(x) # Ensure that we do not have multiple lowerings of the exported function exp2_module_str = str(exp2.mlir_module()) @@ -1513,7 +1513,7 @@ def test_multi_platform_nested(self): def test_multi_platform_nested_inside_single_platform_export(self): x = np.arange(5, dtype=np.float32) exp = get_exported(jax.jit(_testing_multi_platform_func), - lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(x) + platforms=("cpu", "tpu", "cuda", "rocm"))(x) self.assertEqual(exp.platforms, ("cpu", "tpu", "cuda", "rocm")) # Now serialize the call for the current platform. @@ -1586,14 +1586,14 @@ def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, def f(x): return times_2_or_3_or_4.bind(x) x = np.float32(42.) - exp = export.export(f, lowering_platforms=["cpu", "cuda", "rocm", "tpu"])(x) + exp = export.export(f, platforms=["cpu", "cuda", "rocm", "tpu"])(x) expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()]) self.assertAllClose(exp.call(x), expected) def test_multi_platform_unknown_platform(self): x = np.arange(8, dtype=np.float32) exp = get_exported(jax.jit(jnp.sin), - lowering_platforms=("tpu", "cpu", "cuda", "other"))(x) + platforms=("tpu", "cpu", "cuda", "other"))(x) self.assertEqual(exp.platforms, ("tpu", "cpu", "cuda", "other")) @@ -1620,7 +1620,7 @@ def test_multi_platform_and_poly(self): # The export is not applicable to GPU raise unittest.SkipTest("Not intended for running on GPU") exp = get_exported(jax.jit(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,))), - lowering_platforms=("cpu", "tpu"))( + platforms=("cpu", "tpu"))( jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), np.float32) ) x = np.arange(12, dtype=np.float32).reshape((3, 4)) @@ -1643,8 +1643,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] return b * 2. res_native = f_jax(a) - exp = get_exported(f_jax, - lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(a) + exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) # Call with argument placed on different plaforms for platform in self.__class__.platforms: @@ -1790,7 +1789,7 @@ def f_jax(x): # x: f32[b1, b2] effect_class_name="ForTestingOrderedEffect1") exp = get_exported( jax.jit(f_jax), - lowering_platforms=("cpu", "tpu") + platforms=("cpu", "tpu") )(jax.ShapeDtypeStruct(export.symbolic_shape("b1, b2"), x.dtype)) mlir_module_str = str(exp.mlir_module()) wrapped_main_expected_re = ( diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index ee7fe4ffff47..70e40e1f2801 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -49,7 +49,7 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: a = np.arange(8 * 16, dtype=np.int32).reshape((8, 16)) exp = export.export( add_vectors, - lowering_platforms=["tpu", "cuda"], + platforms=["tpu", "cuda"], )(a, a) if (jtu.device_under_test() == "tpu" or From 72eb5088b722a34c4d81758385d77637d398e7c0 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Sat, 2 Nov 2024 12:12:41 -0700 Subject: [PATCH 177/698] [jax] Mesh discharge rule should return None for inputs it did not touch. PiperOrigin-RevId: 692519730 --- jax/_src/pallas/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index b697810b8967..14e9d72af186 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1112,7 +1112,7 @@ def body(*args): backend=backend, )(*args) # ``outs`` lacks the unmodified inputs. Add them back in. - all_outs = [*args] + all_outs = [None] * len(args) for out_idx, in_idx in enumerate(modified_idxs): all_outs[in_idx] = outs[out_idx] return all_outs, () From d679c0abaa4a064ad4fc30c1faf21624dc06fe0a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 2 Nov 2024 13:34:53 -0700 Subject: [PATCH 178/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/cbf3c8d5deae86c162725a30f2871648430990d0. PiperOrigin-RevId: 692531116 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 32b8f9207059..fffc475fa1b9 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "8ec02b3611ffa4378ef2189842b5912241b604d0" -XLA_SHA256 = "d5f22ae989dfffda803c8493862733bdf105f63961fff115553ae2bd815436db" +XLA_COMMIT = "cbf3c8d5deae86c162725a30f2871648430990d0" +XLA_SHA256 = "7bf0a7fd6bb2eeb54c386f47b353834d15eda4a2d7ca383a4d42c6d444b484eb" def repo(): tf_http_archive( From ec39b592f7c096b0b8183723feaab2ed0d001041 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Sat, 2 Nov 2024 17:02:02 -0700 Subject: [PATCH 179/698] Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat) PiperOrigin-RevId: 692557993 --- jax/_src/ad_util.py | 3 +- jax/_src/core.py | 112 ++++++---------------- jax/_src/custom_derivatives.py | 27 +++--- jax/_src/interpreters/ad.py | 9 +- jax/_src/interpreters/batching.py | 8 +- jax/_src/interpreters/partial_eval.py | 24 ++--- jax/_src/lax/control_flow/conditionals.py | 10 +- jax/_src/lax/control_flow/loops.py | 6 +- jax/_src/lax/control_flow/solves.py | 3 +- jax/_src/lax/lax.py | 8 +- jax/_src/lax/linalg.py | 4 +- jax/_src/lax/parallel.py | 15 ++- jax/_src/pallas/core.py | 19 ---- jax/_src/pallas/mosaic/core.py | 9 -- jax/_src/pallas/mosaic/verification.py | 3 +- jax/_src/pallas/mosaic_gpu/core.py | 10 -- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- jax/_src/state/primitives.py | 2 - jax/_src/state/types.py | 11 +-- jax/core.py | 1 - tests/core_test.py | 9 -- tests/lax_test.py | 2 +- tests/pallas/tpu_pallas_test.py | 5 +- tests/state_test.py | 5 +- 24 files changed, 96 insertions(+), 211 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index bd1427f59e01..02f3b0405e38 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -43,7 +43,8 @@ def add_impl(x, y): @add_jaxvals_p.def_abstract_eval def add_abstract(x, y): - return core.lattice_join(x, y) + assert core.typematch(x, y) + return x def zeros_like_aval(aval: core.AbstractValue) -> Array: return aval_zeros_likers[type(aval)](aval) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7d912e3c207b..1fecf3f18b43 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -368,7 +368,7 @@ class Var: def __init__(self, suffix: str, aval: AbstractValue): self.count = next(_var_counter) self.suffix = suffix - self.aval = raise_to_shaped(aval) + self.aval = aval # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not # care about variable ordering, but the downstream package kfac_jax does. @@ -662,7 +662,7 @@ def __init__(self, trace: Trace): def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" - return f"traced array with shape {raise_to_shaped(self.aval).str_short()}" + return f"traced array with shape {self.aval.str_short()}" def __array__(self, *args, **kw): raise TracerArrayConversionError(self) @@ -1302,11 +1302,11 @@ def __repr__(self): except AttributeError: return self.__class__.__name__ - def strip_weak_type(self) -> AbstractValue: + def update_weak_type(self, weak_type): return self - def join(self, other): - raise NotImplementedError("must override") + def strip_weak_type(self) -> AbstractValue: + return self.update_weak_type(False) def update(self, **kwargs): raise NotImplementedError("must override") @@ -1314,7 +1314,6 @@ def update(self, **kwargs): def str_short(self, short_dtypes=False): return str(self) - # For type signatures involving dynamic shapes, we use lists of abstract values # which may contain (reverse) de Bruijn indices in their shapes. class DBIdx(NamedTuple): @@ -1348,26 +1347,10 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -class Bot(AbstractValue): pass -bot = Bot() - - -def lattice_join(x: AbstractValue | None, - y: AbstractValue | None) -> AbstractValue: - if x is None: - assert y is not None - return y - elif y is None: - return x - elif isinstance(x, type(y)): - return y.join(x) - elif isinstance(y, type(x)): - return x.join(y) - elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray): - # TODO(mattjj): remove this special case after dynamic shapes are integrated - return x.join(y) - else: - raise TypeError(x, y) +# TODO(dougalm): Deprecate. This is here for backwards compat. +def lattice_join(x, y): + assert typematch(x, y) + return x # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1530,9 +1513,8 @@ def __repr__(self): def str_short(self, short_dtypes=False) -> str: return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - def strip_weak_type(self): - """Returns a copy of the aval with weak_type=False.""" - return self.update(weak_type=False) + def update_weak_type(self, weak_type): + return self.update(weak_type=weak_type) def _canonicalize_dimension(dim: DimSize) -> DimSize: # Dimensions are most commonly integral (by far), so we check that first. @@ -1656,13 +1638,6 @@ def to_tangent_aval(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) - def join(self, other): - if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype: - weak_type = self.weak_type and other.weak_type - return self.update(weak_type=weak_type) - else: - raise TypeError(self, other) - def str_short(self, short_dtypes=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name) @@ -1762,14 +1737,6 @@ def __eq__(self, other): def __hash__(self): return hash((self.shape, self.dtype, self.weak_type)) - def join(self, other): - if (definitely_equal_shape(self.shape, other.shape) and - self.dtype == other.dtype): - weak_type = self.weak_type and other.weak_type - return self.update(weak_type=weak_type) - else: - raise TypeError(self, other) - def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1881,16 +1848,11 @@ def mutable_array_abstract_eval(init_aval): @mutable_array_p.def_impl def _mutable_array_impl(init_val): from jax._src.state.types import AbstractRef # pytype: disable=import-error - aval = raise_to_shaped(get_aval(init_val)) + aval = get_aval(init_val) return MutableArray(AbstractRef(aval), init_val) class AbstractToken(AbstractValue): - def join(self, other): - if isinstance(other, AbstractToken): - return self - else: - assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() @@ -1910,30 +1872,9 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -def raise_to_shaped(aval: AbstractValue, weak_type=None): - aval_type = type(aval) - if aval_type is ShapedArray and weak_type is None: - return aval - if aval_type is DShapedArray and weak_type is None: - return aval - if weak_type is None: - weak_type = getattr(aval, 'weak_type', False) - for typ in aval_type.__mro__: - handler = raise_to_shaped_mappings.get(typ) - if handler: return handler(aval, weak_type) - raise TypeError(type(aval)) - -def _shaped_array_mapping(aval, weak_type): - if config.sharding_in_types.value: - return ShapedArray(aval.shape, aval.dtype, weak_type, sharding=aval.sharding) - return ShapedArray(aval.shape, aval.dtype, weak_type) - -raise_to_shaped_mappings: dict[type, Callable] = { - AbstractToken: lambda aval, _: aval, - Bot: lambda aval, _: aval, - ShapedArray: _shaped_array_mapping, - DShapedArray: lambda aval, _: aval -} +# TODO(dougalm): Deprecate. This is just here for backwards compat. +def raise_to_shaped(aval): + return aval ### Operations on shapes and dimension sizes. @@ -2341,18 +2282,23 @@ def typecheck(aval: AbstractValue, x) -> bool: def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: """Determine whether `aval` conforms to `aval_ref`. Ignores weak_type.""" try: - return typematch(aval_ref, lattice_join(aval_ref, aval)) + return typematch(aval_ref, aval) except TypeError: return False -def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool: - """Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type.""" - if aval1 == aval2: return True - # unequal avals may still represent the same type, because type is represented - # by avals at the shaped level, and because weak type tags aren't considered - # part of the type - return (raise_to_shaped(aval1, weak_type=False) == - raise_to_shaped(aval2, weak_type=False)) +def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: + """Determine whether `t1` and `t2` are equivalent. Ignores weak_type.""" + t1 = t1.strip_weak_type() + t2 = t2.strip_weak_type() + if t1 == t2: + return True + elif (isinstance(t1, (ShapedArray, DShapedArray)) and + isinstance(t2, (ShapedArray, DShapedArray))): + # This case handles DShapedArray and shape polynomials. Alternatively we + # could try normalizing first and then doing simple equality. + return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + else: + return False class JaxprTypeError(TypeError): pass diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 77f73562aecd..375efeb712b8 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,7 +31,6 @@ stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs) -from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -81,7 +80,7 @@ def _flatten_fun_nokwargs(in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = yield py_args, {} ans_flat, ans_tree = tree_flatten(ans) - ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat] + ans_avals = [core.get_aval(x) for x in ans_flat] yield ans_flat, (ans_tree, ans_avals) @@ -287,7 +286,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): py_primals_out, py_tangents_out = pair_out primals_out, out_tree = tree_flatten(py_primals_out) tangents_out, out_tree2 = tree_flatten(py_tangents_out) - primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] + primal_avals = [core.get_aval(x) for x in primals_out] if out_tree != out_tree2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must " "produce primal and tangent outputs with equal container (pytree) " @@ -327,11 +326,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] + primal_avals_out = [core.get_aval(x).strip_weak_type() for x in primals_out] expected_tangent_avals_out = [ - raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval() + core.get_aval(x).strip_weak_type().to_tangent_aval() for x in primals_out] - tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) + tangent_avals_out = [core.get_aval(t).strip_weak_type() if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] if expected_tangent_avals_out != tangent_avals_out: @@ -606,7 +605,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable f_, dyn_args = lu.wrap_init(self.fun), args fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + in_avals = [core.get_aval(x) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name, fwd_name, in_tree, out_type) @@ -674,7 +673,7 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, py_primals_out, res = pair_out primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) - primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] + primal_avals = [core.get_aval(x) for x in primals_out] # If the primal function already ran, check out_tree agreement. try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None @@ -772,7 +771,7 @@ def append(x, d): msg = ("Custom VJP bwd rule must produce an output with the same " "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " - f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding " + f"shape/dtype {a_.str_short()} corresponding " f"to an input of shape/dtype {a.str_short()}.") raise ValueError(msg) results.append(ct) @@ -831,7 +830,7 @@ def _custom_vjp_call_jaxpr_jvp( _, res_tree = out_trees() res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, @@ -1110,7 +1109,7 @@ def merge(l1, l2): return out, merge def abstractify(x): - return core.raise_to_shaped(core.get_aval(x)) + return core.get_aval(x) ### Custom transposition @@ -1211,7 +1210,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) - out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) + out_avals = f_jaxpr.out_avals t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) @@ -1265,7 +1264,7 @@ def _linear_call_transpose_rule(cts, *args, callee, transpose, return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): - return map(core.raise_to_shaped, kwargs['callee'].out_avals) + return kwargs['callee'].out_avals linear_call_p = core.Primitive('linear_call') linear_call_p.multiple_results = True @@ -1398,7 +1397,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: in_tree, out_type) flat_fwd = _fix_fwd_args(flat_fwd) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + in_avals = [core.get_aval(x) for x in args_flat] fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) prim_tree, res_tree = out_trees() diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 47c7882372ab..d080aae759a6 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -33,8 +33,7 @@ replace_rule_output_symbolic_zeros, Zero, zeros_like_aval) from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 from jax._src.api_util import flatten_fun, flatten_fun_nokwargs -from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, - raise_to_shaped) +from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, @@ -362,7 +361,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! with core.set_current_trace(self.parent_trace): tangents_in = map(instantiate_zeros, tangents_in) @@ -434,8 +433,8 @@ def to_concrete_value(self): def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: - primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) - tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False) + primal_aval = get_aval(primal).strip_weak_type() + tangent_aval = get_aval(tangent).strip_weak_type() assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 2ff27f0c5d74..590e60383b90 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -29,7 +29,7 @@ from jax._src.ad_util import (Zero, instantiate, SymbolicZero, replace_rule_output_symbolic_zeros, add_jaxvals, add_jaxvals_p) -from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName +from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) @@ -217,7 +217,7 @@ def __init__(self, a): self.a = a for d in a.shape)) if type(a) is core.DShapedArray else a for a, e in orig_type if e] - new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens] + new_avals = [core.get_aval(s) for s in segment_lens] sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size for a, d in zip(avals, explicit_in_dims): if isinstance(d, RaggedAxis): @@ -387,7 +387,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, if config.enable_checks.value: assert type(batch_dim) in (NotMapped, int, RaggedAxis) if type(batch_dim) is int: - aval = raise_to_shaped(core.get_aval(val)) + aval = core.get_aval(val) assert 0 <= batch_dim < len(aval.shape) self._trace = trace self.val = val @@ -396,7 +396,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, @property def aval(self): - aval = raise_to_shaped(core.get_aval(self.val)) + aval = core.get_aval(self.val) if self.batch_dim is not_mapped: return aval elif type(self.batch_dim) is int: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 2f63eb386029..c09a8c711984 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -40,7 +40,7 @@ fun_sourceinfo) from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - Var, DropVar, raise_to_shaped, Atom, + Var, DropVar, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) @@ -162,8 +162,7 @@ def new_const(self, val) -> JaxprTracer: def new_instantiated_literal(self, val) -> JaxprTracer: aval = get_aval(val) - return JaxprTracer(self, PartialVal.unknown(aval), - Literal(val, raise_to_shaped(aval))) + return JaxprTracer(self, PartialVal.unknown(aval), Literal(val, aval)) def new_instantiated_const(self, val) -> JaxprTracer: aval = get_aval(val) @@ -201,7 +200,7 @@ def instantiate_const_abstracted(self, tracer) -> JaxprTracer: if const is None: return tracer else: - aval = raise_to_shaped(get_aval(const), np.isscalar(const)) + aval = get_aval(const).update_weak_type(np.isscalar(const)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) def process_primitive(self, primitive, tracers, params): @@ -715,7 +714,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], len(params["in_axes"]) == len(params["call_jaxpr"].invars)) assert ("donated_invars" in params and len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) - out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] + out_avals = [t.aval for t in out_tracers] ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), config.threefry_partitionable.value, @@ -936,7 +935,7 @@ def fun(*known_vals_in): f, in_pvals, instantiate=instantiate) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] - res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals] + res_avals = [core.get_aval(r) for r in residuals] cell.append((out_unknowns, jaxpr_unknown, res_avals)) known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] @@ -1567,7 +1566,7 @@ def get_referent(self): return self if val is None else get_referent(val) def _dynamic_jaxpr_tracer_shaped_abstractify(x): - return core.raise_to_shaped(x.aval) + return x.aval api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: @@ -1827,7 +1826,9 @@ def new_const(self, c): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: - aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) + aval = get_aval(c) + if hasattr(aval, "weak_type"): + aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) aval = self._lift_tracers_in_aval(aval) tracer = self._new_const(aval, c) return tracer @@ -1892,8 +1893,7 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f, explicit_tracers, params): if f.in_type is None: - f = lu.annotate(f, tuple((raise_to_shaped(get_aval(t)), True) - for t in explicit_tracers)) + f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation @@ -2291,7 +2291,7 @@ def _collect_implicit( for i, name in spec.items(): if name not in idxs and id(x.shape[i]) not in explicit_tracers: idxs[name] = DBIdx(next(counter)) - implicit_types.append(raise_to_shaped(get_aval(x.shape[i]))) + implicit_types.append(get_aval(x.shape[i])) if isinstance(x, Tracer): explicit_tracers.setdefault(id(x), explicit_idx) # use the first @@ -2310,7 +2310,7 @@ def _arg_type( ) -> AbstractValue: # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. aval = get_aval(x) # aval.shape could contain Tracers - if not spec: return core.raise_to_shaped(aval) + if not spec: return aval shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d for i, d in enumerate(aval.shape)] assert not any(isinstance(d, Tracer) for d in shape) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 8dae3433e4f6..6333638deae6 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -35,7 +35,7 @@ from jax._src import util from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect -from jax._src.core import raise_to_shaped, replace_jaxpr_effects +from jax._src.core import replace_jaxpr_effects from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -328,7 +328,7 @@ def _cond_abstract_eval(*avals, branches, **_): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') - return map(raise_to_shaped, branches[0].out_avals), joined_effects + return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): if np.ndim(pred) != np.ndim(on_true): @@ -676,7 +676,6 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, def _transpose_cond_jaxpr(jaxpr, num_res): res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) - primal_avals = map(raise_to_shaped, primal_avals) @lu.wrap_init def transposed(*args): @@ -693,7 +692,7 @@ def _cond_transpose(cts, *args, branches): index, *ops = args assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] - in_avals = map(raise_to_shaped, branches[0].in_avals) + in_avals = branches[0].in_avals num_res = len(ops) - sum(linear) if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): @@ -701,8 +700,7 @@ def _cond_transpose(cts, *args, branches): branches_trans = tuple( _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches) - lin_in_avals = [raise_to_shaped(a, weak_type=False) - for a, l in zip(in_avals, linear) if l] + lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l] assert all(core.typematch(out_aval, lin_in_aval) for jaxpr in branches_trans for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index ddbbe0213f6f..19d3429d2675 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -35,7 +35,7 @@ from jax._src import state from jax._src import util from jax._src.api_util import shaped_abstractify -from jax._src.core import ShapedArray, raise_to_shaped +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -262,7 +262,7 @@ def scan(f, init, xs, length=None): stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y - xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat] + xs_avals = [core.get_aval(x) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] def _create_jaxpr(init): @@ -1370,7 +1370,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') - return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects + return body_jaxpr.out_avals, joined_effects def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 9a5a01e3987d..f97377b2df6c 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,7 +23,6 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import linear_util as lu -from jax._src.core import raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -300,7 +299,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return _map(raise_to_shaped, args_to_raise) + return args_to_raise def _custom_linear_solve_impl(*args, const_lengths, jaxprs): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e6dbcbb12a1c..8b6a517a54b3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -48,7 +48,7 @@ from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.core import (Primitive, UnshapedArray, ShapedArray, - raise_to_shaped, abstract_token, canonicalize_shape) + abstract_token, canonicalize_shape) from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -3044,7 +3044,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f" has a representation shape {rep_aval.shape} while the given " f"representation array has shape {x.shape}, so the shape suffix " f"does not match: given {shape_suffix} but required {rep_aval.shape}.") - return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype) + return x.update(shape=shape_prefix, dtype=edtype) to_edtype_p = Primitive('to_edtype') to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) @@ -5246,7 +5246,7 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): def _sort_abstract_eval(*args, **kwargs): - args = tuple(raise_to_shaped(arg) for arg in args) + args = tuple(args) if any(arg.shape != args[0].shape for arg in args[1:]): shapes = " ".join(str(a.shape) for a in args) raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") @@ -6196,7 +6196,7 @@ def _eq_meet(a, b): def _abstractify(x): - return raise_to_shaped(core.get_aval(x)) + return core.get_aval(x) def empty(dtype): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index dc1d1d472ae2..0e0390abc78f 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -33,7 +33,7 @@ from jax._src import dtypes from jax._src import util from jax._src.core import ( - Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape) + Primitive, ShapedArray, is_constant_dim, is_constant_shape) from jax._src.extend import ffi from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -1289,7 +1289,6 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): - pivots = raise_to_shaped(pivots) if isinstance(pivots, ShapedArray): if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32): raise ValueError( @@ -1421,7 +1420,6 @@ def _lu_impl(operand): return lu, pivot, perm def _lu_abstract_eval(operand): - operand = raise_to_shaped(operand) if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 932fd4b88c08..3a1c1ef3bcf1 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -27,7 +27,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import sharding_impls -from jax._src.core import AxisName, ShapedArray, raise_to_shaped +from jax._src.core import AxisName, ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -636,7 +636,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes), + ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype) for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} @@ -817,7 +817,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): _check_axis_names(axis_name) - return raise_to_shaped(x) + return x ppermute_p = core.Primitive('ppermute') ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval) @@ -1019,13 +1019,12 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, def _all_to_all_effectful_abstract_eval( - x, axis_name, split_axis, concat_axis, axis_index_groups, tiled + input_aval, axis_name, split_axis, concat_axis, axis_index_groups, tiled ): del tiled # expand_dims and squeeze is done in `all_to_all` if `True` if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) - input_aval = raise_to_shaped(x) shape = list(input_aval.shape) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) @@ -1169,12 +1168,11 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, def _all_gather_effectful_abstract_eval( - x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled + x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) - x_aval = raise_to_shaped(x) new_shape = list(x_aval.shape) if tiled: new_shape[all_gather_dimension] *= axis_size @@ -1298,12 +1296,11 @@ def _reduce_scatter_lowering( def _reduce_scatter_effectful_abstract_eval( - x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled + x_aval, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) - x_aval = core.raise_to_shaped(x) new_shape = list(x_aval.shape) scatter_dim_input_size = x_aval.shape[scatter_dimension] if tiled: diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 14e9d72af186..4526f12f3cca 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -140,13 +140,6 @@ def __hash__(self): self.memory_space, )) - def at_least_vspace(self): - """Vector space method needed for AD.""" - raise NotImplementedError - - def join(self, other): - raise NotImplementedError - def str_short(self, short_dtypes=False): dt_str = \ dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name @@ -226,11 +219,6 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' - def join(self, other): - assert isinstance(other, AbstractMemoryRef) - return AbstractMemoryRef(self.inner_aval.join(other.inner_aval), - self.memory_space) - def update(self, inner_aval=None, memory_space=None): inner_aval = self.inner_aval if inner_aval is None else inner_aval memory_space = self.memory_space if memory_space is None else memory_space @@ -262,13 +250,6 @@ def __str__(self) -> str: return self.value -def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): - return AbstractMemoryRef( - jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type), - ref_aval.memory_space) -jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped - - @dataclasses.dataclass(frozen=True) class PallasGridContext: grid: GridMappingGrid diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 12ae5350e725..ad9a6cb13f42 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -174,15 +174,6 @@ def get_ref_aval(self) -> AbstractMemoryRef: class AbstractSemaphore(jax_core.AbstractValue): sem_type: SemaphoreType - def join(self, other): - if not isinstance(other, AbstractSemaphore): - raise ValueError - if other.sem_type != self.sem_type: - raise ValueError - return self - -jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval - @dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index bae87226c664..61caa4087d99 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -529,7 +529,8 @@ def export_promela_model( @assume_p.def_abstract_eval def _assume_abstract_eval(x, y): - return x.join(y) + assert jax_core.typematch(x, y) + return x def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y): return y if ctx.lowering_context.for_verification else x diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 7f9e0bef822e..6b2aa5e7219e 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -458,15 +458,9 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): def __repr__(self) -> str: return f'Accumulator{{{self.inner_aval.str_short()}}}' - def join(self, other): - return _as_accum(super().join(other)) - def update(self, inner_aval=None, memory_space=None): return _as_accum(super().update(inner_aval=None, memory_space=None)) - def at_least_vspace(self): - return _as_accum(super().at_least_vspace()) - def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error arr = wgmma_accumulator_deref(tracer) @@ -483,10 +477,6 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: memory_space=ref.memory_space, # pytype: disable=attribute-error ) -def _ref_raise_to_shaped(ref_aval, weak_type): - return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) -jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped - _WARPGROUP_AXIS_NAME = object() diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1a5ed7f0d43e..1ced213394ff 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -567,7 +567,7 @@ def wgmma_accumulator_deref(acc): @wgmma_accumulator_deref_p.def_effectful_abstract_eval def _wgmma_accumulator_deref_abstract_eval(acc): # Dereferencing implies flushing so we have a wgmma pipeline effect. - ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc + ret = acc.inner_aval if isinstance(acc, state.AbstractRef) else acc assert isinstance(ret, jax_core.ShapedArray), acc return ret, {gpu_core._wgmma_pipeline_effect} diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 7724466d3110..0897e778d079 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -230,7 +230,6 @@ def _swap_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) @@ -262,7 +261,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - val_aval = core.raise_to_shaped(val_aval) out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) assert isinstance(val_aval, core.ShapedArray) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 634617102d6c..df3c63606ba4 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -291,15 +291,14 @@ def weak_type(self) -> bool: raise AttributeError return self.inner_aval.weak_type + def update_weak_type(self, weak_type): + return AbstractRef(self.inner_aval.update_weak_type(weak_type)) + def update(self, inner_aval=None): if inner_aval is None: return AbstractRef(self.inner_aval) return AbstractRef(inner_aval) - def join(self, other): - assert isinstance(other, AbstractRef) - return AbstractRef(self.inner_aval.join(other.inner_aval)) - ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) @@ -365,10 +364,6 @@ def __eq__(self, other): def __hash__(self): return hash((self.__class__, self.inner_aval)) -def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type): - return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type)) -core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped - def _map_ref(size, axis, ref_aval): return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval)) diff --git a/jax/core.py b/jax/core.py index fb08763fd3a1..4c9e696db998 100644 --- a/jax/core.py +++ b/jax/core.py @@ -105,7 +105,6 @@ primitive_uses_outfeed as primitive_uses_outfeed, pytype_aval_mappings as pytype_aval_mappings, raise_to_shaped as raise_to_shaped, - raise_to_shaped_mappings as raise_to_shaped_mappings, reset_trace_state as reset_trace_state, set_current_trace as set_current_trace, str_eqn_compact as str_eqn_compact, diff --git a/tests/core_test.py b/tests/core_test.py index 1471e334c880..7ca941c69c7b 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -533,15 +533,6 @@ def test_jaxpr_undefined_eqn_invar(self): r"Variable '.+_test' not defined\n\nin equation:", lambda: core.check_jaxpr(jaxpr)) - @parameterized.parameters( - {'value': 0, 'weak_type': True}, - {'value': np.int32(0), 'weak_type': False}, - {'value': np.array([0]), 'weak_type': False} - ) - def test_raise_to_shaped_weak_type(self, value, weak_type): - aval = core.raise_to_shaped(core.get_aval(value)) - self.assertEqual(aval.weak_type, weak_type) - @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): diff --git a/tests/lax_test.py b/tests/lax_test.py index f2ce0913e03a..12149700cb30 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3821,7 +3821,7 @@ def shard_foo_array_handler(xs, shardings, layouts): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment - aval = core.raise_to_shaped(core.get_aval(x.data)) + aval = core.get_aval(x.data) results.append(pxla.batched_device_put( aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) return results diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index d92991caa6fe..544ed1ac3ecc 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -870,8 +870,9 @@ def scope(): pl.run_scoped(scope) return [] - aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) - in_avals = [aref, aref] + aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + in_avals = [aref1, aref2] stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = state_discharge.discharge_state( diff --git a/tests/state_test.py b/tests/state_test.py index 36e93e88c5e0..c8458742619d 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -746,9 +746,10 @@ def f(a_ref, b_ref): b_ref[...] = jnp.array(1., dtype=jnp.float32) return a_ref[...], b_ref[...] - scalar_ref = shaped_array_ref((), jnp.float32) + scalar_ref_1 = shaped_array_ref((), jnp.float32) + scalar_ref_2 = shaped_array_ref((), jnp.float32) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [scalar_ref, scalar_ref]) + lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) From d4ddabb1d1cc97746d759913152f1d71b0fda43d Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Sun, 3 Nov 2024 08:38:52 -0800 Subject: [PATCH 180/698] Disable breaking testAutoPgleWithPersistentCache. The test is unstable and it seems to be caused by temporary directory not being properly allocated/released in case of parallel execution of tests via pytest-xdist. Way to reproduce: ``` python3.10 -m pytest -n 2 --tb=short --maxfail=20 tests/pgle_test.py ``` Note, if `-n` argument is 1 (i.e. behave like regular single-threaded pytest) the tests pass. It also passes in bazel, since bazel better isolates different tests from each other. PiperOrigin-RevId: 692707058 --- tests/pgle_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index cf248d1c2cea..5f0c28541b62 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -177,6 +177,7 @@ def f(x): self.assertArraysEqual(compiled(x), expected) self.assertEqual(cache_miss_count[0], 0) + @unittest.skip("Test failing in CI") def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) From e1270533045ee17007247afea55a556750793e02 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Sun, 3 Nov 2024 08:40:39 -0800 Subject: [PATCH 181/698] Fix ColocatedPythonTest. The test has been failing only on pytest nighlies because there are more GPU devices than CPU devices available, but the tests was making assumption that number of cpu devices is always bigger. PiperOrigin-RevId: 692707314 --- tests/colocated_python_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 602904757d7a..9f65e3aeced4 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -36,11 +36,12 @@ def _colocated_cpu_devices( # PjRt-IFRT prepares CPU devices by its own. cpu_backend_devices = jax.local_devices(backend="cpu") device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + + available_devices = devices[:min(len(cpu_backend_devices), len(devices))] return [ - cpu_backend_devices[device_index_map[device.id]] for device in devices + cpu_backend_devices[device_index_map[d.id]] for d in available_devices ] - @contextlib.contextmanager def _count_colocated_python_specialization_cache_miss() -> list[int]: """Counts the number of cache misses for colocated_python specialization.""" From 2d94914c80a492bc722658e24702579332adbd6a Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 3 Nov 2024 13:00:32 -0800 Subject: [PATCH 182/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/336cc6ede42b22e6b1336e5e12c23a8b9e37db20. PiperOrigin-RevId: 692744387 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fffc475fa1b9..72fa6fef9bff 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "cbf3c8d5deae86c162725a30f2871648430990d0" -XLA_SHA256 = "7bf0a7fd6bb2eeb54c386f47b353834d15eda4a2d7ca383a4d42c6d444b484eb" +XLA_COMMIT = "336cc6ede42b22e6b1336e5e12c23a8b9e37db20" +XLA_SHA256 = "2853a33bf320f1b61ba82d5adc9b1cf92ab9ea28f392f37df423bbcecd192140" def repo(): tf_http_archive( From c52b3227d1469beb664920c270a581960613437c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 4 Nov 2024 06:37:38 -0800 Subject: [PATCH 183/698] [pallas:mosaic_gpu] Added a 2D test for `emit_pipeline` PiperOrigin-RevId: 692945663 --- tests/pallas/mosaic_gpu_test.py | 42 +++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f60c6c7c6023..f427b91b5e94 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1025,7 +1025,9 @@ def kernel(x_ref, o_ref): class PipelineTest(PallasTest): - def test_manual(self, max_concurrent_steps=2, num_steps=4): + def test_manual(self): + max_concurrent_steps = 2 + num_steps = 4 def kernel(x_gmem, o_gmem): return pl.run_scoped( @@ -1089,14 +1091,16 @@ def body(step, _): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) - def test_emit(self, max_concurrent_steps=2, num_steps=4): + def test_emit(self): + num_steps = 4 + def kernel(x_gmem, o_gmem): plgpu.emit_pipeline( kernel_body, in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], grid=(num_steps,), - max_concurrent_steps=max_concurrent_steps, + max_concurrent_steps=2, )(x_gmem, o_gmem) def kernel_body(x_smem, o_smem): @@ -1112,9 +1116,11 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) - def test_emit_with_parallel_grid(self, max_concurrent_steps=2, num_steps=4): + def test_emit_with_parallel_grid(self): self.skipTest("Enable once we support multiple levels of indexing") + num_steps = 4 + def kernel(x_gmem, o_gmem): gmem_slice = pl.ds(pl.program_id(0) * 32, 32) plgpu.emit_pipeline( @@ -1122,7 +1128,7 @@ def kernel(x_gmem, o_gmem): in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], grid=(num_steps,), - max_concurrent_steps=max_concurrent_steps, + max_concurrent_steps=2, )(x_gmem.at[gmem_slice], o_gmem.at[gmem_slice]) def kernel_body(x_smem, o_smem): @@ -1139,6 +1145,32 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit_with_2d_grid(self): + num_steps1 = 4 + num_steps2 = 5 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], + out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], + grid=(num_steps1, num_steps2), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) + x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + class CoreMapTest(PallasTest): From f281c6f46475270a57a02416469226315377592c Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 4 Nov 2024 06:52:49 -0800 Subject: [PATCH 184/698] Reverts ec39b592f7c096b0b8183723feaab2ed0d001041 PiperOrigin-RevId: 692949053 --- jax/_src/ad_util.py | 3 +- jax/_src/core.py | 112 ++++++++++++++++------ jax/_src/custom_derivatives.py | 27 +++--- jax/_src/interpreters/ad.py | 9 +- jax/_src/interpreters/batching.py | 8 +- jax/_src/interpreters/partial_eval.py | 24 ++--- jax/_src/lax/control_flow/conditionals.py | 10 +- jax/_src/lax/control_flow/loops.py | 6 +- jax/_src/lax/control_flow/solves.py | 3 +- jax/_src/lax/lax.py | 8 +- jax/_src/lax/linalg.py | 4 +- jax/_src/lax/parallel.py | 15 +-- jax/_src/pallas/core.py | 19 ++++ jax/_src/pallas/mosaic/core.py | 9 ++ jax/_src/pallas/mosaic/verification.py | 3 +- jax/_src/pallas/mosaic_gpu/core.py | 10 ++ jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- jax/_src/state/primitives.py | 2 + jax/_src/state/types.py | 11 ++- jax/core.py | 1 + tests/core_test.py | 9 ++ tests/lax_test.py | 2 +- tests/pallas/tpu_pallas_test.py | 5 +- tests/state_test.py | 5 +- 24 files changed, 211 insertions(+), 96 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 02f3b0405e38..bd1427f59e01 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -43,8 +43,7 @@ def add_impl(x, y): @add_jaxvals_p.def_abstract_eval def add_abstract(x, y): - assert core.typematch(x, y) - return x + return core.lattice_join(x, y) def zeros_like_aval(aval: core.AbstractValue) -> Array: return aval_zeros_likers[type(aval)](aval) diff --git a/jax/_src/core.py b/jax/_src/core.py index 1fecf3f18b43..7d912e3c207b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -368,7 +368,7 @@ class Var: def __init__(self, suffix: str, aval: AbstractValue): self.count = next(_var_counter) self.suffix = suffix - self.aval = aval + self.aval = raise_to_shaped(aval) # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not # care about variable ordering, but the downstream package kfac_jax does. @@ -662,7 +662,7 @@ def __init__(self, trace: Trace): def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" - return f"traced array with shape {self.aval.str_short()}" + return f"traced array with shape {raise_to_shaped(self.aval).str_short()}" def __array__(self, *args, **kw): raise TracerArrayConversionError(self) @@ -1302,11 +1302,11 @@ def __repr__(self): except AttributeError: return self.__class__.__name__ - def update_weak_type(self, weak_type): + def strip_weak_type(self) -> AbstractValue: return self - def strip_weak_type(self) -> AbstractValue: - return self.update_weak_type(False) + def join(self, other): + raise NotImplementedError("must override") def update(self, **kwargs): raise NotImplementedError("must override") @@ -1314,6 +1314,7 @@ def update(self, **kwargs): def str_short(self, short_dtypes=False): return str(self) + # For type signatures involving dynamic shapes, we use lists of abstract values # which may contain (reverse) de Bruijn indices in their shapes. class DBIdx(NamedTuple): @@ -1347,10 +1348,26 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -# TODO(dougalm): Deprecate. This is here for backwards compat. -def lattice_join(x, y): - assert typematch(x, y) - return x +class Bot(AbstractValue): pass +bot = Bot() + + +def lattice_join(x: AbstractValue | None, + y: AbstractValue | None) -> AbstractValue: + if x is None: + assert y is not None + return y + elif y is None: + return x + elif isinstance(x, type(y)): + return y.join(x) + elif isinstance(y, type(x)): + return x.join(y) + elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray): + # TODO(mattjj): remove this special case after dynamic shapes are integrated + return x.join(y) + else: + raise TypeError(x, y) # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1513,8 +1530,9 @@ def __repr__(self): def str_short(self, short_dtypes=False) -> str: return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - def update_weak_type(self, weak_type): - return self.update(weak_type=weak_type) + def strip_weak_type(self): + """Returns a copy of the aval with weak_type=False.""" + return self.update(weak_type=False) def _canonicalize_dimension(dim: DimSize) -> DimSize: # Dimensions are most commonly integral (by far), so we check that first. @@ -1638,6 +1656,13 @@ def to_tangent_aval(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) + def join(self, other): + if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype: + weak_type = self.weak_type and other.weak_type + return self.update(weak_type=weak_type) + else: + raise TypeError(self, other) + def str_short(self, short_dtypes=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name) @@ -1737,6 +1762,14 @@ def __eq__(self, other): def __hash__(self): return hash((self.shape, self.dtype, self.weak_type)) + def join(self, other): + if (definitely_equal_shape(self.shape, other.shape) and + self.dtype == other.dtype): + weak_type = self.weak_type and other.weak_type + return self.update(weak_type=weak_type) + else: + raise TypeError(self, other) + def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1848,11 +1881,16 @@ def mutable_array_abstract_eval(init_aval): @mutable_array_p.def_impl def _mutable_array_impl(init_val): from jax._src.state.types import AbstractRef # pytype: disable=import-error - aval = get_aval(init_val) + aval = raise_to_shaped(get_aval(init_val)) return MutableArray(AbstractRef(aval), init_val) class AbstractToken(AbstractValue): + def join(self, other): + if isinstance(other, AbstractToken): + return self + else: + assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() @@ -1872,9 +1910,30 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -# TODO(dougalm): Deprecate. This is just here for backwards compat. -def raise_to_shaped(aval): - return aval +def raise_to_shaped(aval: AbstractValue, weak_type=None): + aval_type = type(aval) + if aval_type is ShapedArray and weak_type is None: + return aval + if aval_type is DShapedArray and weak_type is None: + return aval + if weak_type is None: + weak_type = getattr(aval, 'weak_type', False) + for typ in aval_type.__mro__: + handler = raise_to_shaped_mappings.get(typ) + if handler: return handler(aval, weak_type) + raise TypeError(type(aval)) + +def _shaped_array_mapping(aval, weak_type): + if config.sharding_in_types.value: + return ShapedArray(aval.shape, aval.dtype, weak_type, sharding=aval.sharding) + return ShapedArray(aval.shape, aval.dtype, weak_type) + +raise_to_shaped_mappings: dict[type, Callable] = { + AbstractToken: lambda aval, _: aval, + Bot: lambda aval, _: aval, + ShapedArray: _shaped_array_mapping, + DShapedArray: lambda aval, _: aval +} ### Operations on shapes and dimension sizes. @@ -2282,23 +2341,18 @@ def typecheck(aval: AbstractValue, x) -> bool: def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: """Determine whether `aval` conforms to `aval_ref`. Ignores weak_type.""" try: - return typematch(aval_ref, aval) + return typematch(aval_ref, lattice_join(aval_ref, aval)) except TypeError: return False -def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: - """Determine whether `t1` and `t2` are equivalent. Ignores weak_type.""" - t1 = t1.strip_weak_type() - t2 = t2.strip_weak_type() - if t1 == t2: - return True - elif (isinstance(t1, (ShapedArray, DShapedArray)) and - isinstance(t2, (ShapedArray, DShapedArray))): - # This case handles DShapedArray and shape polynomials. Alternatively we - # could try normalizing first and then doing simple equality. - return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) - else: - return False +def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool: + """Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type.""" + if aval1 == aval2: return True + # unequal avals may still represent the same type, because type is represented + # by avals at the shaped level, and because weak type tags aren't considered + # part of the type + return (raise_to_shaped(aval1, weak_type=False) == + raise_to_shaped(aval2, weak_type=False)) class JaxprTypeError(TypeError): pass diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 375efeb712b8..77f73562aecd 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,6 +31,7 @@ stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs) +from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -80,7 +81,7 @@ def _flatten_fun_nokwargs(in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = yield py_args, {} ans_flat, ans_tree = tree_flatten(ans) - ans_avals = [core.get_aval(x) for x in ans_flat] + ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat] yield ans_flat, (ans_tree, ans_avals) @@ -286,7 +287,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): py_primals_out, py_tangents_out = pair_out primals_out, out_tree = tree_flatten(py_primals_out) tangents_out, out_tree2 = tree_flatten(py_tangents_out) - primal_avals = [core.get_aval(x) for x in primals_out] + primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] if out_tree != out_tree2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must " "produce primal and tangent outputs with equal container (pytree) " @@ -326,11 +327,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - primal_avals_out = [core.get_aval(x).strip_weak_type() for x in primals_out] + primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] expected_tangent_avals_out = [ - core.get_aval(x).strip_weak_type().to_tangent_aval() + raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval() for x in primals_out] - tangent_avals_out = [core.get_aval(t).strip_weak_type() + tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] if expected_tangent_avals_out != tangent_avals_out: @@ -605,7 +606,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable f_, dyn_args = lu.wrap_init(self.fun), args fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) - in_avals = [core.get_aval(x) for x in args_flat] + in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name, fwd_name, in_tree, out_type) @@ -673,7 +674,7 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, py_primals_out, res = pair_out primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) - primal_avals = [core.get_aval(x) for x in primals_out] + primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] # If the primal function already ran, check out_tree agreement. try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None @@ -771,7 +772,7 @@ def append(x, d): msg = ("Custom VJP bwd rule must produce an output with the same " "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " - f"shape/dtype {a_.str_short()} corresponding " + f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding " f"to an input of shape/dtype {a.str_short()}.") raise ValueError(msg) results.append(ct) @@ -830,7 +831,7 @@ def _custom_vjp_call_jaxpr_jvp( _, res_tree = out_trees() res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, @@ -1109,7 +1110,7 @@ def merge(l1, l2): return out, merge def abstractify(x): - return core.get_aval(x) + return core.raise_to_shaped(core.get_aval(x)) ### Custom transposition @@ -1210,7 +1211,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) - out_avals = f_jaxpr.out_avals + out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) @@ -1264,7 +1265,7 @@ def _linear_call_transpose_rule(cts, *args, callee, transpose, return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): - return kwargs['callee'].out_avals + return map(core.raise_to_shaped, kwargs['callee'].out_avals) linear_call_p = core.Primitive('linear_call') linear_call_p.multiple_results = True @@ -1397,7 +1398,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: in_tree, out_type) flat_fwd = _fix_fwd_args(flat_fwd) - in_avals = [core.get_aval(x) for x in args_flat] + in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) prim_tree, res_tree = out_trees() diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d080aae759a6..47c7882372ab 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -33,7 +33,8 @@ replace_rule_output_symbolic_zeros, Zero, zeros_like_aval) from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 from jax._src.api_util import flatten_fun, flatten_fun_nokwargs -from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) +from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, + raise_to_shaped) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, @@ -361,7 +362,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! with core.set_current_trace(self.parent_trace): tangents_in = map(instantiate_zeros, tangents_in) @@ -433,8 +434,8 @@ def to_concrete_value(self): def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: - primal_aval = get_aval(primal).strip_weak_type() - tangent_aval = get_aval(tangent).strip_weak_type() + primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) + tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False) assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 590e60383b90..2ff27f0c5d74 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -29,7 +29,7 @@ from jax._src.ad_util import (Zero, instantiate, SymbolicZero, replace_rule_output_symbolic_zeros, add_jaxvals, add_jaxvals_p) -from jax._src.core import Trace, Tracer, TraceTag, AxisName +from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) @@ -217,7 +217,7 @@ def __init__(self, a): self.a = a for d in a.shape)) if type(a) is core.DShapedArray else a for a, e in orig_type if e] - new_avals = [core.get_aval(s) for s in segment_lens] + new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens] sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size for a, d in zip(avals, explicit_in_dims): if isinstance(d, RaggedAxis): @@ -387,7 +387,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, if config.enable_checks.value: assert type(batch_dim) in (NotMapped, int, RaggedAxis) if type(batch_dim) is int: - aval = core.get_aval(val) + aval = raise_to_shaped(core.get_aval(val)) assert 0 <= batch_dim < len(aval.shape) self._trace = trace self.val = val @@ -396,7 +396,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, @property def aval(self): - aval = core.get_aval(self.val) + aval = raise_to_shaped(core.get_aval(self.val)) if self.batch_dim is not_mapped: return aval elif type(self.batch_dim) is int: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index c09a8c711984..2f63eb386029 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -40,7 +40,7 @@ fun_sourceinfo) from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - Var, DropVar, Atom, + Var, DropVar, raise_to_shaped, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) @@ -162,7 +162,8 @@ def new_const(self, val) -> JaxprTracer: def new_instantiated_literal(self, val) -> JaxprTracer: aval = get_aval(val) - return JaxprTracer(self, PartialVal.unknown(aval), Literal(val, aval)) + return JaxprTracer(self, PartialVal.unknown(aval), + Literal(val, raise_to_shaped(aval))) def new_instantiated_const(self, val) -> JaxprTracer: aval = get_aval(val) @@ -200,7 +201,7 @@ def instantiate_const_abstracted(self, tracer) -> JaxprTracer: if const is None: return tracer else: - aval = get_aval(const).update_weak_type(np.isscalar(const)) + aval = raise_to_shaped(get_aval(const), np.isscalar(const)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) def process_primitive(self, primitive, tracers, params): @@ -714,7 +715,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], len(params["in_axes"]) == len(params["call_jaxpr"].invars)) assert ("donated_invars" in params and len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) - out_avals = [t.aval for t in out_tracers] + out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), config.threefry_partitionable.value, @@ -935,7 +936,7 @@ def fun(*known_vals_in): f, in_pvals, instantiate=instantiate) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] - res_avals = [core.get_aval(r) for r in residuals] + res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals] cell.append((out_unknowns, jaxpr_unknown, res_avals)) known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] @@ -1566,7 +1567,7 @@ def get_referent(self): return self if val is None else get_referent(val) def _dynamic_jaxpr_tracer_shaped_abstractify(x): - return x.aval + return core.raise_to_shaped(x.aval) api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: @@ -1826,9 +1827,7 @@ def new_const(self, c): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: - aval = get_aval(c) - if hasattr(aval, "weak_type"): - aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) + aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) aval = self._lift_tracers_in_aval(aval) tracer = self._new_const(aval, c) return tracer @@ -1893,7 +1892,8 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f, explicit_tracers, params): if f.in_type is None: - f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) + f = lu.annotate(f, tuple((raise_to_shaped(get_aval(t)), True) + for t in explicit_tracers)) implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation @@ -2291,7 +2291,7 @@ def _collect_implicit( for i, name in spec.items(): if name not in idxs and id(x.shape[i]) not in explicit_tracers: idxs[name] = DBIdx(next(counter)) - implicit_types.append(get_aval(x.shape[i])) + implicit_types.append(raise_to_shaped(get_aval(x.shape[i]))) if isinstance(x, Tracer): explicit_tracers.setdefault(id(x), explicit_idx) # use the first @@ -2310,7 +2310,7 @@ def _arg_type( ) -> AbstractValue: # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. aval = get_aval(x) # aval.shape could contain Tracers - if not spec: return aval + if not spec: return core.raise_to_shaped(aval) shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d for i, d in enumerate(aval.shape)] assert not any(isinstance(d, Tracer) for d in shape) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 6333638deae6..8dae3433e4f6 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -35,7 +35,7 @@ from jax._src import util from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect -from jax._src.core import replace_jaxpr_effects +from jax._src.core import raise_to_shaped, replace_jaxpr_effects from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -328,7 +328,7 @@ def _cond_abstract_eval(*avals, branches, **_): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') - return branches[0].out_avals, joined_effects + return map(raise_to_shaped, branches[0].out_avals), joined_effects def _bcast_select(pred, on_true, on_false): if np.ndim(pred) != np.ndim(on_true): @@ -676,6 +676,7 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, def _transpose_cond_jaxpr(jaxpr, num_res): res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) + primal_avals = map(raise_to_shaped, primal_avals) @lu.wrap_init def transposed(*args): @@ -692,7 +693,7 @@ def _cond_transpose(cts, *args, branches): index, *ops = args assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] - in_avals = branches[0].in_avals + in_avals = map(raise_to_shaped, branches[0].in_avals) num_res = len(ops) - sum(linear) if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): @@ -700,7 +701,8 @@ def _cond_transpose(cts, *args, branches): branches_trans = tuple( _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches) - lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l] + lin_in_avals = [raise_to_shaped(a, weak_type=False) + for a, l in zip(in_avals, linear) if l] assert all(core.typematch(out_aval, lin_in_aval) for jaxpr in branches_trans for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 19d3429d2675..ddbbe0213f6f 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -35,7 +35,7 @@ from jax._src import state from jax._src import util from jax._src.api_util import shaped_abstractify -from jax._src.core import ShapedArray +from jax._src.core import ShapedArray, raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -262,7 +262,7 @@ def scan(f, init, xs, length=None): stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y - xs_avals = [core.get_aval(x) for x in xs_flat] + xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] def _create_jaxpr(init): @@ -1370,7 +1370,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') - return body_jaxpr.out_avals, joined_effects + return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index f97377b2df6c..9a5a01e3987d 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,6 +23,7 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import linear_util as lu +from jax._src.core import raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -299,7 +300,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return args_to_raise + return _map(raise_to_shaped, args_to_raise) def _custom_linear_solve_impl(*args, const_lengths, jaxprs): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8b6a517a54b3..e6dbcbb12a1c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -48,7 +48,7 @@ from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.core import (Primitive, UnshapedArray, ShapedArray, - abstract_token, canonicalize_shape) + raise_to_shaped, abstract_token, canonicalize_shape) from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -3044,7 +3044,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f" has a representation shape {rep_aval.shape} while the given " f"representation array has shape {x.shape}, so the shape suffix " f"does not match: given {shape_suffix} but required {rep_aval.shape}.") - return x.update(shape=shape_prefix, dtype=edtype) + return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype) to_edtype_p = Primitive('to_edtype') to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) @@ -5246,7 +5246,7 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): def _sort_abstract_eval(*args, **kwargs): - args = tuple(args) + args = tuple(raise_to_shaped(arg) for arg in args) if any(arg.shape != args[0].shape for arg in args[1:]): shapes = " ".join(str(a.shape) for a in args) raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") @@ -6196,7 +6196,7 @@ def _eq_meet(a, b): def _abstractify(x): - return core.get_aval(x) + return raise_to_shaped(core.get_aval(x)) def empty(dtype): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0e0390abc78f..dc1d1d472ae2 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -33,7 +33,7 @@ from jax._src import dtypes from jax._src import util from jax._src.core import ( - Primitive, ShapedArray, is_constant_dim, is_constant_shape) + Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape) from jax._src.extend import ffi from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -1289,6 +1289,7 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): + pivots = raise_to_shaped(pivots) if isinstance(pivots, ShapedArray): if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32): raise ValueError( @@ -1420,6 +1421,7 @@ def _lu_impl(operand): return lu, pivot, perm def _lu_abstract_eval(operand): + operand = raise_to_shaped(operand) if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 3a1c1ef3bcf1..932fd4b88c08 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -27,7 +27,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import sharding_impls -from jax._src.core import AxisName, ShapedArray +from jax._src.core import AxisName, ShapedArray, raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -636,7 +636,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), + ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes), arg.dtype) for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} @@ -817,7 +817,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): _check_axis_names(axis_name) - return x + return raise_to_shaped(x) ppermute_p = core.Primitive('ppermute') ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval) @@ -1019,12 +1019,13 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, def _all_to_all_effectful_abstract_eval( - input_aval, axis_name, split_axis, concat_axis, axis_index_groups, tiled + x, axis_name, split_axis, concat_axis, axis_index_groups, tiled ): del tiled # expand_dims and squeeze is done in `all_to_all` if `True` if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) + input_aval = raise_to_shaped(x) shape = list(input_aval.shape) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) @@ -1168,11 +1169,12 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, def _all_gather_effectful_abstract_eval( - x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled + x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) + x_aval = raise_to_shaped(x) new_shape = list(x_aval.shape) if tiled: new_shape[all_gather_dimension] *= axis_size @@ -1296,11 +1298,12 @@ def _reduce_scatter_lowering( def _reduce_scatter_effectful_abstract_eval( - x_aval, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled + x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) + x_aval = core.raise_to_shaped(x) new_shape = list(x_aval.shape) scatter_dim_input_size = x_aval.shape[scatter_dimension] if tiled: diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 4526f12f3cca..14e9d72af186 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -140,6 +140,13 @@ def __hash__(self): self.memory_space, )) + def at_least_vspace(self): + """Vector space method needed for AD.""" + raise NotImplementedError + + def join(self, other): + raise NotImplementedError + def str_short(self, short_dtypes=False): dt_str = \ dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name @@ -219,6 +226,11 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' + def join(self, other): + assert isinstance(other, AbstractMemoryRef) + return AbstractMemoryRef(self.inner_aval.join(other.inner_aval), + self.memory_space) + def update(self, inner_aval=None, memory_space=None): inner_aval = self.inner_aval if inner_aval is None else inner_aval memory_space = self.memory_space if memory_space is None else memory_space @@ -250,6 +262,13 @@ def __str__(self) -> str: return self.value +def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): + return AbstractMemoryRef( + jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type), + ref_aval.memory_space) +jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped + + @dataclasses.dataclass(frozen=True) class PallasGridContext: grid: GridMappingGrid diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index ad9a6cb13f42..12ae5350e725 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -174,6 +174,15 @@ def get_ref_aval(self) -> AbstractMemoryRef: class AbstractSemaphore(jax_core.AbstractValue): sem_type: SemaphoreType + def join(self, other): + if not isinstance(other, AbstractSemaphore): + raise ValueError + if other.sem_type != self.sem_type: + raise ValueError + return self + +jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval + @dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index 61caa4087d99..bae87226c664 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -529,8 +529,7 @@ def export_promela_model( @assume_p.def_abstract_eval def _assume_abstract_eval(x, y): - assert jax_core.typematch(x, y) - return x + return x.join(y) def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y): return y if ctx.lowering_context.for_verification else x diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 6b2aa5e7219e..7f9e0bef822e 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -458,9 +458,15 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): def __repr__(self) -> str: return f'Accumulator{{{self.inner_aval.str_short()}}}' + def join(self, other): + return _as_accum(super().join(other)) + def update(self, inner_aval=None, memory_space=None): return _as_accum(super().update(inner_aval=None, memory_space=None)) + def at_least_vspace(self): + return _as_accum(super().at_least_vspace()) + def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error arr = wgmma_accumulator_deref(tracer) @@ -477,6 +483,10 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: memory_space=ref.memory_space, # pytype: disable=attribute-error ) +def _ref_raise_to_shaped(ref_aval, weak_type): + return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) +jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped + _WARPGROUP_AXIS_NAME = object() diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1ced213394ff..1a5ed7f0d43e 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -567,7 +567,7 @@ def wgmma_accumulator_deref(acc): @wgmma_accumulator_deref_p.def_effectful_abstract_eval def _wgmma_accumulator_deref_abstract_eval(acc): # Dereferencing implies flushing so we have a wgmma pipeline effect. - ret = acc.inner_aval if isinstance(acc, state.AbstractRef) else acc + ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc assert isinstance(ret, jax_core.ShapedArray), acc return ret, {gpu_core._wgmma_pipeline_effect} diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 0897e778d079..7724466d3110 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -230,6 +230,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): + val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) @@ -261,6 +262,7 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): + val_aval = core.raise_to_shaped(val_aval) out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) assert isinstance(val_aval, core.ShapedArray) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index df3c63606ba4..634617102d6c 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -291,14 +291,15 @@ def weak_type(self) -> bool: raise AttributeError return self.inner_aval.weak_type - def update_weak_type(self, weak_type): - return AbstractRef(self.inner_aval.update_weak_type(weak_type)) - def update(self, inner_aval=None): if inner_aval is None: return AbstractRef(self.inner_aval) return AbstractRef(inner_aval) + def join(self, other): + assert isinstance(other, AbstractRef) + return AbstractRef(self.inner_aval.join(other.inner_aval)) + ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) @@ -364,6 +365,10 @@ def __eq__(self, other): def __hash__(self): return hash((self.__class__, self.inner_aval)) +def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type): + return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type)) +core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped + def _map_ref(size, axis, ref_aval): return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval)) diff --git a/jax/core.py b/jax/core.py index 4c9e696db998..fb08763fd3a1 100644 --- a/jax/core.py +++ b/jax/core.py @@ -105,6 +105,7 @@ primitive_uses_outfeed as primitive_uses_outfeed, pytype_aval_mappings as pytype_aval_mappings, raise_to_shaped as raise_to_shaped, + raise_to_shaped_mappings as raise_to_shaped_mappings, reset_trace_state as reset_trace_state, set_current_trace as set_current_trace, str_eqn_compact as str_eqn_compact, diff --git a/tests/core_test.py b/tests/core_test.py index 7ca941c69c7b..1471e334c880 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -533,6 +533,15 @@ def test_jaxpr_undefined_eqn_invar(self): r"Variable '.+_test' not defined\n\nin equation:", lambda: core.check_jaxpr(jaxpr)) + @parameterized.parameters( + {'value': 0, 'weak_type': True}, + {'value': np.int32(0), 'weak_type': False}, + {'value': np.array([0]), 'weak_type': False} + ) + def test_raise_to_shaped_weak_type(self, value, weak_type): + aval = core.raise_to_shaped(core.get_aval(value)) + self.assertEqual(aval.weak_type, weak_type) + @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): diff --git a/tests/lax_test.py b/tests/lax_test.py index 12149700cb30..f2ce0913e03a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3821,7 +3821,7 @@ def shard_foo_array_handler(xs, shardings, layouts): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment - aval = core.get_aval(x.data) + aval = core.raise_to_shaped(core.get_aval(x.data)) results.append(pxla.batched_device_put( aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) return results diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 544ed1ac3ecc..d92991caa6fe 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -870,9 +870,8 @@ def scope(): pl.run_scoped(scope) return [] - aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) - aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) - in_avals = [aref1, aref2] + aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + in_avals = [aref, aref] stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = state_discharge.discharge_state( diff --git a/tests/state_test.py b/tests/state_test.py index c8458742619d..36e93e88c5e0 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -746,10 +746,9 @@ def f(a_ref, b_ref): b_ref[...] = jnp.array(1., dtype=jnp.float32) return a_ref[...], b_ref[...] - scalar_ref_1 = shaped_array_ref((), jnp.float32) - scalar_ref_2 = shaped_array_ref((), jnp.float32) + scalar_ref = shaped_array_ref((), jnp.float32) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) + lu.wrap_init(f), [scalar_ref, scalar_ref]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) From 95146deb6b4123742f946f05a4dea1727867a11e Mon Sep 17 00:00:00 2001 From: jiaxi98 Date: Mon, 4 Nov 2024 23:52:54 +0800 Subject: [PATCH 185/698] issue #24691 --- jax/_src/numpy/linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 14798f6f6913..03f864919887 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -91,8 +91,8 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: Args: a: input array, representing a (batched) positive-definite hermitian matrix. Must have shape ``(..., N, N)``. - upper: if True, compute the upper Cholesky decomposition `L`. if False - (default), compute the lower Cholesky decomposition `U`. + upper: if True, compute the upper Cholesky decomposition `U`. if False + (default), compute the lower Cholesky decomposition `L`. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition From 9df719f83f5f3eee27b22786c3b9f9d2833767d6 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Mon, 4 Nov 2024 11:17:25 +0200 Subject: [PATCH 186/698] Fix `_cuda_path` for case when `cuda_nvcc` is a namespace package `cuda_nvcc`, when installed e.g. via `pip` in a `venv` comes out as a namespace package. The previous logic found the `cuda_nvcc` import but failed because `cuda_nvcc.__file__ is None`. --- jax/_src/lib/__init__.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index a3be2390d856..7f9936a7d180 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -150,7 +150,20 @@ def _try_cuda_nvcc_import() -> str | None: from nvidia import cuda_nvcc # pytype: disable=import-error except ImportError: return None - cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent + + if hasattr(cuda_nvcc, '__file__') and cuda_nvcc.__file__ is not None: + # `cuda_nvcc` is a regular package. + cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent + elif hasattr(cuda_nvcc, '__path__') and cuda_nvcc.__path__ is not None: + # `cuda_nvcc` is a namespace package, which might have multiple paths. + cuda_nvcc_path = None + for path in cuda_nvcc.__path__: + if (pathlib.Path(path) / 'bin' / 'ptxas').exists(): + cuda_nvcc_path = pathlib.Path(path) + break + else: + return None + return str(cuda_nvcc_path) if (path := _try_cuda_root_environment_variable()) is not None: From 3544efcade293ec3244ee2c85b431bcf955dc9b9 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 4 Nov 2024 08:12:00 -0800 Subject: [PATCH 187/698] #sdy Fix Shardy bug where we weren't setting shmap in/out shardings as open. If I revert the change in `shard_map.py`, then the unit test added `test_partial_auto_propagate_through` fails with: ``` self.assertEqual(actual.sharding, sharding) AssertionError: Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec(), memory_kind=device) != Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec('i',), memory_kind=device) ``` PiperOrigin-RevId: 692971413 --- jax/experimental/shard_map.py | 12 +++++++---- tests/shard_map_test.py | 39 +++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 615fd3128309..c67b4f68cc9b 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -606,14 +606,18 @@ def _rule_missing(prim: core.Primitive, *_, **__): # Lowering def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, names, aval_in + ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in ) -> ir.Attribute: axes = {name: i for i, ns in names.items() for name in ns} ns = _make_scoped_manual_sharding(ctx, mesh, axes) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) - return ns._to_sdy_sharding(aval_in.ndim).build() + sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) + if auto: + for dim_sharding in sdy_sharding.dimension_shardings: + dim_sharding.is_closed = False + return sdy_sharding.build() def _shard_map_lowering_shardy( @@ -643,10 +647,10 @@ def _shard_map_lowering_shardy( return out_nodes in_shardings = sdy.TensorShardingPerValueAttr.get(map( - partial(_shardy_shard_map_sharding, ctx, mesh), + partial(_shardy_shard_map_sharding, ctx, mesh, auto), in_names, ctx.avals_in)) out_shardings = sdy.TensorShardingPerValueAttr.get(map( - partial(_shardy_shard_map_sharding, ctx, mesh), + partial(_shardy_shard_map_sharding, ctx, mesh, auto), out_names, ctx.avals_out)) output_types = map(mlir.aval_to_ir_type, ctx.avals_out) manual_computation_op = sdy.ManualComputationOp( diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3541e331e869..48850c8da66a 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1824,8 +1824,8 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) if config.use_shardy_partitioner.value: self.assertIn( - 'in_shardings=[<@mesh, [{"i"}, {}]>] out_shardings=[<@mesh, [{"i"},' - ' {}]>] manual_axes={"i"}', + 'in_shardings=[<@mesh, [{"i", ?}, {?}]>]' + ' out_shardings=[<@mesh, [{"i", ?}, {?}]>] manual_axes={"i"}', f.lower(v).as_text(), ) else: @@ -1836,6 +1836,41 @@ def f(x): ) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_partial_auto_propagate_through(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) + sharding = jax.sharding.NamedSharding(mesh, P('i')) + + def g(x): + return jax.lax.with_sharding_constraint(x * x, sharding) + + @jax.jit + def f(x): + return shard_map( + g, + mesh, + in_specs=P(), + out_specs=P(), + check_rep=False, + auto=frozenset({'i'}), + )(x) + + v = jnp.arange(32.0).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i'))) + if config.use_shardy_partitioner.value: + self.assertIn( + 'in_shardings=[<@mesh, [{?}, {?}]>]' + ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j"}', + f.lower(v).as_text(), + ) + else: + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[2,2]T(1,0) last_tile_dims={manual, replicated}}', + f.lower(v).as_text('hlo'), + ) + actual = f(v) + self.assertAllClose(v * v, actual, check_dtypes=False) + self.assertEqual(actual.sharding, sharding) + def test_sharded_prng_with_abstract_mesh(self): shape = (8, 2, 2) mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) From d2bbd5640592281c2725fdd96cdb88710e061a38 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 4 Nov 2024 08:28:01 -0800 Subject: [PATCH 188/698] [pallas:mosaic_gpu] `lax.fori_loop` lowering now promotes the carry to `mgpu.FragmentedArray`s PiperOrigin-RevId: 692976037 --- jax/_src/pallas/mosaic_gpu/lowering.py | 14 +++++++---- tests/pallas/mosaic_gpu_test.py | 32 +++++++++++++++++++++----- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c4447ae95435..ba343cd923c3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1389,16 +1389,23 @@ def _lower_jaxpr_to_for_loop( has_loop_index: bool, ): - @mgpu.fori(length, [*args]) + _consts_avals, arg_avals = util.split_list(ctx.avals_in, [len(consts)]) + arg_avals = arg_avals[has_loop_index:] + out_avals = [] + if arg_avals: + out_avals = ctx.avals_out[-len(arg_avals):] + + @mgpu.fori(length, [*map(_ensure_fa, args, arg_avals)]) def loop(loop_index, body_args): if has_loop_index: loop_index = arith_dialect.addi(loop_index, start) jaxpr_args = [*consts, loop_index, *body_args] else: jaxpr_args = [*consts, *body_args] - return lower_jaxpr_to_mosaic_gpu( + outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args ) + return map(_ensure_fa, outs, out_avals) return loop.results @@ -1437,13 +1444,12 @@ def _scan_lowering_rule( _consts_avals, arg_avals = util.split_list(ctx.avals_in, [num_consts]) if has_loop_index: start, *args = args - index_aval, *arg_avals = arg_avals + index_aval, *_ = arg_avals start: ir.Value = _ensure_ir_value(start, index_aval.dtype) length = _ir_constant(length, start.type) else: start = _i32_constant(0) length = _i32_constant(length) - args = map(lambda arg, aval: _ensure_fa(arg, aval.dtype), args, arg_avals) for_out = _lower_jaxpr_to_for_loop( ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f427b91b5e94..e0945b0265fb 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -649,28 +649,48 @@ def kernel(x_ref, o_ref): def test_fori_loop_array(self): @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...]) - x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0) + x = jnp.arange(256).astype(jnp.int32) + np.testing.assert_array_equal(kernel(x), x + 2 + 3) def test_fori_loop_scalar(self): + @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(o_ref): # Equivalent to 2 + 3. o_ref[...] = jax.lax.broadcast( - jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0.0), o_ref.shape + jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0), o_ref.shape + ) + + np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) + + def test_fori_loop_tuple(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(o_ref): + def body(step, xs): + return tuple( + jax.lax.cond(step % 2 == 0, lambda x: x + 1, lambda x: x, x) + for x in xs + ) + + # Equivalent to 3 * (0 + 1). + o_ref[...] = jax.lax.broadcast( + sum(jax.lax.fori_loop(2, 4, body, (0, 0, 0))), o_ref.shape ) np.testing.assert_array_equal( - kernel(), jnp.full([256], 5.0, dtype=jnp.float32) + kernel(), jnp.full([256], 3 * (0 + 1), dtype=jnp.int32) ) def test_fori_loop_indexed_store(self): From f95417006f1b479fb81caaf875e1374fac052969 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 Nov 2024 08:39:04 -0800 Subject: [PATCH 189/698] [tpu] Disable a cumulative reduction test on TPU v6e that currently hits an unimplemented case in XLA. PiperOrigin-RevId: 692979420 --- tests/lax_numpy_reducers_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 98c8785f0fb1..be6208e6e305 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -886,6 +886,8 @@ def testCumulativeSumBool(self): @jtu.ignore_warning(category=NumpyComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): + if jtu.is_device_tpu(6): + raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e") rng = jtu.rand_some_zero(self.rng()) # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as From e9acaa8484914594e570d0093d5985a1bb3171a3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 4 Nov 2024 10:54:41 -0800 Subject: [PATCH 190/698] Remove the `initial` argument to `jax.nn.softmax` and `jax.nn.log_softmax`. This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later. PiperOrigin-RevId: 693023366 --- CHANGELOG.md | 2 ++ jax/_src/nn/functions.py | 15 ++++++--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b122675e1ffc..8c34b8e369c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The deprecated module `jax.experimental.export` has been removed. It was replaced by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. + * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` + has been removed, after being deprecated in v0.4.27. * The following deprecated methods and functions in {mod}`jax.export` have been removed: * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 861e3d0123ff..5dfaa7b7e5f7 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -22,7 +22,6 @@ import math import numpy as np from typing import Any, Literal -import warnings import jax import jax.numpy as jnp @@ -502,7 +501,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: + initial: Unspecified = _UNSPECIFIED) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -528,10 +527,9 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ + # TODO(jakevdp): remove the initial argument after JAX v0.4.40. if initial is not _UNSPECIFIED: - # Added 2024-4-10 - warnings.warn("The initial argument to log_softmax is deprecated, and no longer has any effect.", - DeprecationWarning, stacklevel=2) + raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.") del initial numpy_util.check_arraylike("log_softmax", x) x_arr = jnp.asarray(x) @@ -551,7 +549,7 @@ def log_softmax(x: ArrayLike, def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, - initial: ArrayLike | None | Unspecified = _UNSPECIFIED) -> Array: + initial: Unspecified = _UNSPECIFIED) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -577,10 +575,9 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ + # TODO(jakevdp): remove the initial argument after JAX v0.4.40. if initial is not _UNSPECIFIED: - # Added 2024-4-10 - warnings.warn("The initial argument to softmax is deprecated, and no longer has any effect.", - DeprecationWarning, stacklevel=2) + raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.") del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition From 38b4d00100634c506663d54f2e5b08dc01bb43f6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 4 Nov 2024 13:01:06 -0800 Subject: [PATCH 191/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c66d74e5b3ef0d64c43cdd99c8e6aac8512adb6a. PiperOrigin-RevId: 693065128 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 72fa6fef9bff..93ac0a9eae0c 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "336cc6ede42b22e6b1336e5e12c23a8b9e37db20" -XLA_SHA256 = "2853a33bf320f1b61ba82d5adc9b1cf92ab9ea28f392f37df423bbcecd192140" +XLA_COMMIT = "c66d74e5b3ef0d64c43cdd99c8e6aac8512adb6a" +XLA_SHA256 = "1b036e7adc0d408b76ab4f67705704ad7d95e4070c8e8e905315f678d3f7f1df" def repo(): tf_http_archive( From 74da736e0ea8bf4a93f71db41c8c9e3e5cc59afe Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Mon, 4 Nov 2024 18:02:44 -0500 Subject: [PATCH 192/698] Update link to algebraic_simplifier.cc to point to OpenXLA instead of TF --- docs/faq.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/faq.rst b/docs/faq.rst index af14f382b1d7..1d2bb204f24c 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -847,6 +847,6 @@ see the page on `JAX GPU memory allocation`_. .. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function -.. _algebraic_simplifier.cc: https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/compiler/xla/service/algebraic_simplifier.cc#L3266 +.. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 .. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html From 700f3bdccc720e5381dfb242927ddc1e50429a6c Mon Sep 17 00:00:00 2001 From: charleshofer Date: Mon, 4 Nov 2024 17:10:03 -0600 Subject: [PATCH 193/698] Rename the CI flow to 'ROCm CI' and only run it on PRs to rocm-main (#126) * Rename the CI flow to 'ROCm CI' and only run it on PRs to the rocm-main branch * Change name to 'ROCm CPU CI' --- .github/workflows/ci-build.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 6ac7a138d7da..2183aaf8948d 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,4 +1,4 @@ -name: CI +name: ROCm CPU CI # We test all supported Python versions as follows: # - 3.10 : Documentation build @@ -11,10 +11,10 @@ on: # but only for the main branch push: branches: - - main + - rocm-main pull_request: branches: - - main + - rocm-main permissions: contents: read # to fetch code From ab47d4687f647de3aa145a9a782fb7b4aaf92af4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 Nov 2024 15:38:25 -0800 Subject: [PATCH 194/698] [JAX] [XLA:Python] Move JAX configuration objects into C++. A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code. There are two main ways we can get a speedup: * Python thread-local state is based around a dictionary and isn't terribly fast. * we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand. PiperOrigin-RevId: 693114411 --- jax/_src/api.py | 4 +- jax/_src/compute_on.py | 7 +- jax/_src/config.py | 649 +++++++++++++++++++++++---------------- jax/_src/core.py | 15 +- jax/_src/mesh.py | 14 +- jax/_src/xla_metadata.py | 12 +- tests/pmap_test.py | 2 - 7 files changed, 406 insertions(+), 297 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index cc42a37b0e7c..a902c7de4c3e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -123,8 +123,8 @@ def _update_debug_special_global(_): jax_jit.global_state().post_hook = None def _update_debug_special_thread_local(_): - if (getattr(config._thread_local_state, "jax_debug_nans", False) or - getattr(config._thread_local_state, "jax_debug_infs", False)): + if (config.debug_nans.get_local() == True or + config.debug_infs.get_local() == True): jax_jit.thread_local_state().post_hook = _nan_check_posthook else: jax_jit.thread_local_state().post_hook = None diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index b5194ddad21d..7bd9b9b08b7b 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -29,8 +29,8 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) - config.update_thread_local_jit_state( - compute_on_context_manager=tuple(compute_on_context.stack)) + config.compute_on_context_manager.set_local( + tuple(compute_on_context.stack)) try: if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: raise NotImplementedError( @@ -39,8 +39,7 @@ def extend_compute_type(c_type: str): yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() - config.update_thread_local_jit_state( - compute_on_context_manager=tuple(compute_on_context.stack)) + config.compute_on_context_manager.set_local(tuple(compute_on_context.stack)) def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None diff --git a/jax/_src/config.py b/jax/_src/config.py index 0860168b23a3..0e113c894695 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -22,14 +22,23 @@ import os import sys import threading -from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast +from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING from jax._src import lib from jax._src.lib import guard_lib from jax._src.lib import jax_jit from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax._src import logging_config +# TODO(phawkins): reenable pytype after xla_extension_version >= 295 +# pytype: skip-file + +if xla_extension_version >= 295: + config_ext = xla_client._xla.config +else: + config_ext = None + logger = logging.getLogger(__name__) _T = TypeVar('_T') @@ -191,49 +200,79 @@ def parse_flags_with_absl(self): already_configured_with_absl = True -def trace_context(): - """Returns a tuple of configuration values that affect tracing. +if xla_extension_version >= 295: + def trace_context(): + """Returns a tuple of configuration values that affect tracing. - These values are included in the cache key for linear_util.cache. + These values are included in the cache key for linear_util.cache. - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately. - """ - tls = jax_jit.thread_local_state() - axis_env_state = () - mesh_context_manager = () - xla_metadata_context_manager = () - compute_on_context_manager = () - - context: Any = tls.extra_jit_context - if context and context.axis_env_state is not None: - axis_env_state = context.axis_env_state - if context and context.mesh_context_manager: - mesh_context_manager = context.mesh_context_manager - if context and context.xla_metadata_context_manager: - xla_metadata_context_manager = context.xla_metadata_context_manager - if context and context.compute_on_context_manager: - compute_on_context_manager = context.compute_on_context_manager - return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, - compute_on_context_manager, enable_x64.value, - numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, - eager_constant_folding.value, - numpy_dtype_promotion.value, - default_device.value, random_seed_offset.value, - threefry_partitionable.value, - threefry_gpu_kernel_lowering.value, - sharding_in_types.value, - softmax_custom_jvp.value, - enable_memories.value, - disable_jit.value, - debug_key_reuse.value, - jax_xla_profile_version.value, - # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value, - pgle_profiling_runs.value, - enable_pgle.value, - use_shardy_partitioner.value) + Values included in this set should also most likely be included in + the C++ JIT state, which is handled separately. + """ + return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value, + compute_on_context_manager.value, enable_x64.value, + numpy_rank_promotion.value, default_matmul_precision.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, + default_device.value, random_seed_offset.value, + threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, + sharding_in_types.value, + softmax_custom_jvp.value, + enable_memories.value, + disable_jit.value, + debug_key_reuse.value, + jax_xla_profile_version.value, + # Technically this affects jaxpr->stablehlo lowering, not tracing. + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value, + use_shardy_partitioner.value) +else: + def trace_context(): + """Returns a tuple of configuration values that affect tracing. + + These values are included in the cache key for linear_util.cache. + + Values included in this set should also most likely be included in + the C++ JIT state, which is handled separately. + """ + tls = jax_jit.thread_local_state() + axis_env_state = () + mesh_context_manager = () + xla_metadata_context_manager = () + compute_on_context_manager = () + + context: Any = tls.extra_jit_context + if context and context.axis_env_state is not None: + axis_env_state = context.axis_env_state + if context and context.mesh_context_manager: + mesh_context_manager = context.mesh_context_manager + if context and context.xla_metadata_context_manager: + xla_metadata_context_manager = context.xla_metadata_context_manager + if context and context.compute_on_context_manager: + compute_on_context_manager = context.compute_on_context_manager + return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, + compute_on_context_manager, enable_x64.value, + numpy_rank_promotion.value, default_matmul_precision.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, + default_device.value, random_seed_offset.value, + threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, + sharding_in_types.value, + softmax_custom_jvp.value, + enable_memories.value, + disable_jit.value, + debug_key_reuse.value, + jax_xla_profile_version.value, + # Technically this affects jaxpr->stablehlo lowering, not tracing. + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value, + use_shardy_partitioner.value) config = Config() @@ -245,94 +284,185 @@ def trace_context(): class NoDefault: pass no_default = NoDefault() +if xla_extension_version >= 295: + class State(config_ext.Config[_T]): -class _Unset: pass -unset = _Unset() - -_thread_local_state = threading.local() + __slots__ = ( + '_name', '_update_thread_local_hook', '_update_global_hook', + '_validator', '_default_context_manager_value', '__doc__', '__name__', + ) -class State(Generic[_T]): + def __init__( + self, + name: str, + default: _T, + help, + update_global_hook: Callable[[_T], None] | None = None, + update_thread_local_hook: Callable[[_T | None], None] | None = None, + validator: Callable[[Any], None] | None = None, + extra_description: str = '', + default_context_manager_value: Any = no_default, + include_in_jit_key: bool = False, + ): + super().__init__(default, include_in_jit_key) + self._name = name + self.__name__ = name[4:] if name.startswith('jax_') else name + self.__doc__ = (f"Context manager for `{name}` config option" + f"{extra_description}.\n\n{help}") + self._update_global_hook = update_global_hook + self._update_thread_local_hook = update_thread_local_hook + self._validator = validator + self._default_context_manager_value = default_context_manager_value + if self._validator: + self._validator(default) + if self._update_global_hook: + self._update_global_hook(default) + + def __bool__(self) -> NoReturn: + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__)) + + def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) + self.set_global(value) + if self._update_global_hook: + self._update_global_hook(value) + + @contextlib.contextmanager + def __call__(self, new_val: Any = no_default): + if new_val is no_default: + if self._default_context_manager_value is not no_default: + new_val = self._default_context_manager_value # default_context_manager_value provided to constructor + else: + # no default_value provided to constructor and no value provided as an + # argument, so we raise an error + raise TypeError(f"Context manager for {self.__name__} config option " + "requires an argument representing the new value for " + "the config option.") + if self._validator: + self._validator(new_val) + prev_val = self.swap_local(new_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(new_val) + try: + yield + finally: + self.set_local(prev_val) + if self._update_thread_local_hook: + if prev_val is config_ext.unset: + self._update_thread_local_hook(None) + else: + self._update_thread_local_hook(cast(Optional[Any], prev_val)) - __slots__ = ( - '_name', '_value', '_update_thread_local_hook', '_update_global_hook', - '_validator', '_default_context_manager_value', '__doc__', '__name__', - ) + def _add_hooks(self, update_global_hook, update_thread_local_hook): + """Private method that adds hooks to an existing context-manager. - def __init__( - self, - name: str, - default: _T, - help, - update_global_hook: Callable[[_T], None] | None = None, - update_thread_local_hook: Callable[[_T | None], None] | None = None, - validator: Callable[[Any], None] | None = None, - extra_description: str = '', - default_context_manager_value: Any = no_default, - ): - self._name = name - self.__name__ = name[4:] if name.startswith('jax_') else name - self.__doc__ = (f"Context manager for `{name}` config option" - f"{extra_description}.\n\n{help}") - self._update_global_hook = update_global_hook - self._update_thread_local_hook = update_thread_local_hook - self._validator = validator - self._default_context_manager_value = default_context_manager_value - self._set(default) + Used to avoid cyclic import dependencies.""" + self._update_thread_local_hook = update_thread_local_hook + self._update_global_hook = update_global_hook + update_global_hook(self.get_global()) - def __bool__(self) -> NoReturn: - raise TypeError( - "bool() not supported for instances of type '{0}' " - "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) +else: + class _Unset: pass + unset = _Unset() - def _set(self, value: _T) -> None: - if self._validator: - self._validator(value) - self._value = value - if self._update_global_hook: - self._update_global_hook(value) + _thread_local_state = threading.local() - @property - def value(self) -> _T: - val = _thread_local_state.__dict__.get(self._name, unset) - return cast(_T, val) if val is not unset else self._value - - @contextlib.contextmanager - def __call__(self, new_val: Any = no_default): - if new_val is no_default: - if self._default_context_manager_value is not no_default: - new_val = self._default_context_manager_value # default_context_manager_value provided to constructor - else: - # no default_value provided to constructor and no value provided as an - # argument, so we raise an error - raise TypeError(f"Context manager for {self.__name__} config option " - "requires an argument representing the new value for " - "the config option.") - if self._validator: - self._validator(new_val) - prev_val = getattr(_thread_local_state, self._name, unset) - setattr(_thread_local_state, self._name, new_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(new_val) - try: - yield - finally: - if prev_val is unset: - delattr(_thread_local_state, self._name) - if self._update_thread_local_hook: - self._update_thread_local_hook(None) - else: - setattr(_thread_local_state, self._name, prev_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(cast(_T, prev_val)) + class State(Generic[_T]): - def _add_hooks(self, update_global_hook, update_thread_local_hook): - """Private method that adds hooks to an existing context-manager. + __slots__ = ( + '_name', '_value', '_update_thread_local_hook', '_update_global_hook', + '_validator', '_default_context_manager_value', '__doc__', '__name__', + ) - Used to avoid cyclic import dependencies.""" - self._update_thread_local_hook = update_thread_local_hook - self._update_global_hook = update_global_hook - update_global_hook(self._value) + def __init__( + self, + name: str, + default: _T, + help, + update_global_hook: Callable[[_T], None] | None = None, + update_thread_local_hook: Callable[[_T | None], None] | None = None, + validator: Callable[[Any], None] | None = None, + extra_description: str = '', + default_context_manager_value: Any = no_default, + include_in_jit_key: bool = False, + ): + self._name = name + self.__name__ = name[4:] if name.startswith('jax_') else name + self.__doc__ = (f"Context manager for `{name}` config option" + f"{extra_description}.\n\n{help}") + if include_in_jit_key: + assert update_global_hook is None + assert update_thread_local_hook is None + update_global_hook = lambda val: _update_global_jit_state( + **{self.__name__: val}) + update_thread_local_hook = lambda val: update_thread_local_jit_state( + **{self.__name__: val}) + self._update_global_hook = update_global_hook + self._update_thread_local_hook = update_thread_local_hook + self._validator = validator + self._default_context_manager_value = default_context_manager_value + self._set(default) + def __bool__(self) -> NoReturn: + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__)) + + def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) + self._value = value + if self._update_global_hook: + self._update_global_hook(value) + + @property + def value(self) -> _T: + val = _thread_local_state.__dict__.get(self._name, unset) + return cast(_T, val) if val is not unset else self._value + + def get_local(self) -> Any: + return _thread_local_state.__dict__.get(self._name, unset) + + @contextlib.contextmanager + def __call__(self, new_val: Any = no_default): + if new_val is no_default: + if self._default_context_manager_value is not no_default: + new_val = self._default_context_manager_value # default_context_manager_value provided to constructor + else: + # no default_value provided to constructor and no value provided as an + # argument, so we raise an error + raise TypeError(f"Context manager for {self.__name__} config option " + "requires an argument representing the new value for " + "the config option.") + if self._validator: + self._validator(new_val) + prev_val = getattr(_thread_local_state, self._name, unset) + setattr(_thread_local_state, self._name, new_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(new_val) + try: + yield + finally: + if prev_val is unset: + delattr(_thread_local_state, self._name) + if self._update_thread_local_hook: + self._update_thread_local_hook(None) + else: + setattr(_thread_local_state, self._name, prev_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(cast(_T, prev_val)) + + def _add_hooks(self, update_global_hook, update_thread_local_hook): + """Private method that adds hooks to an existing context-manager. + + Used to avoid cyclic import dependencies.""" + self._update_thread_local_hook = update_thread_local_hook + self._update_global_hook = update_global_hook + update_global_hook(self._value) UPGRADE_BOOL_HELP = ( @@ -353,6 +483,7 @@ def bool_state( update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', + include_in_jit_key: bool = False, ) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. @@ -417,7 +548,8 @@ def bool_state( s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, - extra_description=extra_description, default_context_manager_value=True) + extra_description=extra_description, default_context_manager_value=True, + include_in_jit_key=include_in_jit_key) config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -431,6 +563,7 @@ def enum_state( *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. @@ -470,6 +603,7 @@ def validator(new_val): update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator, + include_in_jit_key=include_in_jit_key, ) config.add_option( name, s, 'enum', @@ -488,6 +622,7 @@ def optional_enum_state( *, update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. @@ -523,7 +658,7 @@ def validate(new_val): s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, - validate + validate, include_in_jit_key=include_in_jit_key, ) config.add_option( name, s, 'enum', @@ -541,6 +676,7 @@ def int_state( *, update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. @@ -575,7 +711,8 @@ def validate(new_val): f'got {new_val} of type {type(new_val)}') s = State[int](name, default, help, update_global_hook, - update_thread_local_hook, validate) + update_thread_local_hook, validate, + include_in_jit_key=include_in_jit_key) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -826,92 +963,119 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: already_configured_with_absl = False -# The C++ JIT maintains its own copy of several configuration items as -# a global/thread-local state. These methods allow updates to part of the -# state when a configuration value changes. -class _GlobalExtraJitContext(NamedTuple): - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool = False - eager_constant_folding: bool = False - random_seed_offset: int = 0 - threefry_partitionable: bool = False - threefry_gpu_kernel_lowering: bool = False - sharding_in_types: bool = False - softmax_custom_jvp: bool = False - xla_profile_version: int = 0 - pgle_profiling_runs: int = 0 - enable_pgle: bool = False - use_shardy_partitioner: bool = False - - -def _update_global_jit_state(**kw): - gs = jax_jit.global_state() - context = gs.extra_jit_context or _GlobalExtraJitContext() - gs.extra_jit_context = context._replace(**kw) - - -class _ThreadLocalExtraJitContext(NamedTuple): - """A namedtuple containing states to add to the cache key. - - Just in time compilation (for jit, pmap, etc) behavior is configurable through - global and thread-local options, used in the cache key. - - The initialization, which uses both config.py and core.py is done using - `_update_thread_local_jit_state` in core.py to prevent circular imports. - """ - trace_state: Any | None = None - axis_env_state: Hashable = () - mesh_context_manager: Hashable = () - compute_on_context_manager: Hashable = () - xla_metadata_context_manager: Hashable = () - - # Values set by _StateContextManager context managers. - # CAUTION: these must be initialized to `None`! The state context manager - # restores these to None on exit. If the object default is not `None`, the - # context manager is not a no-op, which leads to problems with stale state - # (e.g. spurious cache misses in tests). - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool | None = None - eager_constant_folding : bool | None = None - random_seed_offset: int | None = None - threefry_partitionable: bool | None = None - threefry_gpu_kernel_lowering: bool | None = None - sharding_in_types: bool | None = None - softmax_custom_jvp: bool | None = None - xla_profile_version: int | None = None - pgle_profiling_runs: int | None = None - enable_pgle: bool | None = None - use_shardy_partitioner: bool | None = None - - -class _ThreadLocalStateCache(threading.local): - """"A thread local cache for _ThreadLocalExtraJitContext - - The extra_jit_context in jax_jit.thread_local_state() may get updated and thus - incurring dispatch overhead for comparing this python object during jit calls. - We want to deduplicate the objects that have the same hash/equality to also - have the same object ID, since the equality check is much faster if the object - IDs match. - """ - def __init__(self): - self.canonicalize = functools.lru_cache(128)(lambda x: x) +if xla_extension_version >= 295: + trace_state = config_ext.Config(None, include_in_jit_key=True) + axis_env_state = config_ext.Config((), include_in_jit_key=True) + mesh_context_manager = config_ext.Config((), include_in_jit_key=True) + compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) + xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) +else: + # The C++ JIT maintains its own copy of several configuration items as + # a global/thread-local state. These methods allow updates to part of the + # state when a configuration value changes. + class _GlobalExtraJitContext(NamedTuple): + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None + dynamic_shapes: bool = False + eager_constant_folding: bool = False + random_seed_offset: int = 0 + threefry_partitionable: bool = False + threefry_gpu_kernel_lowering: bool = False + sharding_in_types: bool = False + softmax_custom_jvp: bool = False + xla_profile_version: int = 0 + pgle_profiling_runs: int = 0 + enable_pgle: bool = False + use_shardy_partitioner: bool = False + + + def _update_global_jit_state(**kw): + gs = jax_jit.global_state() + context = gs.extra_jit_context or _GlobalExtraJitContext() + gs.extra_jit_context = context._replace(**kw) + + + class _ThreadLocalExtraJitContext(NamedTuple): + """A namedtuple containing states to add to the cache key. + + Just in time compilation (for jit, pmap, etc) behavior is configurable through + global and thread-local options, used in the cache key. + + The initialization, which uses both config.py and core.py is done using + `_update_thread_local_jit_state` in core.py to prevent circular imports. + """ + trace_state: Any | None = None + axis_env_state: Hashable = () + mesh_context_manager: Hashable = () + compute_on_context_manager: Hashable = () + xla_metadata_context_manager: Hashable = () + + # Values set by _StateContextManager context managers. + # CAUTION: these must be initialized to `None`! The state context manager + # restores these to None on exit. If the object default is not `None`, the + # context manager is not a no-op, which leads to problems with stale state + # (e.g. spurious cache misses in tests). + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None + dynamic_shapes: bool | None = None + eager_constant_folding : bool | None = None + random_seed_offset: int | None = None + threefry_partitionable: bool | None = None + threefry_gpu_kernel_lowering: bool | None = None + sharding_in_types: bool | None = None + softmax_custom_jvp: bool | None = None + xla_profile_version: int | None = None + pgle_profiling_runs: int | None = None + enable_pgle: bool | None = None + use_shardy_partitioner: bool | None = None + + + class _ThreadLocalStateCache(threading.local): + """"A thread local cache for _ThreadLocalExtraJitContext + + The extra_jit_context in jax_jit.thread_local_state() may get updated and thus + incurring dispatch overhead for comparing this python object during jit calls. + We want to deduplicate the objects that have the same hash/equality to also + have the same object ID, since the equality check is much faster if the object + IDs match. + """ + def __init__(self): + self.canonicalize = functools.lru_cache(128)(lambda x: x) -_thread_local_state_cache = _ThreadLocalStateCache() + _thread_local_state_cache = _ThreadLocalStateCache() -def update_thread_local_jit_state(**kw): - tls = jax_jit.thread_local_state() - # After xla_client._version >= 70, the thread_local object will necessarily - # be initialized when accessed. The following line can be removed when the - # minimum jaxlib version is past version 70 - context = tls.extra_jit_context or _ThreadLocalExtraJitContext() - tmp = context._replace(**kw) - tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) + def update_thread_local_jit_state(**kw): + tls = jax_jit.thread_local_state() + # After xla_client._version >= 70, the thread_local object will necessarily + # be initialized when accessed. The following line can be removed when the + # minimum jaxlib version is past version 70 + context = tls.extra_jit_context or _ThreadLocalExtraJitContext() + tmp = context._replace(**kw) + tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) + + class JitConfig: + def __init__(self, name): + self._name = name + + def value(self): + return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) + + def get_local(self): + return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) + + def set_local(self, value): + update_thread_local_jit_state(**{self._name: value}) + + trace_state = JitConfig('trace_state') + axis_env_state = JitConfig('axis_env_state') + mesh_context_manager = JitConfig('mesh_context_manager') + compute_on_context_manager = JitConfig('compute_on_context_manager') + xla_metadata_context_manager = JitConfig('xla_metadata_context_manager') + # TODO(b/214340779): remove flag when XLA:CPU is improved. jax2tf_associative_scan_reductions = bool_state( @@ -1102,10 +1266,7 @@ def _update_jax_memories_thread_local(val): name='jax_random_seed_offset', default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), - update_global_hook=lambda val: _update_global_jit_state( - random_seed_offset=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - random_seed_offset=val) + include_in_jit_key=True, ) legacy_prng_key = enum_state( @@ -1140,10 +1301,7 @@ def _update_jax_memories_thread_local(val): 'may result in extraneous communication and/or redundant distributed ' 'computation. With this flag, the communication overheads disappear ' 'in some cases.'), - update_global_hook=lambda val: _update_global_jit_state( - threefry_partitionable=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - threefry_partitionable=val)) + include_in_jit_key=True) threefry_gpu_kernel_lowering = bool_state( name='jax_threefry_gpu_kernel_lowering', @@ -1151,20 +1309,14 @@ def _update_jax_memories_thread_local(val): help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' 'This makes compile times faster at a potential runtime memory ' 'cost.'), - update_global_hook=lambda val: _update_global_jit_state( - threefry_gpu_kernel_lowering=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - threefry_gpu_kernel_lowering=val)) + include_in_jit_key=True) sharding_in_types = bool_state( name='jax_sharding_in_types', default=False, help=('When True, enables forward only sharding propagation in JAX and ' 'avals have sharding on them.'), - update_global_hook=lambda val: _update_global_jit_state( - sharding_in_types=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - sharding_in_types=val)) + include_in_jit_key=True) data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', @@ -1179,10 +1331,7 @@ def _update_jax_memories_thread_local(val): help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' 'behavior. See https://github.com/jax-ml/jax/pull/15677'), - update_global_hook=lambda val: _update_global_jit_state( - softmax_custom_jvp=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - softmax_custom_jvp=val)) + include_in_jit_key=True) enable_custom_vjp_by_custom_transpose = bool_state( @@ -1298,9 +1447,7 @@ def _update_jax_memories_thread_local(val): 'number times with collected data provided to the profile guided latency ' 'estimator.' ), - update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - enable_pgle=val), + include_in_jit_key=True, ) pgle_profiling_runs = int_state( @@ -1310,12 +1457,7 @@ def _update_jax_memories_thread_local(val): 'Amount of times module should be profiled before recompilation when ' 'PGLE is used.' ), - update_global_hook=lambda val: _update_global_jit_state( - pgle_profiling_runs=val - ), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - pgle_profiling_runs=val - ), + include_in_jit_key=True, ) pgle_aggregation_percentile = int_state( @@ -1381,10 +1523,7 @@ def _update_jax_memories_thread_local(val): 'between arrays. Options are "standard" or "strict"; in strict-mode, ' 'binary operations between arrays of differing strongly-specified ' 'dtypes will result in an error.'), - update_global_hook=lambda val: \ - _update_global_jit_state(numpy_dtype_promotion=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(numpy_dtype_promotion=val)) + include_in_jit_key=True) disallow_mesh_context_manager = bool_state( name='jax_disallow_mesh_context_manager', @@ -1470,10 +1609,7 @@ def _update_disable_jit_thread_local(val): default='allow', help=('Control NumPy-style automatic rank promotion broadcasting ' '("allow", "warn", or "raise").'), - update_global_hook=lambda val: \ - _update_global_jit_state(numpy_rank_promotion=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(numpy_rank_promotion=val)) + include_in_jit_key=True) default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', @@ -1509,10 +1645,7 @@ def _update_disable_jit_thread_local(val): '"algorithm" for functions that perform matrix multiplications, like ' ':func:`jax.lax.dot`. To specify an algorithm, set this option to ' 'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'), - update_global_hook=lambda val: \ - _update_global_jit_state(default_matmul_precision=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(default_matmul_precision=val)) + include_in_jit_key=True) traceback_filtering = enum_state( name = 'jax_traceback_filtering', @@ -1547,20 +1680,14 @@ def _update_disable_jit_thread_local(val): default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), help=('Enables experimental features for staging out computations with ' 'dynamic shapes.'), - update_global_hook=lambda val: \ - _update_global_jit_state(dynamic_shapes=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(dynamic_shapes=val)) + include_in_jit_key=True) # This is for stackless backward compat with e.g. equinox eager_constant_folding = bool_state( name='eager_constant_folding', default=False, help=('Attempt constant folding during staging.'), - update_global_hook=lambda val: \ - _update_global_jit_state(eager_constant_folding=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(eager_constant_folding=val)) + include_in_jit_key=True) # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. @@ -1619,10 +1746,7 @@ def _update_disable_jit_thread_local(val): 'Optional profile version for XLA compilation. This is meaningful ' 'only when XLA is configured to support the remote compilation ' 'profile feature.'), - update_global_hook=lambda val: _update_global_jit_state( - xla_profile_version=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - xla_profile_version=val), + include_in_jit_key=True, ) @contextlib.contextmanager @@ -1821,10 +1945,5 @@ def _update_debug_log_modules(module_names_str: str | None): 'framework for MLIR. Currently Shardy is experimental in JAX. See ' 'www.github.com/openxla/shardy' ), - update_global_hook=lambda val: _update_global_jit_state( - use_shardy_partitioner=val - ), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - use_shardy_partitioner=val - ), + include_in_jit_key=True, ) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7d912e3c207b..7c92fa1b5236 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1016,18 +1016,16 @@ def is_top_level(self) -> bool: def set_trace(self, trace): self.trace = trace ts = ref(trace) if trace is not None else None - config.update_thread_local_jit_state(trace_state=ts) + config.trace_state.set_local(ts) def set_axis_env(self, axis_env): self.axis_env = axis_env - config.update_thread_local_jit_state( - axis_env_state=self.axis_env.as_hashable_key()) + config.axis_env_state.set_local(axis_env.as_hashable_key()) def update_thread_local_jit_state(self): ts = ref(self.trace) if self.trace is not None else None - config.update_thread_local_jit_state( - trace_state=ts, - axis_env_state=self.axis_env.as_hashable_key()) + config.trace_state.set_local(ts) + config.axis_env_state.set_local(self.axis_env.as_hashable_key()) trace_ctx = TracingContext() @@ -1071,10 +1069,7 @@ def _initialize_jax_jit_thread_local_state(): This function does not live in `config.py`, to prevent circular imports. """ - tls = jax_jit.thread_local_state() - - if tls.extra_jit_context is None: - trace_ctx.update_thread_local_jit_state() + trace_ctx.update_thread_local_jit_state() jax_jit.set_thread_local_state_initialization_callback( _initialize_jax_jit_thread_local_state) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 8cb508378129..43791f2e5f72 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -224,17 +224,17 @@ def __enter__(self): new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.mesh_context_manager.set_local( + tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return self def __exit__(self, exc_type, exc_value, traceback): thread_resources.stack.pop() thread_resources.env = thread_resources.stack[-1] - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.mesh_context_manager.set_local( + tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return False @property @@ -410,7 +410,7 @@ def __exit__(self, exc_type, exc_value, traceback): @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): - jax_config.update_thread_local_jit_state(mesh_context_manager=mesh) + jax_config.mesh_context_manager.set_local(mesh) return diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 94b482e2dea4..94e26eeefa65 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -41,15 +41,13 @@ def set_xla_metadata(*args, **kwargs): thread_local_metadata.val, new_metadata, ) - config.update_thread_local_jit_state( - xla_metadata_context_manager=tuple( - (v, k) for k, v in sorted(new_metadata.items()))) + config.xla_metadata_context_manager.set_local( + tuple((v, k) for k, v in sorted(new_metadata.items())) + ) try: yield finally: thread_local_metadata.val = prev_metadata - config.update_thread_local_jit_state( - xla_metadata_context_manager=tuple( - (v, k) for k, v in sorted(prev_metadata.items()) - ) + config.xla_metadata_context_manager.set_local( + tuple((v, k) for k, v in sorted(prev_metadata.items())) ) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 6e0e795df334..0df5d99715e6 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2215,8 +2215,6 @@ def test_cache_uses_jax_key(self): pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) - config.update_thread_local_jit_state() - pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) From 9f0e6237a3175d9e3e130ee3137da81a1c3a21a1 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 4 Nov 2024 22:43:37 +0000 Subject: [PATCH 195/698] Update JAX landing page - Flax --- docs/index.rst | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 2dd856ab88ef..7d555fe85eb4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -65,7 +65,6 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-outlined:`hub;2em` **Neural networks** - Flax_ - - NNX_ - Equinox_ - Keras_ @@ -79,8 +78,8 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-outlined:`storage;2em` **Data loading** - Grain_ - - `Tensorflow datasets`_ - - `Hugging Face datasets`_ + - `TensorFlow Datasets`_ + - `Hugging Face Datasets`_ .. grid-item:: :material-regular:`construction;2em` **Miscellaneous tools** @@ -95,7 +94,7 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling** - - `Tensorflow probabilty`_ + - `TensorFlow Probabilty`_ - Distrax_ .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation** @@ -164,17 +163,16 @@ maintains an up-to-date list. .. _Equinox: https://docs.kidger.site/equinox/ .. _Flax: https://flax.readthedocs.io/ .. _Grain: https://github.com/google/grain -.. _Hugging Face datasets: https://huggingface.co/docs/datasets/ +.. _Hugging Face Datasets: https://huggingface.co/docs/datasets/ .. _JAX MD: https://jax-md.readthedocs.io/ .. _Keras: https://keras.io/ .. _Levanter: https://github.com/stanford-crfm/levanter .. _Lineax: https://github.com/patrick-kidger/lineax .. _MaxText: https://github.com/google/maxtext/ -.. _NNX: https://flax.readthedocs.io/en/latest/nnx/ .. _Numpyro: https://num.pyro.ai/en/latest/index.html .. _Optax: https://optax.readthedocs.io/ .. _Optimistix: https://github.com/patrick-kidger/optimistix .. _Orbax: https://orbax.readthedocs.io/ .. _PyMC: https://www.pymc.io/ -.. _Tensorflow datasets: https://www.tensorflow.org/datasets -.. _Tensorflow probabilty: https://www.tensorflow.org/probability +.. _TensorFlow Datasets: https://www.tensorflow.org/datasets +.. _TensorFlow Probabilty: https://www.tensorflow.org/probability From b976b1ab275045eb1ec20d57103d392adfd3c510 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 4 Nov 2024 22:35:26 +0000 Subject: [PATCH 196/698] Update details on JAX libraries in JAX README.md --- README.md | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index c99d3db10a2a..89fe51212638 100644 --- a/README.md +++ b/README.md @@ -411,23 +411,18 @@ community-supported conda build, and answers to some frequently-asked questions. ## Neural network libraries -Multiple Google research groups develop and share libraries for training neural -networks in JAX. If you want a fully featured library for neural network +Multiple Google research groups at Google DeepMind and Alphabet develop and share libraries +for training neural networks in JAX. If you want a fully featured library for neural network training with examples and how-to guides, try -[Flax](https://github.com/google/flax). Check out the new [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) API for a -simplified development experience. - -Google X maintains the neural network library -[Equinox](https://github.com/patrick-kidger/equinox). This is used as the -foundation for several other libraries in the JAX ecosystem. - -In addition, DeepMind has open-sourced an [ecosystem of libraries around -JAX](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) -including [Optax](https://github.com/deepmind/optax) for gradient processing and -optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and -[chex](https://github.com/deepmind/chex) for reliable code and testing. (Watch -the NeurIPS 2020 JAX Ecosystem at DeepMind talk -[here](https://www.youtube.com/watch?v=iDxJxIyzSiM)) +[Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). + +Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) +on the JAX documentation site for a list of JAX-based network libraries, which includes +[Optax](https://github.com/deepmind/optax) for gradient processing and +optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and +[Equinox](https://github.com/patrick-kidger/equinox) for neural networks. +(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talk +[here](https://www.youtube.com/watch?v=iDxJxIyzSiM) for additional details.) ## Citing JAX From a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 4 Nov 2024 21:04:48 -0800 Subject: [PATCH 197/698] rollback due to data race Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4 PiperOrigin-RevId: 693191298 --- jax/_src/api.py | 4 +- jax/_src/compute_on.py | 7 +- jax/_src/config.py | 649 ++++++++++++++++----------------------- jax/_src/core.py | 15 +- jax/_src/mesh.py | 14 +- jax/_src/xla_metadata.py | 12 +- tests/pmap_test.py | 2 + 7 files changed, 297 insertions(+), 406 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index a902c7de4c3e..cc42a37b0e7c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -123,8 +123,8 @@ def _update_debug_special_global(_): jax_jit.global_state().post_hook = None def _update_debug_special_thread_local(_): - if (config.debug_nans.get_local() == True or - config.debug_infs.get_local() == True): + if (getattr(config._thread_local_state, "jax_debug_nans", False) or + getattr(config._thread_local_state, "jax_debug_infs", False)): jax_jit.thread_local_state().post_hook = _nan_check_posthook else: jax_jit.thread_local_state().post_hook = None diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 7bd9b9b08b7b..b5194ddad21d 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -29,8 +29,8 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) - config.compute_on_context_manager.set_local( - tuple(compute_on_context.stack)) + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) try: if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: raise NotImplementedError( @@ -39,7 +39,8 @@ def extend_compute_type(c_type: str): yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() - config.compute_on_context_manager.set_local(tuple(compute_on_context.stack)) + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None diff --git a/jax/_src/config.py b/jax/_src/config.py index 0e113c894695..0860168b23a3 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -22,23 +22,14 @@ import os import sys import threading -from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING +from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast from jax._src import lib from jax._src.lib import guard_lib from jax._src.lib import jax_jit from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version from jax._src import logging_config -# TODO(phawkins): reenable pytype after xla_extension_version >= 295 -# pytype: skip-file - -if xla_extension_version >= 295: - config_ext = xla_client._xla.config -else: - config_ext = None - logger = logging.getLogger(__name__) _T = TypeVar('_T') @@ -200,79 +191,49 @@ def parse_flags_with_absl(self): already_configured_with_absl = True -if xla_extension_version >= 295: - def trace_context(): - """Returns a tuple of configuration values that affect tracing. +def trace_context(): + """Returns a tuple of configuration values that affect tracing. - These values are included in the cache key for linear_util.cache. + These values are included in the cache key for linear_util.cache. - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately. - """ - return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value, - compute_on_context_manager.value, enable_x64.value, - numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, - eager_constant_folding.value, - numpy_dtype_promotion.value, - default_device.value, random_seed_offset.value, - threefry_partitionable.value, - threefry_gpu_kernel_lowering.value, - sharding_in_types.value, - softmax_custom_jvp.value, - enable_memories.value, - disable_jit.value, - debug_key_reuse.value, - jax_xla_profile_version.value, - # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value, - pgle_profiling_runs.value, - enable_pgle.value, - use_shardy_partitioner.value) -else: - def trace_context(): - """Returns a tuple of configuration values that affect tracing. - - These values are included in the cache key for linear_util.cache. - - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately. - """ - tls = jax_jit.thread_local_state() - axis_env_state = () - mesh_context_manager = () - xla_metadata_context_manager = () - compute_on_context_manager = () - - context: Any = tls.extra_jit_context - if context and context.axis_env_state is not None: - axis_env_state = context.axis_env_state - if context and context.mesh_context_manager: - mesh_context_manager = context.mesh_context_manager - if context and context.xla_metadata_context_manager: - xla_metadata_context_manager = context.xla_metadata_context_manager - if context and context.compute_on_context_manager: - compute_on_context_manager = context.compute_on_context_manager - return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, - compute_on_context_manager, enable_x64.value, - numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, - eager_constant_folding.value, - numpy_dtype_promotion.value, - default_device.value, random_seed_offset.value, - threefry_partitionable.value, - threefry_gpu_kernel_lowering.value, - sharding_in_types.value, - softmax_custom_jvp.value, - enable_memories.value, - disable_jit.value, - debug_key_reuse.value, - jax_xla_profile_version.value, - # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value, - pgle_profiling_runs.value, - enable_pgle.value, - use_shardy_partitioner.value) + Values included in this set should also most likely be included in + the C++ JIT state, which is handled separately. + """ + tls = jax_jit.thread_local_state() + axis_env_state = () + mesh_context_manager = () + xla_metadata_context_manager = () + compute_on_context_manager = () + + context: Any = tls.extra_jit_context + if context and context.axis_env_state is not None: + axis_env_state = context.axis_env_state + if context and context.mesh_context_manager: + mesh_context_manager = context.mesh_context_manager + if context and context.xla_metadata_context_manager: + xla_metadata_context_manager = context.xla_metadata_context_manager + if context and context.compute_on_context_manager: + compute_on_context_manager = context.compute_on_context_manager + return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, + compute_on_context_manager, enable_x64.value, + numpy_rank_promotion.value, default_matmul_precision.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, + default_device.value, random_seed_offset.value, + threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, + sharding_in_types.value, + softmax_custom_jvp.value, + enable_memories.value, + disable_jit.value, + debug_key_reuse.value, + jax_xla_profile_version.value, + # Technically this affects jaxpr->stablehlo lowering, not tracing. + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value, + use_shardy_partitioner.value) config = Config() @@ -284,185 +245,94 @@ def trace_context(): class NoDefault: pass no_default = NoDefault() -if xla_extension_version >= 295: - class State(config_ext.Config[_T]): - __slots__ = ( - '_name', '_update_thread_local_hook', '_update_global_hook', - '_validator', '_default_context_manager_value', '__doc__', '__name__', - ) +class _Unset: pass +unset = _Unset() - def __init__( - self, - name: str, - default: _T, - help, - update_global_hook: Callable[[_T], None] | None = None, - update_thread_local_hook: Callable[[_T | None], None] | None = None, - validator: Callable[[Any], None] | None = None, - extra_description: str = '', - default_context_manager_value: Any = no_default, - include_in_jit_key: bool = False, - ): - super().__init__(default, include_in_jit_key) - self._name = name - self.__name__ = name[4:] if name.startswith('jax_') else name - self.__doc__ = (f"Context manager for `{name}` config option" - f"{extra_description}.\n\n{help}") - self._update_global_hook = update_global_hook - self._update_thread_local_hook = update_thread_local_hook - self._validator = validator - self._default_context_manager_value = default_context_manager_value - if self._validator: - self._validator(default) - if self._update_global_hook: - self._update_global_hook(default) - - def __bool__(self) -> NoReturn: - raise TypeError( - "bool() not supported for instances of type '{0}' " - "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) - - def _set(self, value: _T) -> None: - if self._validator: - self._validator(value) - self.set_global(value) - if self._update_global_hook: - self._update_global_hook(value) - - @contextlib.contextmanager - def __call__(self, new_val: Any = no_default): - if new_val is no_default: - if self._default_context_manager_value is not no_default: - new_val = self._default_context_manager_value # default_context_manager_value provided to constructor - else: - # no default_value provided to constructor and no value provided as an - # argument, so we raise an error - raise TypeError(f"Context manager for {self.__name__} config option " - "requires an argument representing the new value for " - "the config option.") - if self._validator: - self._validator(new_val) - prev_val = self.swap_local(new_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(new_val) - try: - yield - finally: - self.set_local(prev_val) - if self._update_thread_local_hook: - if prev_val is config_ext.unset: - self._update_thread_local_hook(None) - else: - self._update_thread_local_hook(cast(Optional[Any], prev_val)) +_thread_local_state = threading.local() - def _add_hooks(self, update_global_hook, update_thread_local_hook): - """Private method that adds hooks to an existing context-manager. +class State(Generic[_T]): - Used to avoid cyclic import dependencies.""" - self._update_thread_local_hook = update_thread_local_hook - self._update_global_hook = update_global_hook - update_global_hook(self.get_global()) + __slots__ = ( + '_name', '_value', '_update_thread_local_hook', '_update_global_hook', + '_validator', '_default_context_manager_value', '__doc__', '__name__', + ) -else: - class _Unset: pass - unset = _Unset() + def __init__( + self, + name: str, + default: _T, + help, + update_global_hook: Callable[[_T], None] | None = None, + update_thread_local_hook: Callable[[_T | None], None] | None = None, + validator: Callable[[Any], None] | None = None, + extra_description: str = '', + default_context_manager_value: Any = no_default, + ): + self._name = name + self.__name__ = name[4:] if name.startswith('jax_') else name + self.__doc__ = (f"Context manager for `{name}` config option" + f"{extra_description}.\n\n{help}") + self._update_global_hook = update_global_hook + self._update_thread_local_hook = update_thread_local_hook + self._validator = validator + self._default_context_manager_value = default_context_manager_value + self._set(default) - _thread_local_state = threading.local() + def __bool__(self) -> NoReturn: + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__)) - class State(Generic[_T]): + def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) + self._value = value + if self._update_global_hook: + self._update_global_hook(value) - __slots__ = ( - '_name', '_value', '_update_thread_local_hook', '_update_global_hook', - '_validator', '_default_context_manager_value', '__doc__', '__name__', - ) + @property + def value(self) -> _T: + val = _thread_local_state.__dict__.get(self._name, unset) + return cast(_T, val) if val is not unset else self._value + + @contextlib.contextmanager + def __call__(self, new_val: Any = no_default): + if new_val is no_default: + if self._default_context_manager_value is not no_default: + new_val = self._default_context_manager_value # default_context_manager_value provided to constructor + else: + # no default_value provided to constructor and no value provided as an + # argument, so we raise an error + raise TypeError(f"Context manager for {self.__name__} config option " + "requires an argument representing the new value for " + "the config option.") + if self._validator: + self._validator(new_val) + prev_val = getattr(_thread_local_state, self._name, unset) + setattr(_thread_local_state, self._name, new_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(new_val) + try: + yield + finally: + if prev_val is unset: + delattr(_thread_local_state, self._name) + if self._update_thread_local_hook: + self._update_thread_local_hook(None) + else: + setattr(_thread_local_state, self._name, prev_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(cast(_T, prev_val)) - def __init__( - self, - name: str, - default: _T, - help, - update_global_hook: Callable[[_T], None] | None = None, - update_thread_local_hook: Callable[[_T | None], None] | None = None, - validator: Callable[[Any], None] | None = None, - extra_description: str = '', - default_context_manager_value: Any = no_default, - include_in_jit_key: bool = False, - ): - self._name = name - self.__name__ = name[4:] if name.startswith('jax_') else name - self.__doc__ = (f"Context manager for `{name}` config option" - f"{extra_description}.\n\n{help}") - if include_in_jit_key: - assert update_global_hook is None - assert update_thread_local_hook is None - update_global_hook = lambda val: _update_global_jit_state( - **{self.__name__: val}) - update_thread_local_hook = lambda val: update_thread_local_jit_state( - **{self.__name__: val}) - self._update_global_hook = update_global_hook - self._update_thread_local_hook = update_thread_local_hook - self._validator = validator - self._default_context_manager_value = default_context_manager_value - self._set(default) - def __bool__(self) -> NoReturn: - raise TypeError( - "bool() not supported for instances of type '{0}' " - "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) - - def _set(self, value: _T) -> None: - if self._validator: - self._validator(value) - self._value = value - if self._update_global_hook: - self._update_global_hook(value) - - @property - def value(self) -> _T: - val = _thread_local_state.__dict__.get(self._name, unset) - return cast(_T, val) if val is not unset else self._value - - def get_local(self) -> Any: - return _thread_local_state.__dict__.get(self._name, unset) - - @contextlib.contextmanager - def __call__(self, new_val: Any = no_default): - if new_val is no_default: - if self._default_context_manager_value is not no_default: - new_val = self._default_context_manager_value # default_context_manager_value provided to constructor - else: - # no default_value provided to constructor and no value provided as an - # argument, so we raise an error - raise TypeError(f"Context manager for {self.__name__} config option " - "requires an argument representing the new value for " - "the config option.") - if self._validator: - self._validator(new_val) - prev_val = getattr(_thread_local_state, self._name, unset) - setattr(_thread_local_state, self._name, new_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(new_val) - try: - yield - finally: - if prev_val is unset: - delattr(_thread_local_state, self._name) - if self._update_thread_local_hook: - self._update_thread_local_hook(None) - else: - setattr(_thread_local_state, self._name, prev_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(cast(_T, prev_val)) - - def _add_hooks(self, update_global_hook, update_thread_local_hook): - """Private method that adds hooks to an existing context-manager. - - Used to avoid cyclic import dependencies.""" - self._update_thread_local_hook = update_thread_local_hook - self._update_global_hook = update_global_hook - update_global_hook(self._value) + def _add_hooks(self, update_global_hook, update_thread_local_hook): + """Private method that adds hooks to an existing context-manager. + + Used to avoid cyclic import dependencies.""" + self._update_thread_local_hook = update_thread_local_hook + self._update_global_hook = update_global_hook + update_global_hook(self._value) UPGRADE_BOOL_HELP = ( @@ -483,7 +353,6 @@ def bool_state( update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', - include_in_jit_key: bool = False, ) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. @@ -548,8 +417,7 @@ def bool_state( s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, - extra_description=extra_description, default_context_manager_value=True, - include_in_jit_key=include_in_jit_key) + extra_description=extra_description, default_context_manager_value=True) config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -563,7 +431,6 @@ def enum_state( *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, - include_in_jit_key: bool = False, ) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. @@ -603,7 +470,6 @@ def validator(new_val): update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator, - include_in_jit_key=include_in_jit_key, ) config.add_option( name, s, 'enum', @@ -622,7 +488,6 @@ def optional_enum_state( *, update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, - include_in_jit_key: bool = False, ) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. @@ -658,7 +523,7 @@ def validate(new_val): s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, - validate, include_in_jit_key=include_in_jit_key, + validate ) config.add_option( name, s, 'enum', @@ -676,7 +541,6 @@ def int_state( *, update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, - include_in_jit_key: bool = False, ) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. @@ -711,8 +575,7 @@ def validate(new_val): f'got {new_val} of type {type(new_val)}') s = State[int](name, default, help, update_global_hook, - update_thread_local_hook, validate, - include_in_jit_key=include_in_jit_key) + update_thread_local_hook, validate) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -963,119 +826,92 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: already_configured_with_absl = False -if xla_extension_version >= 295: - trace_state = config_ext.Config(None, include_in_jit_key=True) - axis_env_state = config_ext.Config((), include_in_jit_key=True) - mesh_context_manager = config_ext.Config((), include_in_jit_key=True) - compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) - xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) -else: - # The C++ JIT maintains its own copy of several configuration items as - # a global/thread-local state. These methods allow updates to part of the - # state when a configuration value changes. - class _GlobalExtraJitContext(NamedTuple): - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool = False - eager_constant_folding: bool = False - random_seed_offset: int = 0 - threefry_partitionable: bool = False - threefry_gpu_kernel_lowering: bool = False - sharding_in_types: bool = False - softmax_custom_jvp: bool = False - xla_profile_version: int = 0 - pgle_profiling_runs: int = 0 - enable_pgle: bool = False - use_shardy_partitioner: bool = False - - - def _update_global_jit_state(**kw): - gs = jax_jit.global_state() - context = gs.extra_jit_context or _GlobalExtraJitContext() - gs.extra_jit_context = context._replace(**kw) - - - class _ThreadLocalExtraJitContext(NamedTuple): - """A namedtuple containing states to add to the cache key. - - Just in time compilation (for jit, pmap, etc) behavior is configurable through - global and thread-local options, used in the cache key. - - The initialization, which uses both config.py and core.py is done using - `_update_thread_local_jit_state` in core.py to prevent circular imports. - """ - trace_state: Any | None = None - axis_env_state: Hashable = () - mesh_context_manager: Hashable = () - compute_on_context_manager: Hashable = () - xla_metadata_context_manager: Hashable = () - - # Values set by _StateContextManager context managers. - # CAUTION: these must be initialized to `None`! The state context manager - # restores these to None on exit. If the object default is not `None`, the - # context manager is not a no-op, which leads to problems with stale state - # (e.g. spurious cache misses in tests). - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool | None = None - eager_constant_folding : bool | None = None - random_seed_offset: int | None = None - threefry_partitionable: bool | None = None - threefry_gpu_kernel_lowering: bool | None = None - sharding_in_types: bool | None = None - softmax_custom_jvp: bool | None = None - xla_profile_version: int | None = None - pgle_profiling_runs: int | None = None - enable_pgle: bool | None = None - use_shardy_partitioner: bool | None = None - - - class _ThreadLocalStateCache(threading.local): - """"A thread local cache for _ThreadLocalExtraJitContext - - The extra_jit_context in jax_jit.thread_local_state() may get updated and thus - incurring dispatch overhead for comparing this python object during jit calls. - We want to deduplicate the objects that have the same hash/equality to also - have the same object ID, since the equality check is much faster if the object - IDs match. - """ - def __init__(self): - self.canonicalize = functools.lru_cache(128)(lambda x: x) - - - _thread_local_state_cache = _ThreadLocalStateCache() - - - def update_thread_local_jit_state(**kw): - tls = jax_jit.thread_local_state() - # After xla_client._version >= 70, the thread_local object will necessarily - # be initialized when accessed. The following line can be removed when the - # minimum jaxlib version is past version 70 - context = tls.extra_jit_context or _ThreadLocalExtraJitContext() - tmp = context._replace(**kw) - tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) - - class JitConfig: - def __init__(self, name): - self._name = name - - def value(self): - return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) +# The C++ JIT maintains its own copy of several configuration items as +# a global/thread-local state. These methods allow updates to part of the +# state when a configuration value changes. +class _GlobalExtraJitContext(NamedTuple): + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None + dynamic_shapes: bool = False + eager_constant_folding: bool = False + random_seed_offset: int = 0 + threefry_partitionable: bool = False + threefry_gpu_kernel_lowering: bool = False + sharding_in_types: bool = False + softmax_custom_jvp: bool = False + xla_profile_version: int = 0 + pgle_profiling_runs: int = 0 + enable_pgle: bool = False + use_shardy_partitioner: bool = False + + +def _update_global_jit_state(**kw): + gs = jax_jit.global_state() + context = gs.extra_jit_context or _GlobalExtraJitContext() + gs.extra_jit_context = context._replace(**kw) + + +class _ThreadLocalExtraJitContext(NamedTuple): + """A namedtuple containing states to add to the cache key. + + Just in time compilation (for jit, pmap, etc) behavior is configurable through + global and thread-local options, used in the cache key. + + The initialization, which uses both config.py and core.py is done using + `_update_thread_local_jit_state` in core.py to prevent circular imports. + """ + trace_state: Any | None = None + axis_env_state: Hashable = () + mesh_context_manager: Hashable = () + compute_on_context_manager: Hashable = () + xla_metadata_context_manager: Hashable = () + + # Values set by _StateContextManager context managers. + # CAUTION: these must be initialized to `None`! The state context manager + # restores these to None on exit. If the object default is not `None`, the + # context manager is not a no-op, which leads to problems with stale state + # (e.g. spurious cache misses in tests). + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None + dynamic_shapes: bool | None = None + eager_constant_folding : bool | None = None + random_seed_offset: int | None = None + threefry_partitionable: bool | None = None + threefry_gpu_kernel_lowering: bool | None = None + sharding_in_types: bool | None = None + softmax_custom_jvp: bool | None = None + xla_profile_version: int | None = None + pgle_profiling_runs: int | None = None + enable_pgle: bool | None = None + use_shardy_partitioner: bool | None = None + + +class _ThreadLocalStateCache(threading.local): + """"A thread local cache for _ThreadLocalExtraJitContext + + The extra_jit_context in jax_jit.thread_local_state() may get updated and thus + incurring dispatch overhead for comparing this python object during jit calls. + We want to deduplicate the objects that have the same hash/equality to also + have the same object ID, since the equality check is much faster if the object + IDs match. + """ + def __init__(self): + self.canonicalize = functools.lru_cache(128)(lambda x: x) - def get_local(self): - return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) - def set_local(self, value): - update_thread_local_jit_state(**{self._name: value}) +_thread_local_state_cache = _ThreadLocalStateCache() - trace_state = JitConfig('trace_state') - axis_env_state = JitConfig('axis_env_state') - mesh_context_manager = JitConfig('mesh_context_manager') - compute_on_context_manager = JitConfig('compute_on_context_manager') - xla_metadata_context_manager = JitConfig('xla_metadata_context_manager') +def update_thread_local_jit_state(**kw): + tls = jax_jit.thread_local_state() + # After xla_client._version >= 70, the thread_local object will necessarily + # be initialized when accessed. The following line can be removed when the + # minimum jaxlib version is past version 70 + context = tls.extra_jit_context or _ThreadLocalExtraJitContext() + tmp = context._replace(**kw) + tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) # TODO(b/214340779): remove flag when XLA:CPU is improved. jax2tf_associative_scan_reductions = bool_state( @@ -1266,7 +1102,10 @@ def _update_jax_memories_thread_local(val): name='jax_random_seed_offset', default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), - include_in_jit_key=True, + update_global_hook=lambda val: _update_global_jit_state( + random_seed_offset=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + random_seed_offset=val) ) legacy_prng_key = enum_state( @@ -1301,7 +1140,10 @@ def _update_jax_memories_thread_local(val): 'may result in extraneous communication and/or redundant distributed ' 'computation. With this flag, the communication overheads disappear ' 'in some cases.'), - include_in_jit_key=True) + update_global_hook=lambda val: _update_global_jit_state( + threefry_partitionable=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + threefry_partitionable=val)) threefry_gpu_kernel_lowering = bool_state( name='jax_threefry_gpu_kernel_lowering', @@ -1309,14 +1151,20 @@ def _update_jax_memories_thread_local(val): help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' 'This makes compile times faster at a potential runtime memory ' 'cost.'), - include_in_jit_key=True) + update_global_hook=lambda val: _update_global_jit_state( + threefry_gpu_kernel_lowering=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + threefry_gpu_kernel_lowering=val)) sharding_in_types = bool_state( name='jax_sharding_in_types', default=False, help=('When True, enables forward only sharding propagation in JAX and ' 'avals have sharding on them.'), - include_in_jit_key=True) + update_global_hook=lambda val: _update_global_jit_state( + sharding_in_types=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + sharding_in_types=val)) data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', @@ -1331,7 +1179,10 @@ def _update_jax_memories_thread_local(val): help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' 'behavior. See https://github.com/jax-ml/jax/pull/15677'), - include_in_jit_key=True) + update_global_hook=lambda val: _update_global_jit_state( + softmax_custom_jvp=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + softmax_custom_jvp=val)) enable_custom_vjp_by_custom_transpose = bool_state( @@ -1447,7 +1298,9 @@ def _update_jax_memories_thread_local(val): 'number times with collected data provided to the profile guided latency ' 'estimator.' ), - include_in_jit_key=True, + update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + enable_pgle=val), ) pgle_profiling_runs = int_state( @@ -1457,7 +1310,12 @@ def _update_jax_memories_thread_local(val): 'Amount of times module should be profiled before recompilation when ' 'PGLE is used.' ), - include_in_jit_key=True, + update_global_hook=lambda val: _update_global_jit_state( + pgle_profiling_runs=val + ), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + pgle_profiling_runs=val + ), ) pgle_aggregation_percentile = int_state( @@ -1523,7 +1381,10 @@ def _update_jax_memories_thread_local(val): 'between arrays. Options are "standard" or "strict"; in strict-mode, ' 'binary operations between arrays of differing strongly-specified ' 'dtypes will result in an error.'), - include_in_jit_key=True) + update_global_hook=lambda val: \ + _update_global_jit_state(numpy_dtype_promotion=val), + update_thread_local_hook=lambda val: \ + update_thread_local_jit_state(numpy_dtype_promotion=val)) disallow_mesh_context_manager = bool_state( name='jax_disallow_mesh_context_manager', @@ -1609,7 +1470,10 @@ def _update_disable_jit_thread_local(val): default='allow', help=('Control NumPy-style automatic rank promotion broadcasting ' '("allow", "warn", or "raise").'), - include_in_jit_key=True) + update_global_hook=lambda val: \ + _update_global_jit_state(numpy_rank_promotion=val), + update_thread_local_hook=lambda val: \ + update_thread_local_jit_state(numpy_rank_promotion=val)) default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', @@ -1645,7 +1509,10 @@ def _update_disable_jit_thread_local(val): '"algorithm" for functions that perform matrix multiplications, like ' ':func:`jax.lax.dot`. To specify an algorithm, set this option to ' 'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'), - include_in_jit_key=True) + update_global_hook=lambda val: \ + _update_global_jit_state(default_matmul_precision=val), + update_thread_local_hook=lambda val: \ + update_thread_local_jit_state(default_matmul_precision=val)) traceback_filtering = enum_state( name = 'jax_traceback_filtering', @@ -1680,14 +1547,20 @@ def _update_disable_jit_thread_local(val): default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), help=('Enables experimental features for staging out computations with ' 'dynamic shapes.'), - include_in_jit_key=True) + update_global_hook=lambda val: \ + _update_global_jit_state(dynamic_shapes=val), + update_thread_local_hook=lambda val: \ + update_thread_local_jit_state(dynamic_shapes=val)) # This is for stackless backward compat with e.g. equinox eager_constant_folding = bool_state( name='eager_constant_folding', default=False, help=('Attempt constant folding during staging.'), - include_in_jit_key=True) + update_global_hook=lambda val: \ + _update_global_jit_state(eager_constant_folding=val), + update_thread_local_hook=lambda val: \ + update_thread_local_jit_state(eager_constant_folding=val)) # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. @@ -1746,7 +1619,10 @@ def _update_disable_jit_thread_local(val): 'Optional profile version for XLA compilation. This is meaningful ' 'only when XLA is configured to support the remote compilation ' 'profile feature.'), - include_in_jit_key=True, + update_global_hook=lambda val: _update_global_jit_state( + xla_profile_version=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + xla_profile_version=val), ) @contextlib.contextmanager @@ -1945,5 +1821,10 @@ def _update_debug_log_modules(module_names_str: str | None): 'framework for MLIR. Currently Shardy is experimental in JAX. See ' 'www.github.com/openxla/shardy' ), - include_in_jit_key=True, + update_global_hook=lambda val: _update_global_jit_state( + use_shardy_partitioner=val + ), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + use_shardy_partitioner=val + ), ) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7c92fa1b5236..7d912e3c207b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1016,16 +1016,18 @@ def is_top_level(self) -> bool: def set_trace(self, trace): self.trace = trace ts = ref(trace) if trace is not None else None - config.trace_state.set_local(ts) + config.update_thread_local_jit_state(trace_state=ts) def set_axis_env(self, axis_env): self.axis_env = axis_env - config.axis_env_state.set_local(axis_env.as_hashable_key()) + config.update_thread_local_jit_state( + axis_env_state=self.axis_env.as_hashable_key()) def update_thread_local_jit_state(self): ts = ref(self.trace) if self.trace is not None else None - config.trace_state.set_local(ts) - config.axis_env_state.set_local(self.axis_env.as_hashable_key()) + config.update_thread_local_jit_state( + trace_state=ts, + axis_env_state=self.axis_env.as_hashable_key()) trace_ctx = TracingContext() @@ -1069,7 +1071,10 @@ def _initialize_jax_jit_thread_local_state(): This function does not live in `config.py`, to prevent circular imports. """ - trace_ctx.update_thread_local_jit_state() + tls = jax_jit.thread_local_state() + + if tls.extra_jit_context is None: + trace_ctx.update_thread_local_jit_state() jax_jit.set_thread_local_state_initialization_callback( _initialize_jax_jit_thread_local_state) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 43791f2e5f72..8cb508378129 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -224,17 +224,17 @@ def __enter__(self): new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env - jax_config.mesh_context_manager.set_local( - tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.update_thread_local_jit_state( + mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return self def __exit__(self, exc_type, exc_value, traceback): thread_resources.stack.pop() thread_resources.env = thread_resources.stack[-1] - jax_config.mesh_context_manager.set_local( - tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.update_thread_local_jit_state( + mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return False @property @@ -410,7 +410,7 @@ def __exit__(self, exc_type, exc_value, traceback): @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): - jax_config.mesh_context_manager.set_local(mesh) + jax_config.update_thread_local_jit_state(mesh_context_manager=mesh) return diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 94e26eeefa65..94b482e2dea4 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -41,13 +41,15 @@ def set_xla_metadata(*args, **kwargs): thread_local_metadata.val, new_metadata, ) - config.xla_metadata_context_manager.set_local( - tuple((v, k) for k, v in sorted(new_metadata.items())) - ) + config.update_thread_local_jit_state( + xla_metadata_context_manager=tuple( + (v, k) for k, v in sorted(new_metadata.items()))) try: yield finally: thread_local_metadata.val = prev_metadata - config.xla_metadata_context_manager.set_local( - tuple((v, k) for k, v in sorted(prev_metadata.items())) + config.update_thread_local_jit_state( + xla_metadata_context_manager=tuple( + (v, k) for k, v in sorted(prev_metadata.items()) + ) ) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 0df5d99715e6..6e0e795df334 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2215,6 +2215,8 @@ def test_cache_uses_jax_key(self): pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) + config.update_thread_local_jit_state() + pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) From a80d027dd7dbf164c757c74774dd62cf135ff9cb Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Tue, 5 Nov 2024 12:29:20 +0530 Subject: [PATCH 198/698] Fix Typos --- docs/export/export.md | 4 ++-- docs/export/shape_poly.md | 12 ++++++------ jax/experimental/jax2tf/README.md | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/export/export.md b/docs/export/export.md index b62cf9fe0113..aa686b03e2b2 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -247,7 +247,7 @@ for which the code was exported. You can specify explicitly for what platforms the code should be exported. This allows you to specify a different accelerator than you have available at export time, -and it even allows you to specify multi-platform lexport to +and it even allows you to specify multi-platform export to obtain an `Exported` object that can be compiled and executed on multiple platforms. @@ -293,7 +293,7 @@ resulting module size should be only marginally larger than the size of a module with default export. As an extreme case, when serializing a module without any primitives with platform-specific lowering, you will get -the same StableHLO as for the single-plaform export. +the same StableHLO as for the single-platform export. ```python >>> import jax diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index b1ce80638706..6ad7fb5c2b09 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -44,7 +44,7 @@ following example: ``` Note that such functions are still re-compiled on demand for -each concrete input shapes they are invoked on. Only the +each concrete input shape they are invoked on. Only the tracing and the lowering are saved. The {func}`jax.export.symbolic_shape` is used in the above @@ -98,7 +98,7 @@ A few examples of shape specifications: arguments. Note that the same specification would work if the first argument is a pytree of 3D arrays, all with the same leading dimension but possibly with different trailing dimensions. - The value `None` for the second arugment means that the argument + The value `None` for the second argument means that the argument is not symbolic. Equivalently, one can use `...`. * `("(batch, ...)", "(batch,)")` specifies that the two arguments @@ -256,7 +256,7 @@ as follows: integers. E.g., `b >= 1`, `b >= 0`, `2 * a + b >= 3` are `True`, while `b >= 2`, `a >= b`, `a - b >= 0` are inconclusive and result in an exception. -In cases where a comparison operation cannot be resolve to a boolean, +In cases where a comparison operation cannot be resolved to a boolean, we raise {class}`InconclusiveDimensionOperation`. E.g., ```python @@ -351,7 +351,7 @@ symbolic constraints: is encountered, it is rewritten to the expression on the right. E.g., `floordiv(a, b) == c` works by replacing all - occurences of `floordiv(a, b)` with `c`. + occurrences of `floordiv(a, b)` with `c`. Equality constraints must not contain addition or subtraction at the top-level on the left-hand-side. Examples of valid left-hand-sides are `a * b`, or `4 * a`, or @@ -498,11 +498,11 @@ This works well for most use cases, and it mirrors the calling convention of JIT functions. Sometimes you may want to export a function parameterized -by an integer values that determines some shapes in the program. +by an integer value that determines some shapes in the program. For example, we may want to export the function `my_top_k` defined below, parameterized by the -value of `k`, which determined the shape of the result. +value of `k`, which determines the shape of the result. The following attempt will lead to an error since the dimension variable `k` cannot be derived from the shape of the input `x: i32[4, 10]`: diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index caf63df17bf1..da33a677ba07 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -237,7 +237,7 @@ params_vars = tf.nest.map_structure(tf.Variable, params) prediction_tf = lambda inputs: jax2tf.convert(model_jax)(params_vars, inputs) my_model = tf.Module() -# Tell the model saver what are the variables. +# Tell the model saver what the variables are. my_model._variables = tf.nest.flatten(params_vars) my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False) tf.saved_model.save(my_model) @@ -760,7 +760,7 @@ symbolic constraints: We plan to improve somewhat this area in the future. * Equality constraints are treated as normalization rules. E.g., `floordiv(a, b) = c` works by replacing all - occurences of the left-hand-side with the right-hand-side. + occurrences of the left-hand-side with the right-hand-side. You can only have equality constraints where the left-hand-side is a multiplication of factors, e.g, `a * b`, or `4 * a`, or `floordiv(a, b)`. Thus, the left-hand-side cannot contain @@ -1048,7 +1048,7 @@ jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32 tf.function(jax2tf.convert(jnp.sin), autograph=False)(tf.Variable(3.14, dtype=tf.float64)) ``` -When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types +When the `JAX_ENABLE_X64` flag is set, JAX uses 64-bit types for Python scalars and respects the explicit 64-bit types: ```python @@ -1245,7 +1245,7 @@ Applies to both native and non-native serialization. trackable classes during attribute assignment. Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper classes. -In most situation, these Wrapper classes work exactly as the standard +In most situations, these Wrapper classes work exactly as the standard Python data types. However, the low-level pytree data structures are different and this can lead to errors. @@ -1499,7 +1499,7 @@ during lowering we try to generate one TensorFlow op for one JAX primitive. We expect that the lowering that XLA does is similar to that done by JAX before conversion. (This is a hypothesis, we have not yet verified it extensively.) -There is one know case when the performance of the lowered code will be different. +There is one known case when the performance of the lowered code will be different. JAX programs use a [stateless deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md) and it has an internal JAX primitive for it. From 63e59c5fd73cea18831cd13b5867dae06607e3de Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 5 Nov 2024 00:46:40 -0800 Subject: [PATCH 199/698] [Mosaic GPU] Ensure that the dialect module can be loaded successfully. This requires that the file providing the bindings has the same name as the dialect it defines, since dialect search looks for a module path of the form `.`. PiperOrigin-RevId: 693241875 --- jax/_src/lib/__init__.py | 2 +- jaxlib/mosaic/python/BUILD | 2 +- jaxlib/mosaic/python/{gpu.py => mosaic_gpu.py} | 6 +++++- tests/mosaic/gpu_dialect_test.py | 8 ++++++++ 4 files changed, 15 insertions(+), 3 deletions(-) rename jaxlib/mosaic/python/{gpu.py => mosaic_gpu.py} (84%) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 7f9936a7d180..2810002013ac 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -121,7 +121,7 @@ def _xla_gc_callback(*args): import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 try: - import jaxlib.mosaic.python.gpu as mosaic_gpu_dialect # pytype: disable=import-error + import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error except ImportError: # TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36. # Jaxlib doesn't contain Mosaic GPU dialect bindings. diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 6899914e6b89..ef6230f70321 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -20,7 +20,7 @@ load("@rules_python//python:defs.bzl", "py_library") py_library( name = "gpu_dialect", srcs = [ - "gpu.py", + "mosaic_gpu.py", "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py", ], visibility = ["//visibility:public"], diff --git a/jaxlib/mosaic/python/gpu.py b/jaxlib/mosaic/python/mosaic_gpu.py similarity index 84% rename from jaxlib/mosaic/python/gpu.py rename to jaxlib/mosaic/python/mosaic_gpu.py index 755a4d3eff7d..3157242e48a8 100644 --- a/jaxlib/mosaic/python/gpu.py +++ b/jaxlib/mosaic/python/mosaic_gpu.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Python bindings for the MLIR Mosaic GPU dialect.""" +"""Python bindings for the MLIR Mosaic GPU dialect. + +Note: this file *must* be called `mosaic_gpu.py`, in order to match the dialect +name. Otherwise, MLIR is unable to find the module during dialect search. +""" # ruff: noqa: F401 # ruff: noqa: F403 diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 19701012f706..d6428a98c96a 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -21,6 +21,9 @@ from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member +_cext = mgpu._cext if mgpu is not None else None + + config.parse_flags_with_absl() @@ -40,6 +43,9 @@ def setUp(self): self.enter_context(ir.Location.unknown()) self.module = ir.Module.create() + def test_dialect_module_is_loaded(self): + self.assertTrue(_cext.globals._check_dialect_module_loaded("mosaic_gpu")) + def test_initialize_barrier_op_result_memref_must_wrap_barriers(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( @@ -62,6 +68,8 @@ def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self): ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), arrival_count=1) self.assertTrue(self.module.operation.verify()) + self.assertIsInstance(self.module.body.operations[0], + mgpu.InitializeBarrierOp) if __name__ == "__main__": From 34b4787e2eff9edbd8eca242a74f1c165388b871 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 5 Nov 2024 01:59:21 -0800 Subject: [PATCH 200/698] [mosaic_gpu] Check the return code of `gpuEventCreate` and `gpuEventDestroy` PiperOrigin-RevId: 693260326 --- jaxlib/mosaic/gpu/BUILD | 1 + jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 42 ++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 4ec643dc63c8..875dc1d151ba 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -184,6 +184,7 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", + "@com_google_absl//absl/strings", "@nanobind", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 55801ebdb8d4..922d13d213f5 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include "nanobind/nanobind.h" +#include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/service/custom_call_status.h" @@ -23,33 +26,52 @@ limitations under the License. namespace jax::cuda { namespace { -namespace nb = nanobind; +static std::string ToString(CUresult result) { + const char* error_name; + if (cuGetErrorName(result, &error_name)) { + return absl::StrCat("UNKNOWN ERROR (", static_cast(result), ")"); + } + const char* error_string; + if (cuGetErrorString(result, &error_string)) { + return error_name; + } + return absl::StrCat(error_name, ": ", error_string); +} void EventRecordCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto* event = reinterpret_cast(opaque); - if (gpuEventRecord(**event, reinterpret_cast(stream)) != - gpuSuccess) { - const char message[] = "Failed to record event"; - XlaCustomCallStatusSetFailure(status, message, sizeof(message)); + if (auto res = gpuEventRecord(**event, reinterpret_cast(stream)); + res) { + auto message = absl::StrCat("Failed to record event: ", ToString(res)); + XlaCustomCallStatusSetFailure(status, message.c_str(), message.size()); } } NB_MODULE(_mosaic_gpu_ext, m) { m.def("_gpu_event_create", []() { gpuEvent_t* event = new gpuEvent_t(); - gpuEventCreate(event, GPU_EVENT_DEFAULT); + if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); res) { + throw std::runtime_error( + absl::StrCat("Failed to create event: ", ToString(res))); + } return reinterpret_cast(event); }); m.def("_gpu_event_destroy", [](uintptr_t event) { - gpuEventDestroy(*reinterpret_cast(event)); + if (auto res = gpuEventDestroy(*reinterpret_cast(event)); + res) { + throw std::runtime_error( + absl::StrCat("Failed to destroy event: ", ToString(res))); + } }); m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) { float elapsed_ms = -1; - if (gpuEventElapsedTime( + if (auto res = gpuEventElapsedTime( &elapsed_ms, *reinterpret_cast(start_event), - *reinterpret_cast(end_event)) != gpuSuccess) { - throw std::runtime_error("Failed to get elapsed time between events"); + *reinterpret_cast(end_event)); + res) { + throw std::runtime_error(absl::StrCat( + "Failed to get elapsed time between events: ", ToString(res))); } return elapsed_ms; }); From 5f90f63d19d33c13f6c917431862014d7d2d92c0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 5 Nov 2024 05:13:57 -0800 Subject: [PATCH 201/698] Improve efficiency of jax.scipy.stats.rankdata --- jax/_src/scipy/stats/_core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 08d1c0b6b538..f7b28d3ac301 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -198,13 +198,12 @@ def rankdata( return jnp.apply_along_axis(rankdata, axis, a, method) arr = jnp.ravel(a) - sorter = jnp.argsort(arr) + arr, sorter = jax.lax.sort_key_val(arr, jnp.arange(len(arr))) inv = invert_permutation(sorter) if method == "ordinal": return inv + 1 - arr = arr[sorter] - obs = jnp.insert(arr[1:] != arr[:-1], 0, True) + obs = jnp.concatenate([jnp.array([True]), arr[1:] != arr[:-1]]) dense = obs.cumsum()[inv] if method == "dense": return dense From 478b750c290376f6ef820fae45d84096928b5096 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 5 Nov 2024 07:16:32 -0800 Subject: [PATCH 202/698] Reverts f281c6f46475270a57a02416469226315377592c PiperOrigin-RevId: 693339094 --- jax/_src/ad_util.py | 3 +- jax/_src/core.py | 116 ++++++---------------- jax/_src/custom_derivatives.py | 27 +++-- jax/_src/interpreters/ad.py | 9 +- jax/_src/interpreters/batching.py | 8 +- jax/_src/interpreters/partial_eval.py | 24 ++--- jax/_src/lax/control_flow/conditionals.py | 10 +- jax/_src/lax/control_flow/loops.py | 6 +- jax/_src/lax/control_flow/solves.py | 3 +- jax/_src/lax/lax.py | 8 +- jax/_src/lax/linalg.py | 4 +- jax/_src/lax/parallel.py | 15 ++- jax/_src/pallas/core.py | 25 ++--- jax/_src/pallas/mosaic/core.py | 9 -- jax/_src/pallas/mosaic/verification.py | 3 +- jax/_src/pallas/mosaic_gpu/core.py | 11 +- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- jax/_src/state/primitives.py | 2 - jax/_src/state/types.py | 11 +- tests/core_test.py | 9 -- tests/lax_test.py | 2 +- tests/pallas/tpu_pallas_test.py | 5 +- tests/state_test.py | 5 +- 23 files changed, 109 insertions(+), 208 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index bd1427f59e01..02f3b0405e38 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -43,7 +43,8 @@ def add_impl(x, y): @add_jaxvals_p.def_abstract_eval def add_abstract(x, y): - return core.lattice_join(x, y) + assert core.typematch(x, y) + return x def zeros_like_aval(aval: core.AbstractValue) -> Array: return aval_zeros_likers[type(aval)](aval) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7d912e3c207b..08ac0668265a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -368,7 +368,7 @@ class Var: def __init__(self, suffix: str, aval: AbstractValue): self.count = next(_var_counter) self.suffix = suffix - self.aval = raise_to_shaped(aval) + self.aval = aval # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not # care about variable ordering, but the downstream package kfac_jax does. @@ -662,7 +662,7 @@ def __init__(self, trace: Trace): def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" - return f"traced array with shape {raise_to_shaped(self.aval).str_short()}" + return f"traced array with shape {self.aval.str_short()}" def __array__(self, *args, **kw): raise TracerArrayConversionError(self) @@ -1302,11 +1302,14 @@ def __repr__(self): except AttributeError: return self.__class__.__name__ - def strip_weak_type(self) -> AbstractValue: + def update_weak_type(self, weak_type): return self - def join(self, other): - raise NotImplementedError("must override") + def strip_weak_type(self) -> AbstractValue: + return self.update_weak_type(False) + + def normalize(self) -> AbstractValue: + return self.strip_weak_type() def update(self, **kwargs): raise NotImplementedError("must override") @@ -1314,7 +1317,6 @@ def update(self, **kwargs): def str_short(self, short_dtypes=False): return str(self) - # For type signatures involving dynamic shapes, we use lists of abstract values # which may contain (reverse) de Bruijn indices in their shapes. class DBIdx(NamedTuple): @@ -1348,26 +1350,10 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -class Bot(AbstractValue): pass -bot = Bot() - - -def lattice_join(x: AbstractValue | None, - y: AbstractValue | None) -> AbstractValue: - if x is None: - assert y is not None - return y - elif y is None: - return x - elif isinstance(x, type(y)): - return y.join(x) - elif isinstance(y, type(x)): - return x.join(y) - elif isinstance(x, DShapedArray) and isinstance(y, ShapedArray): - # TODO(mattjj): remove this special case after dynamic shapes are integrated - return x.join(y) - else: - raise TypeError(x, y) +# TODO(dougalm): Deprecate. This is here for backwards compat. +def lattice_join(x, y): + assert typematch(x, y) + return x # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1530,9 +1516,8 @@ def __repr__(self): def str_short(self, short_dtypes=False) -> str: return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name - def strip_weak_type(self): - """Returns a copy of the aval with weak_type=False.""" - return self.update(weak_type=False) + def update_weak_type(self, weak_type): + return self.update(weak_type=weak_type) def _canonicalize_dimension(dim: DimSize) -> DimSize: # Dimensions are most commonly integral (by far), so we check that first. @@ -1656,13 +1641,6 @@ def to_tangent_aval(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) - def join(self, other): - if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype: - weak_type = self.weak_type and other.weak_type - return self.update(weak_type=weak_type) - else: - raise TypeError(self, other) - def str_short(self, short_dtypes=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name) @@ -1762,14 +1740,6 @@ def __eq__(self, other): def __hash__(self): return hash((self.shape, self.dtype, self.weak_type)) - def join(self, other): - if (definitely_equal_shape(self.shape, other.shape) and - self.dtype == other.dtype): - weak_type = self.weak_type and other.weak_type - return self.update(weak_type=weak_type) - else: - raise TypeError(self, other) - def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1881,16 +1851,11 @@ def mutable_array_abstract_eval(init_aval): @mutable_array_p.def_impl def _mutable_array_impl(init_val): from jax._src.state.types import AbstractRef # pytype: disable=import-error - aval = raise_to_shaped(get_aval(init_val)) + aval = get_aval(init_val) return MutableArray(AbstractRef(aval), init_val) class AbstractToken(AbstractValue): - def join(self, other): - if isinstance(other, AbstractToken): - return self - else: - assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() @@ -1910,30 +1875,10 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -def raise_to_shaped(aval: AbstractValue, weak_type=None): - aval_type = type(aval) - if aval_type is ShapedArray and weak_type is None: - return aval - if aval_type is DShapedArray and weak_type is None: - return aval - if weak_type is None: - weak_type = getattr(aval, 'weak_type', False) - for typ in aval_type.__mro__: - handler = raise_to_shaped_mappings.get(typ) - if handler: return handler(aval, weak_type) - raise TypeError(type(aval)) - -def _shaped_array_mapping(aval, weak_type): - if config.sharding_in_types.value: - return ShapedArray(aval.shape, aval.dtype, weak_type, sharding=aval.sharding) - return ShapedArray(aval.shape, aval.dtype, weak_type) - -raise_to_shaped_mappings: dict[type, Callable] = { - AbstractToken: lambda aval, _: aval, - Bot: lambda aval, _: aval, - ShapedArray: _shaped_array_mapping, - DShapedArray: lambda aval, _: aval -} +# TODO(dougalm): Deprecate these. They're just here for backwards compat. +def raise_to_shaped(aval): + return aval +raise_to_shaped_mappings: dict[type, Callable] = {} ### Operations on shapes and dimension sizes. @@ -2341,18 +2286,23 @@ def typecheck(aval: AbstractValue, x) -> bool: def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: """Determine whether `aval` conforms to `aval_ref`. Ignores weak_type.""" try: - return typematch(aval_ref, lattice_join(aval_ref, aval)) + return typematch(aval_ref, aval) except TypeError: return False -def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool: - """Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type.""" - if aval1 == aval2: return True - # unequal avals may still represent the same type, because type is represented - # by avals at the shaped level, and because weak type tags aren't considered - # part of the type - return (raise_to_shaped(aval1, weak_type=False) == - raise_to_shaped(aval2, weak_type=False)) +def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: + """Determine whether `t1` and `t2` are equivalent. Ignores weak_type.""" + t1 = t1.normalize() + t2 = t2.normalize() + if t1 == t2: + return True + elif (isinstance(t1, (ShapedArray, DShapedArray)) and + isinstance(t2, (ShapedArray, DShapedArray))): + # This case handles DShapedArray and shape polynomials. Alternatively we + # could try normalizing first and then doing simple equality. + return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + else: + return False class JaxprTypeError(TypeError): pass diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 77f73562aecd..375efeb712b8 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,7 +31,6 @@ stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs) -from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -81,7 +80,7 @@ def _flatten_fun_nokwargs(in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = yield py_args, {} ans_flat, ans_tree = tree_flatten(ans) - ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat] + ans_avals = [core.get_aval(x) for x in ans_flat] yield ans_flat, (ans_tree, ans_avals) @@ -287,7 +286,7 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): py_primals_out, py_tangents_out = pair_out primals_out, out_tree = tree_flatten(py_primals_out) tangents_out, out_tree2 = tree_flatten(py_tangents_out) - primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] + primal_avals = [core.get_aval(x) for x in primals_out] if out_tree != out_tree2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must " "produce primal and tangent outputs with equal container (pytree) " @@ -327,11 +326,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] + primal_avals_out = [core.get_aval(x).strip_weak_type() for x in primals_out] expected_tangent_avals_out = [ - raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval() + core.get_aval(x).strip_weak_type().to_tangent_aval() for x in primals_out] - tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) + tangent_avals_out = [core.get_aval(t).strip_weak_type() if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] if expected_tangent_avals_out != tangent_avals_out: @@ -606,7 +605,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable f_, dyn_args = lu.wrap_init(self.fun), args fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + in_avals = [core.get_aval(x) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name, fwd_name, in_tree, out_type) @@ -674,7 +673,7 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, py_primals_out, res = pair_out primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) - primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] + primal_avals = [core.get_aval(x) for x in primals_out] # If the primal function already ran, check out_tree agreement. try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None @@ -772,7 +771,7 @@ def append(x, d): msg = ("Custom VJP bwd rule must produce an output with the same " "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " - f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding " + f"shape/dtype {a_.str_short()} corresponding " f"to an input of shape/dtype {a.str_short()}.") raise ValueError(msg) results.append(ct) @@ -831,7 +830,7 @@ def _custom_vjp_call_jaxpr_jvp( _, res_tree = out_trees() res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, @@ -1110,7 +1109,7 @@ def merge(l1, l2): return out, merge def abstractify(x): - return core.raise_to_shaped(core.get_aval(x)) + return core.get_aval(x) ### Custom transposition @@ -1211,7 +1210,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) - out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) + out_avals = f_jaxpr.out_avals t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) @@ -1265,7 +1264,7 @@ def _linear_call_transpose_rule(cts, *args, callee, transpose, return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): - return map(core.raise_to_shaped, kwargs['callee'].out_avals) + return kwargs['callee'].out_avals linear_call_p = core.Primitive('linear_call') linear_call_p.multiple_results = True @@ -1398,7 +1397,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: in_tree, out_type) flat_fwd = _fix_fwd_args(flat_fwd) - in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] + in_avals = [core.get_aval(x) for x in args_flat] fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) prim_tree, res_tree = out_trees() diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 47c7882372ab..d080aae759a6 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -33,8 +33,7 @@ replace_rule_output_symbolic_zeros, Zero, zeros_like_aval) from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 from jax._src.api_util import flatten_fun, flatten_fun_nokwargs -from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, - raise_to_shaped) +from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, @@ -362,7 +361,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! with core.set_current_trace(self.parent_trace): tangents_in = map(instantiate_zeros, tangents_in) @@ -434,8 +433,8 @@ def to_concrete_value(self): def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: - primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) - tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False) + primal_aval = get_aval(primal).strip_weak_type() + tangent_aval = get_aval(tangent).strip_weak_type() assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 2ff27f0c5d74..590e60383b90 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -29,7 +29,7 @@ from jax._src.ad_util import (Zero, instantiate, SymbolicZero, replace_rule_output_symbolic_zeros, add_jaxvals, add_jaxvals_p) -from jax._src.core import raise_to_shaped, Trace, Tracer, TraceTag, AxisName +from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) @@ -217,7 +217,7 @@ def __init__(self, a): self.a = a for d in a.shape)) if type(a) is core.DShapedArray else a for a, e in orig_type if e] - new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens] + new_avals = [core.get_aval(s) for s in segment_lens] sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size for a, d in zip(avals, explicit_in_dims): if isinstance(d, RaggedAxis): @@ -387,7 +387,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, if config.enable_checks.value: assert type(batch_dim) in (NotMapped, int, RaggedAxis) if type(batch_dim) is int: - aval = raise_to_shaped(core.get_aval(val)) + aval = core.get_aval(val) assert 0 <= batch_dim < len(aval.shape) self._trace = trace self.val = val @@ -396,7 +396,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, @property def aval(self): - aval = raise_to_shaped(core.get_aval(self.val)) + aval = core.get_aval(self.val) if self.batch_dim is not_mapped: return aval elif type(self.batch_dim) is int: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 2f63eb386029..c09a8c711984 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -40,7 +40,7 @@ fun_sourceinfo) from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, - Var, DropVar, raise_to_shaped, Atom, + Var, DropVar, Atom, JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) @@ -162,8 +162,7 @@ def new_const(self, val) -> JaxprTracer: def new_instantiated_literal(self, val) -> JaxprTracer: aval = get_aval(val) - return JaxprTracer(self, PartialVal.unknown(aval), - Literal(val, raise_to_shaped(aval))) + return JaxprTracer(self, PartialVal.unknown(aval), Literal(val, aval)) def new_instantiated_const(self, val) -> JaxprTracer: aval = get_aval(val) @@ -201,7 +200,7 @@ def instantiate_const_abstracted(self, tracer) -> JaxprTracer: if const is None: return tracer else: - aval = raise_to_shaped(get_aval(const), np.isscalar(const)) + aval = get_aval(const).update_weak_type(np.isscalar(const)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) def process_primitive(self, primitive, tracers, params): @@ -715,7 +714,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], len(params["in_axes"]) == len(params["call_jaxpr"].invars)) assert ("donated_invars" in params and len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) - out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] + out_avals = [t.aval for t in out_tracers] ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), config.threefry_partitionable.value, @@ -936,7 +935,7 @@ def fun(*known_vals_in): f, in_pvals, instantiate=instantiate) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] - res_avals = [core.raise_to_shaped(core.get_aval(r)) for r in residuals] + res_avals = [core.get_aval(r) for r in residuals] cell.append((out_unknowns, jaxpr_unknown, res_avals)) known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] @@ -1567,7 +1566,7 @@ def get_referent(self): return self if val is None else get_referent(val) def _dynamic_jaxpr_tracer_shaped_abstractify(x): - return core.raise_to_shaped(x.aval) + return x.aval api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: @@ -1827,7 +1826,9 @@ def new_const(self, c): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: - aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) + aval = get_aval(c) + if hasattr(aval, "weak_type"): + aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) aval = self._lift_tracers_in_aval(aval) tracer = self._new_const(aval, c) return tracer @@ -1892,8 +1893,7 @@ def default_process_primitive(self, primitive, tracers, params): def process_call(self, call_primitive, f, explicit_tracers, params): if f.in_type is None: - f = lu.annotate(f, tuple((raise_to_shaped(get_aval(t)), True) - for t in explicit_tracers)) + f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation @@ -2291,7 +2291,7 @@ def _collect_implicit( for i, name in spec.items(): if name not in idxs and id(x.shape[i]) not in explicit_tracers: idxs[name] = DBIdx(next(counter)) - implicit_types.append(raise_to_shaped(get_aval(x.shape[i]))) + implicit_types.append(get_aval(x.shape[i])) if isinstance(x, Tracer): explicit_tracers.setdefault(id(x), explicit_idx) # use the first @@ -2310,7 +2310,7 @@ def _arg_type( ) -> AbstractValue: # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. aval = get_aval(x) # aval.shape could contain Tracers - if not spec: return core.raise_to_shaped(aval) + if not spec: return aval shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d for i, d in enumerate(aval.shape)] assert not any(isinstance(d, Tracer) for d in shape) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 8dae3433e4f6..6333638deae6 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -35,7 +35,7 @@ from jax._src import util from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect -from jax._src.core import raise_to_shaped, replace_jaxpr_effects +from jax._src.core import replace_jaxpr_effects from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -328,7 +328,7 @@ def _cond_abstract_eval(*avals, branches, **_): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') - return map(raise_to_shaped, branches[0].out_avals), joined_effects + return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): if np.ndim(pred) != np.ndim(on_true): @@ -676,7 +676,6 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, def _transpose_cond_jaxpr(jaxpr, num_res): res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) - primal_avals = map(raise_to_shaped, primal_avals) @lu.wrap_init def transposed(*args): @@ -693,7 +692,7 @@ def _cond_transpose(cts, *args, branches): index, *ops = args assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] - in_avals = map(raise_to_shaped, branches[0].in_avals) + in_avals = branches[0].in_avals num_res = len(ops) - sum(linear) if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): @@ -701,8 +700,7 @@ def _cond_transpose(cts, *args, branches): branches_trans = tuple( _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches) - lin_in_avals = [raise_to_shaped(a, weak_type=False) - for a, l in zip(in_avals, linear) if l] + lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l] assert all(core.typematch(out_aval, lin_in_aval) for jaxpr in branches_trans for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index ddbbe0213f6f..19d3429d2675 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -35,7 +35,7 @@ from jax._src import state from jax._src import util from jax._src.api_util import shaped_abstractify -from jax._src.core import ShapedArray, raise_to_shaped +from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -262,7 +262,7 @@ def scan(f, init, xs, length=None): stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y - xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat] + xs_avals = [core.get_aval(x) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] def _create_jaxpr(init): @@ -1370,7 +1370,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') - return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects + return body_jaxpr.out_avals, joined_effects def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr, diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 9a5a01e3987d..f97377b2df6c 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -23,7 +23,6 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import linear_util as lu -from jax._src.core import raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -300,7 +299,7 @@ def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return _map(raise_to_shaped, args_to_raise) + return args_to_raise def _custom_linear_solve_impl(*args, const_lengths, jaxprs): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e6dbcbb12a1c..8b6a517a54b3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -48,7 +48,7 @@ from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.core import (Primitive, UnshapedArray, ShapedArray, - raise_to_shaped, abstract_token, canonicalize_shape) + abstract_token, canonicalize_shape) from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -3044,7 +3044,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f" has a representation shape {rep_aval.shape} while the given " f"representation array has shape {x.shape}, so the shape suffix " f"does not match: given {shape_suffix} but required {rep_aval.shape}.") - return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype) + return x.update(shape=shape_prefix, dtype=edtype) to_edtype_p = Primitive('to_edtype') to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) @@ -5246,7 +5246,7 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): def _sort_abstract_eval(*args, **kwargs): - args = tuple(raise_to_shaped(arg) for arg in args) + args = tuple(args) if any(arg.shape != args[0].shape for arg in args[1:]): shapes = " ".join(str(a.shape) for a in args) raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") @@ -6196,7 +6196,7 @@ def _eq_meet(a, b): def _abstractify(x): - return raise_to_shaped(core.get_aval(x)) + return core.get_aval(x) def empty(dtype): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index dc1d1d472ae2..0e0390abc78f 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -33,7 +33,7 @@ from jax._src import dtypes from jax._src import util from jax._src.core import ( - Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape) + Primitive, ShapedArray, is_constant_dim, is_constant_shape) from jax._src.extend import ffi from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -1289,7 +1289,6 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): - pivots = raise_to_shaped(pivots) if isinstance(pivots, ShapedArray): if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32): raise ValueError( @@ -1421,7 +1420,6 @@ def _lu_impl(operand): return lu, pivot, perm def _lu_abstract_eval(operand): - operand = raise_to_shaped(operand) if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to LU decomposition must have ndims >= 2") diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 932fd4b88c08..3a1c1ef3bcf1 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -27,7 +27,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import sharding_impls -from jax._src.core import AxisName, ShapedArray, raise_to_shaped +from jax._src.core import AxisName, ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -636,7 +636,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes), + ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype) for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} @@ -817,7 +817,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): _check_axis_names(axis_name) - return raise_to_shaped(x) + return x ppermute_p = core.Primitive('ppermute') ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval) @@ -1019,13 +1019,12 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, def _all_to_all_effectful_abstract_eval( - x, axis_name, split_axis, concat_axis, axis_index_groups, tiled + input_aval, axis_name, split_axis, concat_axis, axis_index_groups, tiled ): del tiled # expand_dims and squeeze is done in `all_to_all` if `True` if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) - input_aval = raise_to_shaped(x) shape = list(input_aval.shape) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) @@ -1169,12 +1168,11 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, def _all_gather_effectful_abstract_eval( - x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled + x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) - x_aval = raise_to_shaped(x) new_shape = list(x_aval.shape) if tiled: new_shape[all_gather_dimension] *= axis_size @@ -1298,12 +1296,11 @@ def _reduce_scatter_lowering( def _reduce_scatter_effectful_abstract_eval( - x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled + x_aval, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled ): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) _check_axis_names(axis_name) - x_aval = core.raise_to_shaped(x) new_shape = list(x_aval.shape) scatter_dim_input_size = x_aval.shape[scatter_dimension] if tiled: diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 14e9d72af186..72ed07674f1f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -140,13 +140,6 @@ def __hash__(self): self.memory_space, )) - def at_least_vspace(self): - """Vector space method needed for AD.""" - raise NotImplementedError - - def join(self, other): - raise NotImplementedError - def str_short(self, short_dtypes=False): dt_str = \ dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name @@ -226,10 +219,9 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' - def join(self, other): - assert isinstance(other, AbstractMemoryRef) - return AbstractMemoryRef(self.inner_aval.join(other.inner_aval), - self.memory_space) + def update_weak_type(self, weak_type): + return AbstractMemoryRef( + self.inner_aval.update_weak_type(weak_type), self.memory_space) def update(self, inner_aval=None, memory_space=None): inner_aval = self.inner_aval if inner_aval is None else inner_aval @@ -240,6 +232,10 @@ def to_tangent_aval(self): return AbstractMemoryRef( self.inner_aval.to_tangent_aval(), self.memory_space) + # TODO(dougalm, sharadmv): figure out how to avoid needing this + def normalize(self): + return state.AbstractRef(self.inner_aval).normalize() + def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval and self.memory_space == other.memory_space) @@ -262,13 +258,6 @@ def __str__(self) -> str: return self.value -def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): - return AbstractMemoryRef( - jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type), - ref_aval.memory_space) -jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped - - @dataclasses.dataclass(frozen=True) class PallasGridContext: grid: GridMappingGrid diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 12ae5350e725..ad9a6cb13f42 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -174,15 +174,6 @@ def get_ref_aval(self) -> AbstractMemoryRef: class AbstractSemaphore(jax_core.AbstractValue): sem_type: SemaphoreType - def join(self, other): - if not isinstance(other, AbstractSemaphore): - raise ValueError - if other.sem_type != self.sem_type: - raise ValueError - return self - -jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval - @dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index bae87226c664..61caa4087d99 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -529,7 +529,8 @@ def export_promela_model( @assume_p.def_abstract_eval def _assume_abstract_eval(x, y): - return x.join(y) + assert jax_core.typematch(x, y) + return x def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y): return y if ctx.lowering_context.for_verification else x diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 7f9e0bef822e..d1f75009c33d 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -458,15 +458,12 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): def __repr__(self) -> str: return f'Accumulator{{{self.inner_aval.str_short()}}}' - def join(self, other): - return _as_accum(super().join(other)) + def update_weak_type(self, weak_type): + return _as_accum(super().update_weak_type(weak_type)) def update(self, inner_aval=None, memory_space=None): return _as_accum(super().update(inner_aval=None, memory_space=None)) - def at_least_vspace(self): - return _as_accum(super().at_least_vspace()) - def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error arr = wgmma_accumulator_deref(tracer) @@ -483,10 +480,6 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: memory_space=ref.memory_space, # pytype: disable=attribute-error ) -def _ref_raise_to_shaped(ref_aval, weak_type): - return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) -jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped - _WARPGROUP_AXIS_NAME = object() diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1a5ed7f0d43e..1ced213394ff 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -567,7 +567,7 @@ def wgmma_accumulator_deref(acc): @wgmma_accumulator_deref_p.def_effectful_abstract_eval def _wgmma_accumulator_deref_abstract_eval(acc): # Dereferencing implies flushing so we have a wgmma pipeline effect. - ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc + ret = acc.inner_aval if isinstance(acc, state.AbstractRef) else acc assert isinstance(ret, jax_core.ShapedArray), acc return ret, {gpu_core._wgmma_pipeline_effect} diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 7724466d3110..0897e778d079 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -230,7 +230,6 @@ def _swap_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) @@ -262,7 +261,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - val_aval = core.raise_to_shaped(val_aval) out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) assert isinstance(val_aval, core.ShapedArray) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 634617102d6c..df3c63606ba4 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -291,15 +291,14 @@ def weak_type(self) -> bool: raise AttributeError return self.inner_aval.weak_type + def update_weak_type(self, weak_type): + return AbstractRef(self.inner_aval.update_weak_type(weak_type)) + def update(self, inner_aval=None): if inner_aval is None: return AbstractRef(self.inner_aval) return AbstractRef(inner_aval) - def join(self, other): - assert isinstance(other, AbstractRef) - return AbstractRef(self.inner_aval.join(other.inner_aval)) - ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) @@ -365,10 +364,6 @@ def __eq__(self, other): def __hash__(self): return hash((self.__class__, self.inner_aval)) -def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type): - return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type)) -core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped - def _map_ref(size, axis, ref_aval): return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval)) diff --git a/tests/core_test.py b/tests/core_test.py index 1471e334c880..7ca941c69c7b 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -533,15 +533,6 @@ def test_jaxpr_undefined_eqn_invar(self): r"Variable '.+_test' not defined\n\nin equation:", lambda: core.check_jaxpr(jaxpr)) - @parameterized.parameters( - {'value': 0, 'weak_type': True}, - {'value': np.int32(0), 'weak_type': False}, - {'value': np.array([0]), 'weak_type': False} - ) - def test_raise_to_shaped_weak_type(self, value, weak_type): - aval = core.raise_to_shaped(core.get_aval(value)) - self.assertEqual(aval.weak_type, weak_type) - @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): diff --git a/tests/lax_test.py b/tests/lax_test.py index f2ce0913e03a..12149700cb30 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3821,7 +3821,7 @@ def shard_foo_array_handler(xs, shardings, layouts): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment - aval = core.raise_to_shaped(core.get_aval(x.data)) + aval = core.get_aval(x.data) results.append(pxla.batched_device_put( aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) return results diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index d92991caa6fe..544ed1ac3ecc 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -870,8 +870,9 @@ def scope(): pl.run_scoped(scope) return [] - aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) - in_avals = [aref, aref] + aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + in_avals = [aref1, aref2] stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = state_discharge.discharge_state( diff --git a/tests/state_test.py b/tests/state_test.py index 36e93e88c5e0..c8458742619d 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -746,9 +746,10 @@ def f(a_ref, b_ref): b_ref[...] = jnp.array(1., dtype=jnp.float32) return a_ref[...], b_ref[...] - scalar_ref = shaped_array_ref((), jnp.float32) + scalar_ref_1 = shaped_array_ref((), jnp.float32) + scalar_ref_2 = shaped_array_ref((), jnp.float32) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f), [scalar_ref, scalar_ref]) + lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) From 0e8acff5c6ec63343e4d31eec8d982a98cb4b898 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 5 Nov 2024 08:31:12 -0800 Subject: [PATCH 203/698] Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461 PiperOrigin-RevId: 693360032 --- jax/_src/api.py | 4 +- jax/_src/compute_on.py | 7 +- jax/_src/config.py | 649 +++++++++++++++++++++++---------------- jax/_src/core.py | 15 +- jax/_src/mesh.py | 14 +- jax/_src/xla_metadata.py | 12 +- tests/pmap_test.py | 2 - 7 files changed, 406 insertions(+), 297 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index cc42a37b0e7c..a902c7de4c3e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -123,8 +123,8 @@ def _update_debug_special_global(_): jax_jit.global_state().post_hook = None def _update_debug_special_thread_local(_): - if (getattr(config._thread_local_state, "jax_debug_nans", False) or - getattr(config._thread_local_state, "jax_debug_infs", False)): + if (config.debug_nans.get_local() == True or + config.debug_infs.get_local() == True): jax_jit.thread_local_state().post_hook = _nan_check_posthook else: jax_jit.thread_local_state().post_hook = None diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index b5194ddad21d..7bd9b9b08b7b 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -29,8 +29,8 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) - config.update_thread_local_jit_state( - compute_on_context_manager=tuple(compute_on_context.stack)) + config.compute_on_context_manager.set_local( + tuple(compute_on_context.stack)) try: if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: raise NotImplementedError( @@ -39,8 +39,7 @@ def extend_compute_type(c_type: str): yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() - config.update_thread_local_jit_state( - compute_on_context_manager=tuple(compute_on_context.stack)) + config.compute_on_context_manager.set_local(tuple(compute_on_context.stack)) def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None diff --git a/jax/_src/config.py b/jax/_src/config.py index 0860168b23a3..0e113c894695 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -22,14 +22,23 @@ import os import sys import threading -from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast +from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING from jax._src import lib from jax._src.lib import guard_lib from jax._src.lib import jax_jit from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax._src import logging_config +# TODO(phawkins): reenable pytype after xla_extension_version >= 295 +# pytype: skip-file + +if xla_extension_version >= 295: + config_ext = xla_client._xla.config +else: + config_ext = None + logger = logging.getLogger(__name__) _T = TypeVar('_T') @@ -191,49 +200,79 @@ def parse_flags_with_absl(self): already_configured_with_absl = True -def trace_context(): - """Returns a tuple of configuration values that affect tracing. +if xla_extension_version >= 295: + def trace_context(): + """Returns a tuple of configuration values that affect tracing. - These values are included in the cache key for linear_util.cache. + These values are included in the cache key for linear_util.cache. - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately. - """ - tls = jax_jit.thread_local_state() - axis_env_state = () - mesh_context_manager = () - xla_metadata_context_manager = () - compute_on_context_manager = () - - context: Any = tls.extra_jit_context - if context and context.axis_env_state is not None: - axis_env_state = context.axis_env_state - if context and context.mesh_context_manager: - mesh_context_manager = context.mesh_context_manager - if context and context.xla_metadata_context_manager: - xla_metadata_context_manager = context.xla_metadata_context_manager - if context and context.compute_on_context_manager: - compute_on_context_manager = context.compute_on_context_manager - return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, - compute_on_context_manager, enable_x64.value, - numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, - eager_constant_folding.value, - numpy_dtype_promotion.value, - default_device.value, random_seed_offset.value, - threefry_partitionable.value, - threefry_gpu_kernel_lowering.value, - sharding_in_types.value, - softmax_custom_jvp.value, - enable_memories.value, - disable_jit.value, - debug_key_reuse.value, - jax_xla_profile_version.value, - # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value, - pgle_profiling_runs.value, - enable_pgle.value, - use_shardy_partitioner.value) + Values included in this set should also most likely be included in + the C++ JIT state, which is handled separately. + """ + return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value, + compute_on_context_manager.value, enable_x64.value, + numpy_rank_promotion.value, default_matmul_precision.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, + default_device.value, random_seed_offset.value, + threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, + sharding_in_types.value, + softmax_custom_jvp.value, + enable_memories.value, + disable_jit.value, + debug_key_reuse.value, + jax_xla_profile_version.value, + # Technically this affects jaxpr->stablehlo lowering, not tracing. + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value, + use_shardy_partitioner.value) +else: + def trace_context(): + """Returns a tuple of configuration values that affect tracing. + + These values are included in the cache key for linear_util.cache. + + Values included in this set should also most likely be included in + the C++ JIT state, which is handled separately. + """ + tls = jax_jit.thread_local_state() + axis_env_state = () + mesh_context_manager = () + xla_metadata_context_manager = () + compute_on_context_manager = () + + context: Any = tls.extra_jit_context + if context and context.axis_env_state is not None: + axis_env_state = context.axis_env_state + if context and context.mesh_context_manager: + mesh_context_manager = context.mesh_context_manager + if context and context.xla_metadata_context_manager: + xla_metadata_context_manager = context.xla_metadata_context_manager + if context and context.compute_on_context_manager: + compute_on_context_manager = context.compute_on_context_manager + return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, + compute_on_context_manager, enable_x64.value, + numpy_rank_promotion.value, default_matmul_precision.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, + default_device.value, random_seed_offset.value, + threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, + sharding_in_types.value, + softmax_custom_jvp.value, + enable_memories.value, + disable_jit.value, + debug_key_reuse.value, + jax_xla_profile_version.value, + # Technically this affects jaxpr->stablehlo lowering, not tracing. + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value, + use_shardy_partitioner.value) config = Config() @@ -245,94 +284,185 @@ def trace_context(): class NoDefault: pass no_default = NoDefault() +if xla_extension_version >= 295: + class State(config_ext.Config[_T]): -class _Unset: pass -unset = _Unset() - -_thread_local_state = threading.local() + __slots__ = ( + '_name', '_update_thread_local_hook', '_update_global_hook', + '_validator', '_default_context_manager_value', '__doc__', '__name__', + ) -class State(Generic[_T]): + def __init__( + self, + name: str, + default: _T, + help, + update_global_hook: Callable[[_T], None] | None = None, + update_thread_local_hook: Callable[[_T | None], None] | None = None, + validator: Callable[[Any], None] | None = None, + extra_description: str = '', + default_context_manager_value: Any = no_default, + include_in_jit_key: bool = False, + ): + super().__init__(default, include_in_jit_key) + self._name = name + self.__name__ = name[4:] if name.startswith('jax_') else name + self.__doc__ = (f"Context manager for `{name}` config option" + f"{extra_description}.\n\n{help}") + self._update_global_hook = update_global_hook + self._update_thread_local_hook = update_thread_local_hook + self._validator = validator + self._default_context_manager_value = default_context_manager_value + if self._validator: + self._validator(default) + if self._update_global_hook: + self._update_global_hook(default) + + def __bool__(self) -> NoReturn: + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__)) + + def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) + self.set_global(value) + if self._update_global_hook: + self._update_global_hook(value) + + @contextlib.contextmanager + def __call__(self, new_val: Any = no_default): + if new_val is no_default: + if self._default_context_manager_value is not no_default: + new_val = self._default_context_manager_value # default_context_manager_value provided to constructor + else: + # no default_value provided to constructor and no value provided as an + # argument, so we raise an error + raise TypeError(f"Context manager for {self.__name__} config option " + "requires an argument representing the new value for " + "the config option.") + if self._validator: + self._validator(new_val) + prev_val = self.swap_local(new_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(new_val) + try: + yield + finally: + self.set_local(prev_val) + if self._update_thread_local_hook: + if prev_val is config_ext.unset: + self._update_thread_local_hook(None) + else: + self._update_thread_local_hook(cast(Optional[Any], prev_val)) - __slots__ = ( - '_name', '_value', '_update_thread_local_hook', '_update_global_hook', - '_validator', '_default_context_manager_value', '__doc__', '__name__', - ) + def _add_hooks(self, update_global_hook, update_thread_local_hook): + """Private method that adds hooks to an existing context-manager. - def __init__( - self, - name: str, - default: _T, - help, - update_global_hook: Callable[[_T], None] | None = None, - update_thread_local_hook: Callable[[_T | None], None] | None = None, - validator: Callable[[Any], None] | None = None, - extra_description: str = '', - default_context_manager_value: Any = no_default, - ): - self._name = name - self.__name__ = name[4:] if name.startswith('jax_') else name - self.__doc__ = (f"Context manager for `{name}` config option" - f"{extra_description}.\n\n{help}") - self._update_global_hook = update_global_hook - self._update_thread_local_hook = update_thread_local_hook - self._validator = validator - self._default_context_manager_value = default_context_manager_value - self._set(default) + Used to avoid cyclic import dependencies.""" + self._update_thread_local_hook = update_thread_local_hook + self._update_global_hook = update_global_hook + update_global_hook(self.get_global()) - def __bool__(self) -> NoReturn: - raise TypeError( - "bool() not supported for instances of type '{0}' " - "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) +else: + class _Unset: pass + unset = _Unset() - def _set(self, value: _T) -> None: - if self._validator: - self._validator(value) - self._value = value - if self._update_global_hook: - self._update_global_hook(value) + _thread_local_state = threading.local() - @property - def value(self) -> _T: - val = _thread_local_state.__dict__.get(self._name, unset) - return cast(_T, val) if val is not unset else self._value - - @contextlib.contextmanager - def __call__(self, new_val: Any = no_default): - if new_val is no_default: - if self._default_context_manager_value is not no_default: - new_val = self._default_context_manager_value # default_context_manager_value provided to constructor - else: - # no default_value provided to constructor and no value provided as an - # argument, so we raise an error - raise TypeError(f"Context manager for {self.__name__} config option " - "requires an argument representing the new value for " - "the config option.") - if self._validator: - self._validator(new_val) - prev_val = getattr(_thread_local_state, self._name, unset) - setattr(_thread_local_state, self._name, new_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(new_val) - try: - yield - finally: - if prev_val is unset: - delattr(_thread_local_state, self._name) - if self._update_thread_local_hook: - self._update_thread_local_hook(None) - else: - setattr(_thread_local_state, self._name, prev_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(cast(_T, prev_val)) + class State(Generic[_T]): - def _add_hooks(self, update_global_hook, update_thread_local_hook): - """Private method that adds hooks to an existing context-manager. + __slots__ = ( + '_name', '_value', '_update_thread_local_hook', '_update_global_hook', + '_validator', '_default_context_manager_value', '__doc__', '__name__', + ) - Used to avoid cyclic import dependencies.""" - self._update_thread_local_hook = update_thread_local_hook - self._update_global_hook = update_global_hook - update_global_hook(self._value) + def __init__( + self, + name: str, + default: _T, + help, + update_global_hook: Callable[[_T], None] | None = None, + update_thread_local_hook: Callable[[_T | None], None] | None = None, + validator: Callable[[Any], None] | None = None, + extra_description: str = '', + default_context_manager_value: Any = no_default, + include_in_jit_key: bool = False, + ): + self._name = name + self.__name__ = name[4:] if name.startswith('jax_') else name + self.__doc__ = (f"Context manager for `{name}` config option" + f"{extra_description}.\n\n{help}") + if include_in_jit_key: + assert update_global_hook is None + assert update_thread_local_hook is None + update_global_hook = lambda val: _update_global_jit_state( + **{self.__name__: val}) + update_thread_local_hook = lambda val: update_thread_local_jit_state( + **{self.__name__: val}) + self._update_global_hook = update_global_hook + self._update_thread_local_hook = update_thread_local_hook + self._validator = validator + self._default_context_manager_value = default_context_manager_value + self._set(default) + def __bool__(self) -> NoReturn: + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__)) + + def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) + self._value = value + if self._update_global_hook: + self._update_global_hook(value) + + @property + def value(self) -> _T: + val = _thread_local_state.__dict__.get(self._name, unset) + return cast(_T, val) if val is not unset else self._value + + def get_local(self) -> Any: + return _thread_local_state.__dict__.get(self._name, unset) + + @contextlib.contextmanager + def __call__(self, new_val: Any = no_default): + if new_val is no_default: + if self._default_context_manager_value is not no_default: + new_val = self._default_context_manager_value # default_context_manager_value provided to constructor + else: + # no default_value provided to constructor and no value provided as an + # argument, so we raise an error + raise TypeError(f"Context manager for {self.__name__} config option " + "requires an argument representing the new value for " + "the config option.") + if self._validator: + self._validator(new_val) + prev_val = getattr(_thread_local_state, self._name, unset) + setattr(_thread_local_state, self._name, new_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(new_val) + try: + yield + finally: + if prev_val is unset: + delattr(_thread_local_state, self._name) + if self._update_thread_local_hook: + self._update_thread_local_hook(None) + else: + setattr(_thread_local_state, self._name, prev_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(cast(_T, prev_val)) + + def _add_hooks(self, update_global_hook, update_thread_local_hook): + """Private method that adds hooks to an existing context-manager. + + Used to avoid cyclic import dependencies.""" + self._update_thread_local_hook = update_thread_local_hook + self._update_global_hook = update_global_hook + update_global_hook(self._value) UPGRADE_BOOL_HELP = ( @@ -353,6 +483,7 @@ def bool_state( update_thread_local_hook: Callable[[bool | None], None] | None = None, upgrade: bool = False, extra_description: str = '', + include_in_jit_key: bool = False, ) -> State[bool]: """Set up thread-local state and return a contextmanager for managing it. @@ -417,7 +548,8 @@ def bool_state( s = State[bool]( name, default, help, update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, - extra_description=extra_description, default_context_manager_value=True) + extra_description=extra_description, default_context_manager_value=True, + include_in_jit_key=include_in_jit_key) config.add_option(name, s, bool, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -431,6 +563,7 @@ def enum_state( *, update_global_hook: Callable[[str], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[str]: """Set up thread-local state and return a contextmanager for managing it. @@ -470,6 +603,7 @@ def validator(new_val): update_global_hook=update_global_hook, update_thread_local_hook=update_thread_local_hook, validator=validator, + include_in_jit_key=include_in_jit_key, ) config.add_option( name, s, 'enum', @@ -488,6 +622,7 @@ def optional_enum_state( *, update_global_hook: Callable[[str | None], None] | None = None, update_thread_local_hook: Callable[[str | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[str | None]: """Set up thread-local state and return a contextmanager for managing it. @@ -523,7 +658,7 @@ def validate(new_val): s = State['str | None']( name, default, help, update_global_hook, update_thread_local_hook, - validate + validate, include_in_jit_key=include_in_jit_key, ) config.add_option( name, s, 'enum', @@ -541,6 +676,7 @@ def int_state( *, update_global_hook: Callable[[int], None] | None = None, update_thread_local_hook: Callable[[int | None], None] | None = None, + include_in_jit_key: bool = False, ) -> State[int]: """Set up thread-local state and return a contextmanager for managing it. @@ -575,7 +711,8 @@ def validate(new_val): f'got {new_val} of type {type(new_val)}') s = State[int](name, default, help, update_global_hook, - update_thread_local_hook, validate) + update_thread_local_hook, validate, + include_in_jit_key=include_in_jit_key) config.add_option(name, s, int, meta_args=[], meta_kwargs={"help": help}) setattr(Config, name, property(lambda _: s.value)) return s @@ -826,92 +963,119 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: already_configured_with_absl = False -# The C++ JIT maintains its own copy of several configuration items as -# a global/thread-local state. These methods allow updates to part of the -# state when a configuration value changes. -class _GlobalExtraJitContext(NamedTuple): - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool = False - eager_constant_folding: bool = False - random_seed_offset: int = 0 - threefry_partitionable: bool = False - threefry_gpu_kernel_lowering: bool = False - sharding_in_types: bool = False - softmax_custom_jvp: bool = False - xla_profile_version: int = 0 - pgle_profiling_runs: int = 0 - enable_pgle: bool = False - use_shardy_partitioner: bool = False - - -def _update_global_jit_state(**kw): - gs = jax_jit.global_state() - context = gs.extra_jit_context or _GlobalExtraJitContext() - gs.extra_jit_context = context._replace(**kw) - - -class _ThreadLocalExtraJitContext(NamedTuple): - """A namedtuple containing states to add to the cache key. - - Just in time compilation (for jit, pmap, etc) behavior is configurable through - global and thread-local options, used in the cache key. - - The initialization, which uses both config.py and core.py is done using - `_update_thread_local_jit_state` in core.py to prevent circular imports. - """ - trace_state: Any | None = None - axis_env_state: Hashable = () - mesh_context_manager: Hashable = () - compute_on_context_manager: Hashable = () - xla_metadata_context_manager: Hashable = () - - # Values set by _StateContextManager context managers. - # CAUTION: these must be initialized to `None`! The state context manager - # restores these to None on exit. If the object default is not `None`, the - # context manager is not a no-op, which leads to problems with stale state - # (e.g. spurious cache misses in tests). - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool | None = None - eager_constant_folding : bool | None = None - random_seed_offset: int | None = None - threefry_partitionable: bool | None = None - threefry_gpu_kernel_lowering: bool | None = None - sharding_in_types: bool | None = None - softmax_custom_jvp: bool | None = None - xla_profile_version: int | None = None - pgle_profiling_runs: int | None = None - enable_pgle: bool | None = None - use_shardy_partitioner: bool | None = None - - -class _ThreadLocalStateCache(threading.local): - """"A thread local cache for _ThreadLocalExtraJitContext - - The extra_jit_context in jax_jit.thread_local_state() may get updated and thus - incurring dispatch overhead for comparing this python object during jit calls. - We want to deduplicate the objects that have the same hash/equality to also - have the same object ID, since the equality check is much faster if the object - IDs match. - """ - def __init__(self): - self.canonicalize = functools.lru_cache(128)(lambda x: x) +if xla_extension_version >= 295: + trace_state = config_ext.Config(None, include_in_jit_key=True) + axis_env_state = config_ext.Config((), include_in_jit_key=True) + mesh_context_manager = config_ext.Config((), include_in_jit_key=True) + compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) + xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) +else: + # The C++ JIT maintains its own copy of several configuration items as + # a global/thread-local state. These methods allow updates to part of the + # state when a configuration value changes. + class _GlobalExtraJitContext(NamedTuple): + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None + dynamic_shapes: bool = False + eager_constant_folding: bool = False + random_seed_offset: int = 0 + threefry_partitionable: bool = False + threefry_gpu_kernel_lowering: bool = False + sharding_in_types: bool = False + softmax_custom_jvp: bool = False + xla_profile_version: int = 0 + pgle_profiling_runs: int = 0 + enable_pgle: bool = False + use_shardy_partitioner: bool = False + + + def _update_global_jit_state(**kw): + gs = jax_jit.global_state() + context = gs.extra_jit_context or _GlobalExtraJitContext() + gs.extra_jit_context = context._replace(**kw) + + + class _ThreadLocalExtraJitContext(NamedTuple): + """A namedtuple containing states to add to the cache key. + + Just in time compilation (for jit, pmap, etc) behavior is configurable through + global and thread-local options, used in the cache key. + + The initialization, which uses both config.py and core.py is done using + `_update_thread_local_jit_state` in core.py to prevent circular imports. + """ + trace_state: Any | None = None + axis_env_state: Hashable = () + mesh_context_manager: Hashable = () + compute_on_context_manager: Hashable = () + xla_metadata_context_manager: Hashable = () + + # Values set by _StateContextManager context managers. + # CAUTION: these must be initialized to `None`! The state context manager + # restores these to None on exit. If the object default is not `None`, the + # context manager is not a no-op, which leads to problems with stale state + # (e.g. spurious cache misses in tests). + numpy_rank_promotion: str | None = None + numpy_dtype_promotion: str | None = None + default_matmul_precision: Any | None = None + dynamic_shapes: bool | None = None + eager_constant_folding : bool | None = None + random_seed_offset: int | None = None + threefry_partitionable: bool | None = None + threefry_gpu_kernel_lowering: bool | None = None + sharding_in_types: bool | None = None + softmax_custom_jvp: bool | None = None + xla_profile_version: int | None = None + pgle_profiling_runs: int | None = None + enable_pgle: bool | None = None + use_shardy_partitioner: bool | None = None + + + class _ThreadLocalStateCache(threading.local): + """"A thread local cache for _ThreadLocalExtraJitContext + + The extra_jit_context in jax_jit.thread_local_state() may get updated and thus + incurring dispatch overhead for comparing this python object during jit calls. + We want to deduplicate the objects that have the same hash/equality to also + have the same object ID, since the equality check is much faster if the object + IDs match. + """ + def __init__(self): + self.canonicalize = functools.lru_cache(128)(lambda x: x) -_thread_local_state_cache = _ThreadLocalStateCache() + _thread_local_state_cache = _ThreadLocalStateCache() -def update_thread_local_jit_state(**kw): - tls = jax_jit.thread_local_state() - # After xla_client._version >= 70, the thread_local object will necessarily - # be initialized when accessed. The following line can be removed when the - # minimum jaxlib version is past version 70 - context = tls.extra_jit_context or _ThreadLocalExtraJitContext() - tmp = context._replace(**kw) - tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) + def update_thread_local_jit_state(**kw): + tls = jax_jit.thread_local_state() + # After xla_client._version >= 70, the thread_local object will necessarily + # be initialized when accessed. The following line can be removed when the + # minimum jaxlib version is past version 70 + context = tls.extra_jit_context or _ThreadLocalExtraJitContext() + tmp = context._replace(**kw) + tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) + + class JitConfig: + def __init__(self, name): + self._name = name + + def value(self): + return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) + + def get_local(self): + return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) + + def set_local(self, value): + update_thread_local_jit_state(**{self._name: value}) + + trace_state = JitConfig('trace_state') + axis_env_state = JitConfig('axis_env_state') + mesh_context_manager = JitConfig('mesh_context_manager') + compute_on_context_manager = JitConfig('compute_on_context_manager') + xla_metadata_context_manager = JitConfig('xla_metadata_context_manager') + # TODO(b/214340779): remove flag when XLA:CPU is improved. jax2tf_associative_scan_reductions = bool_state( @@ -1102,10 +1266,7 @@ def _update_jax_memories_thread_local(val): name='jax_random_seed_offset', default=0, help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), - update_global_hook=lambda val: _update_global_jit_state( - random_seed_offset=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - random_seed_offset=val) + include_in_jit_key=True, ) legacy_prng_key = enum_state( @@ -1140,10 +1301,7 @@ def _update_jax_memories_thread_local(val): 'may result in extraneous communication and/or redundant distributed ' 'computation. With this flag, the communication overheads disappear ' 'in some cases.'), - update_global_hook=lambda val: _update_global_jit_state( - threefry_partitionable=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - threefry_partitionable=val)) + include_in_jit_key=True) threefry_gpu_kernel_lowering = bool_state( name='jax_threefry_gpu_kernel_lowering', @@ -1151,20 +1309,14 @@ def _update_jax_memories_thread_local(val): help=('On GPU, lower threefry PRNG operations to a kernel implementation. ' 'This makes compile times faster at a potential runtime memory ' 'cost.'), - update_global_hook=lambda val: _update_global_jit_state( - threefry_gpu_kernel_lowering=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - threefry_gpu_kernel_lowering=val)) + include_in_jit_key=True) sharding_in_types = bool_state( name='jax_sharding_in_types', default=False, help=('When True, enables forward only sharding propagation in JAX and ' 'avals have sharding on them.'), - update_global_hook=lambda val: _update_global_jit_state( - sharding_in_types=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - sharding_in_types=val)) + include_in_jit_key=True) data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', @@ -1179,10 +1331,7 @@ def _update_jax_memories_thread_local(val): help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' 'behavior. See https://github.com/jax-ml/jax/pull/15677'), - update_global_hook=lambda val: _update_global_jit_state( - softmax_custom_jvp=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - softmax_custom_jvp=val)) + include_in_jit_key=True) enable_custom_vjp_by_custom_transpose = bool_state( @@ -1298,9 +1447,7 @@ def _update_jax_memories_thread_local(val): 'number times with collected data provided to the profile guided latency ' 'estimator.' ), - update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - enable_pgle=val), + include_in_jit_key=True, ) pgle_profiling_runs = int_state( @@ -1310,12 +1457,7 @@ def _update_jax_memories_thread_local(val): 'Amount of times module should be profiled before recompilation when ' 'PGLE is used.' ), - update_global_hook=lambda val: _update_global_jit_state( - pgle_profiling_runs=val - ), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - pgle_profiling_runs=val - ), + include_in_jit_key=True, ) pgle_aggregation_percentile = int_state( @@ -1381,10 +1523,7 @@ def _update_jax_memories_thread_local(val): 'between arrays. Options are "standard" or "strict"; in strict-mode, ' 'binary operations between arrays of differing strongly-specified ' 'dtypes will result in an error.'), - update_global_hook=lambda val: \ - _update_global_jit_state(numpy_dtype_promotion=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(numpy_dtype_promotion=val)) + include_in_jit_key=True) disallow_mesh_context_manager = bool_state( name='jax_disallow_mesh_context_manager', @@ -1470,10 +1609,7 @@ def _update_disable_jit_thread_local(val): default='allow', help=('Control NumPy-style automatic rank promotion broadcasting ' '("allow", "warn", or "raise").'), - update_global_hook=lambda val: \ - _update_global_jit_state(numpy_rank_promotion=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(numpy_rank_promotion=val)) + include_in_jit_key=True) default_matmul_precision = optional_enum_state( name='jax_default_matmul_precision', @@ -1509,10 +1645,7 @@ def _update_disable_jit_thread_local(val): '"algorithm" for functions that perform matrix multiplications, like ' ':func:`jax.lax.dot`. To specify an algorithm, set this option to ' 'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'), - update_global_hook=lambda val: \ - _update_global_jit_state(default_matmul_precision=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(default_matmul_precision=val)) + include_in_jit_key=True) traceback_filtering = enum_state( name = 'jax_traceback_filtering', @@ -1547,20 +1680,14 @@ def _update_disable_jit_thread_local(val): default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')), help=('Enables experimental features for staging out computations with ' 'dynamic shapes.'), - update_global_hook=lambda val: \ - _update_global_jit_state(dynamic_shapes=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(dynamic_shapes=val)) + include_in_jit_key=True) # This is for stackless backward compat with e.g. equinox eager_constant_folding = bool_state( name='eager_constant_folding', default=False, help=('Attempt constant folding during staging.'), - update_global_hook=lambda val: \ - _update_global_jit_state(eager_constant_folding=val), - update_thread_local_hook=lambda val: \ - update_thread_local_jit_state(eager_constant_folding=val)) + include_in_jit_key=True) # This flag is temporary during rollout of the remat barrier. # TODO(parkers): Remove if there are no complaints. @@ -1619,10 +1746,7 @@ def _update_disable_jit_thread_local(val): 'Optional profile version for XLA compilation. This is meaningful ' 'only when XLA is configured to support the remote compilation ' 'profile feature.'), - update_global_hook=lambda val: _update_global_jit_state( - xla_profile_version=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - xla_profile_version=val), + include_in_jit_key=True, ) @contextlib.contextmanager @@ -1821,10 +1945,5 @@ def _update_debug_log_modules(module_names_str: str | None): 'framework for MLIR. Currently Shardy is experimental in JAX. See ' 'www.github.com/openxla/shardy' ), - update_global_hook=lambda val: _update_global_jit_state( - use_shardy_partitioner=val - ), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - use_shardy_partitioner=val - ), + include_in_jit_key=True, ) diff --git a/jax/_src/core.py b/jax/_src/core.py index 08ac0668265a..cf570b8eb424 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1016,18 +1016,16 @@ def is_top_level(self) -> bool: def set_trace(self, trace): self.trace = trace ts = ref(trace) if trace is not None else None - config.update_thread_local_jit_state(trace_state=ts) + config.trace_state.set_local(ts) def set_axis_env(self, axis_env): self.axis_env = axis_env - config.update_thread_local_jit_state( - axis_env_state=self.axis_env.as_hashable_key()) + config.axis_env_state.set_local(axis_env.as_hashable_key()) def update_thread_local_jit_state(self): ts = ref(self.trace) if self.trace is not None else None - config.update_thread_local_jit_state( - trace_state=ts, - axis_env_state=self.axis_env.as_hashable_key()) + config.trace_state.set_local(ts) + config.axis_env_state.set_local(self.axis_env.as_hashable_key()) trace_ctx = TracingContext() @@ -1071,10 +1069,7 @@ def _initialize_jax_jit_thread_local_state(): This function does not live in `config.py`, to prevent circular imports. """ - tls = jax_jit.thread_local_state() - - if tls.extra_jit_context is None: - trace_ctx.update_thread_local_jit_state() + trace_ctx.update_thread_local_jit_state() jax_jit.set_thread_local_state_initialization_callback( _initialize_jax_jit_thread_local_state) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 8cb508378129..43791f2e5f72 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -224,17 +224,17 @@ def __enter__(self): new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.mesh_context_manager.set_local( + tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return self def __exit__(self, exc_type, exc_value, traceback): thread_resources.stack.pop() thread_resources.env = thread_resources.stack[-1] - jax_config.update_thread_local_jit_state( - mesh_context_manager=tuple(t.physical_mesh for t in thread_resources.stack - if not t.physical_mesh.empty)) + jax_config.mesh_context_manager.set_local( + tuple(t.physical_mesh for t in thread_resources.stack + if not t.physical_mesh.empty)) return False @property @@ -410,7 +410,7 @@ def __exit__(self, exc_type, exc_value, traceback): @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): - jax_config.update_thread_local_jit_state(mesh_context_manager=mesh) + jax_config.mesh_context_manager.set_local(mesh) return diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 94b482e2dea4..94e26eeefa65 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -41,15 +41,13 @@ def set_xla_metadata(*args, **kwargs): thread_local_metadata.val, new_metadata, ) - config.update_thread_local_jit_state( - xla_metadata_context_manager=tuple( - (v, k) for k, v in sorted(new_metadata.items()))) + config.xla_metadata_context_manager.set_local( + tuple((v, k) for k, v in sorted(new_metadata.items())) + ) try: yield finally: thread_local_metadata.val = prev_metadata - config.update_thread_local_jit_state( - xla_metadata_context_manager=tuple( - (v, k) for k, v in sorted(prev_metadata.items()) - ) + config.xla_metadata_context_manager.set_local( + tuple((v, k) for k, v in sorted(prev_metadata.items())) ) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 6e0e795df334..0df5d99715e6 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2215,8 +2215,6 @@ def test_cache_uses_jax_key(self): pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) - config.update_thread_local_jit_state() - pmaped_f(inputs) self.assertEqual(pmaped_f._cache_size, 1) From 7bdb2bf998b02cf1022e1e3851eaf7184fe03a44 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Fri, 27 Sep 2024 09:26:14 +0000 Subject: [PATCH 204/698] [jax.distributed] Enable grpc channel compression --- jax/_src/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index e0155e012736..5b9130fc0455 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -123,7 +123,7 @@ def initialize(self, self.client = xla_extension.get_distributed_runtime_client( coordinator_address, process_id, init_timeout=initialization_timeout, heartbeat_interval=client_heartbeat_interval_seconds, - max_missing_heartbeats=client_max_missing_heartbeats) + max_missing_heartbeats=client_max_missing_heartbeats, use_compression=True) logger.info('Connecting to JAX distributed service on %s', coordinator_address) self.client.connect() From 095bb0e742e0c2624fc5bfcb6e2ab91f5b093638 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 5 Nov 2024 09:08:33 -0800 Subject: [PATCH 205/698] Make Tracers non-hashable --- CHANGELOG.md | 2 ++ jax/_src/basearray.py | 1 + jax/_src/core.py | 14 +------------- tests/array_test.py | 14 ++++---------- 4 files changed, 8 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c34b8e369c3..bafe551100d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. use `uses_global_constants`. * the `lowering_platforms` kwarg for {func}`jax.export.export`: use `platforms` instead. + * Hashing of tracers, which has been deprecated since version 0.4.30, now + results in a `TypeError`. * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index c3145f32e8bf..a89d4a2949be 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -53,6 +53,7 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace # associated basearray.pyi file. __slots__ = ['__weakref__'] + __hash__ = None @property @abc.abstractmethod diff --git a/jax/_src/core.py b/jax/_src/core.py index cf570b8eb424..6f96dc760cc0 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -34,7 +34,6 @@ import numpy as np -from jax._src import deprecations from jax._src import dtypes from jax._src import config from jax._src import effects @@ -638,24 +637,13 @@ def _aval_property(name): class Tracer(typing.Array, metaclass=StrictABCMeta): __array_priority__ = 1000 __slots__ = ['_trace', '_line_info'] + __hash__ = None # type: ignore dtype = _aval_property('dtype') ndim = _aval_property('ndim') size = _aval_property('size') shape = _aval_property('shape') - def __hash__(self): - # TODO(jakevdp) finalize this deprecation and set __hash__ = None - # Warning added 2024-06-13 - if deprecations.is_accelerated('tracer-hash'): - raise TypeError(f"unhashable type: {type(self)}") - # Use FutureWarning rather than DeprecationWarning because hash is likely - # not called directly by the user, so we want to warn at all stacklevels. - warnings.warn( - f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an" - " error in a future JAX release.", category=FutureWarning) - return super().__hash__() - def __init__(self, trace: Trace): self._trace = trace diff --git a/tests/array_test.py b/tests/array_test.py index e7aad59b1ad5..9618a8cf4665 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -24,7 +24,6 @@ import jax import jax.numpy as jnp from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import op_shardings from jax._src import test_util as jtu @@ -608,16 +607,11 @@ def test_array_not_hashable(self): with self.assertRaisesRegex(TypeError, "unhashable type"): hash(x) - @jax.jit - def check_tracer_hash(x): - self.assertIsInstance(hash(x), int) + with self.assertRaisesRegex(TypeError, "unhashable type"): + jax.jit(hash)(x) - if deprecations.is_accelerated('tracer-hash'): - with self.assertRaisesRegex(TypeError, "unhashable type"): - check_tracer_hash(x) - else: - with self.assertWarnsRegex(FutureWarning, "unhashable type"): - check_tracer_hash(x) + with self.assertRaisesRegex(TypeError, "unhashable type"): + jax.vmap(hash)(x) def test_shape_dtype_struct_sharding_jit(self): mesh = jtu.create_mesh((8,), ('x')) From 650e70ab2e481242bc8bfde5dd47fc43e3dc56ed Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 5 Nov 2024 11:32:09 -0600 Subject: [PATCH 206/698] Fix nightly sync permissions (#124) --- .github/workflows/rocm-nightly-upstream-sync.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index 98c958c3daa0..a15e49c2e87b 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -8,13 +8,19 @@ on: - cron: '0 6 * * 1-5' jobs: sync-main: + permissions: + contents: write runs-on: ubuntu-latest steps: - run: gh repo sync rocm/jax -b main env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} open-sync-pr: + permissions: + pull-requests: write runs-on: ubuntu-latest steps: - run: | gh pr create --repo $GITHUB_REPOSITORY --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} From 04f2ef9e930829f64be67f404f41d175261de9e6 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 5 Sep 2024 18:22:15 -0700 Subject: [PATCH 207/698] Adding `JAX_LOGGING_LEVEL` configuration option --- docs/persistent_compilation_cache.md | 13 +- jax/_src/config.py | 24 ++-- jax/_src/logging_config.py | 108 +++++++++++++--- tests/logging_test.py | 178 ++++++++++++++++++++++++--- 4 files changed, 282 insertions(+), 41 deletions(-) diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 47a7587b620f..c49e18394e9a 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -70,11 +70,11 @@ cc.set_cache_dir("/tmp/jax_cache") * `jax_persistent_cache_min_entry_size_bytes`: The minimum size (in bytes) of an entry that will be cached in the persistent compilation cache: - * `-1`: disable the size restriction and prevent overrides. + * `-1`: disable the size restriction and prevent overrides. * Leave at default (`0`) to allow for overrides. The override will typically ensure that the minimum size is optimal for the file system - being used for the cache. + being used for the cache. * `> 0`: the actual minimum size desired; no overrides. @@ -155,7 +155,14 @@ import os os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache" ``` -on the top of the script. +on the top of the script. Alternatively, you can change the global jax logging level with + +```python +import os +os.environ["JAX_LOGGING_LEVEL"] = "DEBUG" +# or locally with +jax.config.update("jax_logging_level", "DEBUG") +``` ### Examining cache misses diff --git a/jax/_src/config.py b/jax/_src/config.py index 0e113c894695..215ef443c799 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1647,6 +1647,7 @@ def _update_disable_jit_thread_local(val): 'the name of a :class:`~jax.lax.DotAlgorithmPreset`.\n\n'), include_in_jit_key=True) + traceback_filtering = enum_state( name = 'jax_traceback_filtering', enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", @@ -1913,14 +1914,6 @@ def _update_garbage_collection_guard(state, key, val): ), ) -def _update_debug_log_modules(module_names_str: str | None): - logging_config.disable_all_debug_logging() - if not module_names_str: - return - module_names = module_names_str.split(',') - for module_name in module_names: - logging_config.enable_debug_logging(module_name) - # Don't define a context manager since this isn't threadsafe. string_state( name='jax_debug_log_modules', @@ -1928,7 +1921,20 @@ def _update_debug_log_modules(module_names_str: str | None): help=('Comma-separated list of module names (e.g. "jax" or ' '"jax._src.xla_bridge,jax._src.dispatch") to enable debug logging ' 'for.'), - update_global_hook=_update_debug_log_modules) + update_global_hook=logging_config.update_debug_log_modules) + +# Don't define a context manager since this isn't threadsafe. +optional_enum_state( + name='jax_logging_level', + enum_values=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default=logging.getLevelName(logging.getLogger("jax").level), + help=('Set the corresponding logging level on all jax loggers. Only string' + ' values from ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR",' + ' "CRITICAL"] are accepted. If None, the logging level will not be' + ' set. Includes C++ logging.'), + update_global_hook=lambda logging_level: \ + logging_config.update_logging_level_global(logging_level=logging_level) +) pmap_no_rank_reduction = bool_state( name='jax_pmap_no_rank_reduction', diff --git a/jax/_src/logging_config.py b/jax/_src/logging_config.py index d2f9d9c8fb1f..bdf588d2054a 100644 --- a/jax/_src/logging_config.py +++ b/jax/_src/logging_config.py @@ -13,19 +13,92 @@ # limitations under the License. import logging +import os import sys -_debug_handler = logging.StreamHandler(sys.stderr) -_debug_handler.setLevel(logging.DEBUG) # Example log message: # DEBUG:2023-06-07 00:14:40,280:jax._src.xla_bridge:590: Initializing backend 'cpu' -_debug_handler.setFormatter(logging.Formatter( - "{levelname}:{asctime}:{name}:{lineno}: {message}", style='{')) +logging_formatter = logging.Formatter( + "{levelname}:{asctime}:{name}:{lineno}: {message}", style='{') -_debug_enabled_loggers = [] +_logging_level_set: dict[str, int] = {} +_default_TF_CPP_MIN_LOG_LEVEL = os.environ.get("TF_CPP_MIN_LOG_LEVEL", "1") + +_jax_logger_handler = logging.StreamHandler(sys.stderr) +_jax_logger_handler.setFormatter(logging_formatter) + +_nameToLevel = { + 'CRITICAL': logging.CRITICAL, + 'FATAL': logging.FATAL, + 'ERROR': logging.ERROR, + 'WARN': logging.WARNING, + 'WARNING': logging.WARNING, + 'INFO': logging.INFO, + 'DEBUG': logging.DEBUG, + 'NOTSET': logging.NOTSET, +} + +_tf_cpp_map = { + 'CRITICAL': 3, + 'FATAL': 3, + 'ERROR': 2, + 'WARN': 1, + 'WARNING': 1, + 'INFO': 0, + 'DEBUG': 0, +} + +def _set_TF_CPP_MIN_LOG_LEVEL(logging_level: str | None = None): + if logging_level in (None, "NOTSET"): + # resetting to user-default TF_CPP_MIN_LOG_LEVEL + # this is typically "1", but if the user overrode it, it can be != "1" + os.environ["TF_CPP_MIN_LOG_LEVEL"] = _default_TF_CPP_MIN_LOG_LEVEL + else: + # set cpp runtime logging level if the level is anything but NOTSET + if logging_level not in _tf_cpp_map: + raise ValueError(f"Attempting to set log level \"{logging_level}\" which" + f" isn't one of the supported:" + f" {list(_tf_cpp_map.keys())}.") + # config the CPP logging level 0 - debug, 1 - info, 2 - warning, 3 - error + os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(_tf_cpp_map[logging_level]) + +def update_logging_level_global(logging_level: str | None) -> None: + # remove previous handlers + for logger_name, level in _logging_level_set.items(): + logger = logging.getLogger(logger_name) + logger.removeHandler(_jax_logger_handler) + logger.setLevel(level) + _logging_level_set.clear() + _set_TF_CPP_MIN_LOG_LEVEL(logging_level) + + if logging_level is None: + return + + logging_level_num = _nameToLevel[logging_level] + # update jax and jaxlib root loggers for propagation + root_loggers = [logging.getLogger("jax"), logging.getLogger("jaxlib")] + for logger in root_loggers: + logger.setLevel(logging_level_num) + logger.addHandler(_jax_logger_handler) + _logging_level_set[logger.name] = logger.level -def enable_debug_logging(logger_name): +# per-module debug logging + +_jax_logger = logging.getLogger("jax") + +class _DebugHandlerFilter(logging.Filter): + def filter(self, _): + return _jax_logger.level > logging.DEBUG + +_debug_handler = logging.StreamHandler(sys.stderr) +_debug_handler.setLevel(logging.DEBUG) +_debug_handler.setFormatter(logging_formatter) +_debug_handler.addFilter(_DebugHandlerFilter()) + +_debug_enabled_loggers = [] + +def _enable_debug_logging(logger_name): """Makes the specified logger log everything to stderr. Also adds more useful debug information to the log messages, e.g. the time. @@ -34,21 +107,28 @@ def enable_debug_logging(logger_name): logger_name: the name of the logger, e.g. "jax._src.xla_bridge". """ logger = logging.getLogger(logger_name) + _debug_enabled_loggers.append((logger, logger.level)) + logger.addHandler(_debug_handler) logger.setLevel(logging.DEBUG) - _debug_enabled_loggers.append(logger) -def disable_all_debug_logging(): +def _disable_all_debug_logging(): """Disables all debug logging enabled via `enable_debug_logging`. The default logging behavior will still be in effect, i.e. WARNING and above will be logged to stderr without extra message formatting. """ - for logger in _debug_enabled_loggers: + for logger, prev_level in _debug_enabled_loggers: + logger: logging.Logger logger.removeHandler(_debug_handler) - # Assume that the default non-debug log level is always WARNING. In theory - # we could keep track of what it was set to before. This shouldn't make a - # difference if not other handlers are attached, but set it back in case - # something else gets attached (e.g. absl logger) and for consistency. - logger.setLevel(logging.WARNING) + logger.setLevel(prev_level) + _debug_enabled_loggers.clear() + +def update_debug_log_modules(module_names_str: str | None): + _disable_all_debug_logging() + if not module_names_str: + return + module_names = module_names_str.split(',') + for module_name in module_names: + _enable_debug_logging(module_name) diff --git a/tests/logging_test.py b/tests/logging_test.py index a1d6695a1e37..05d874b73ac3 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -15,8 +15,9 @@ import contextlib import io import logging -import os import platform +import re +import shlex import subprocess import sys import tempfile @@ -49,10 +50,23 @@ def jax_debug_log_modules(value): finally: jax.config.update("jax_debug_log_modules", original_value) +@contextlib.contextmanager +def jax_logging_level(value): + # jax_logging_level doesn't have a context manager, because it's + # not thread-safe. But since tests are always single-threaded, we + # can define one here. + original_value = jax.config.jax_logging_level + jax.config.update("jax_logging_level", value) + try: + yield + finally: + jax.config.update("jax_logging_level", original_value) + @contextlib.contextmanager def capture_jax_logs(): log_output = io.StringIO() + handler = logging.StreamHandler(log_output) logger = logging.getLogger("jax") @@ -91,21 +105,8 @@ def test_no_log_spam(self): """)) python = sys.executable assert "python" in python - env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} - if os.getenv("ASAN_OPTIONS"): - env_variables["ASAN_OPTIONS"] = os.getenv("ASAN_OPTIONS") - if os.getenv("PYTHONPATH"): - env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") - if os.getenv("LD_LIBRARY_PATH"): - env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") - if os.getenv("LD_PRELOAD"): - env_variables["LD_PRELOAD"] = os.getenv("LD_PRELOAD") # Make sure C++ logging is at default level for the test process. - proc = subprocess.run( - [python, f.name], - capture_output=True, - env=env_variables, - ) + proc = subprocess.run([python, f.name], capture_output=True) lines = proc.stdout.split(b"\n") lines.extend(proc.stderr.split(b"\n")) @@ -155,6 +156,153 @@ def test_debug_logging(self): jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_stderr_info_logging(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + # test INFO + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + info_lines = log_output.split("\n") + self.assertGreater(len(info_lines), 0) + self.assertIn("INFO", log_output) + self.assertNotIn("DEBUG", log_output) + + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_stderr_debug_logging(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + # test DEBUG + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + self.assertIn("INFO", log_output) + self.assertIn("DEBUG", log_output) + + # test JAX_DEBUG_MODULES + cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + self.assertIn("DEBUG", log_output) + + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_toggling_logging_level(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + _separator = "---------------------------" + program = f""" + import sys + import jax # this prints INFO logging from backend imports + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + jax.config.update("jax_logging_level", None) + sys.stderr.write("{_separator}") + jax.jit(lambda x: x)(1) # should not log anything now + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + m = re.search(_separator, log_output) + self.assertTrue(m is not None) + log_output_verbose = log_output[:m.start()] + log_output_silent = log_output[m.end():] + + self.assertIn("Finished tracing + transforming for pjit", + log_output_verbose) + self.assertEqual(log_output_silent, "") + + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_double_logging_absent(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import jax # this prints INFO logging from backend imports + jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch") + jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + log_output = p.stderr + self.assertNotEmpty(log_output) + log_lines = log_output.strip().split("\n") + # only one tracing line should be printed, if there's more than one + # then logs are printing duplicated + self.assertLen([line for line in log_lines + if "Finished tracing + transforming" in line], 1) + + @unittest.skipIf(platform.system() == "Windows", + "Subprocess test doesn't work on Windows") + def test_subprocess_cpp_logging_level(self): + if sys.executable is None: + raise self.skipTest("test requires access to python binary") + + program = """ + import sys + import jax # this prints INFO logging from backend imports + jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0) + """ + + # strip the leading whitespace from the program script + program = re.sub(r"^\s+", "", program, flags=re.MULTILINE) + + # verbose logging: DEBUG, VERBOSE + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + self.assertIn("Initializing CoordinationService", p.stderr) + + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + self.assertIn("Initializing CoordinationService", p.stderr) + + # verbose logging: WARNING, None + cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + self.assertNotIn("Initializing CoordinationService", p.stderr) + + cmd = shlex.split(f"{sys.executable} -c" + f" '{program}'") + p = subprocess.run(cmd, capture_output=True, text=True) + self.assertNotIn("Initializing CoordinationService", p.stderr) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 21f4b0854a492e7f806341980646ea3015c18e41 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 5 Nov 2024 12:27:41 -0800 Subject: [PATCH 208/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/5a9f79f295ba8d16afce24ea8724da525b8eb87d. PiperOrigin-RevId: 693439980 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 93ac0a9eae0c..3dc24da2559b 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c66d74e5b3ef0d64c43cdd99c8e6aac8512adb6a" -XLA_SHA256 = "1b036e7adc0d408b76ab4f67705704ad7d95e4070c8e8e905315f678d3f7f1df" +XLA_COMMIT = "5a9f79f295ba8d16afce24ea8724da525b8eb87d" +XLA_SHA256 = "83e516dd8f7c61541aa9e2cba7fe480166ea23f28a41fed445fef4c5b6d45519" def repo(): tf_http_archive( From b60d0ab8ac16fe8d8791bbc653cf997463f65104 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 5 Nov 2024 13:28:17 -0800 Subject: [PATCH 209/698] logging fixes --- tests/logging_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/logging_test.py b/tests/logging_test.py index 05d874b73ac3..a83058095ce6 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -27,6 +27,7 @@ import jax import jax._src.test_util as jtu from jax._src import xla_bridge +from jax._src.logging_config import _default_TF_CPP_MIN_LOG_LEVEL # Note: importing absltest causes an extra absl root log handler to be # registered, which causes extra debug log messages. We don't expect users to @@ -156,6 +157,7 @@ def test_debug_logging(self): jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) + @jtu.skip_on_devices("tpu") @unittest.skipIf(platform.system() == "Windows", "Subprocess test doesn't work on Windows") def test_subprocess_stderr_info_logging(self): @@ -180,6 +182,7 @@ def test_subprocess_stderr_info_logging(self): self.assertIn("INFO", log_output) self.assertNotIn("DEBUG", log_output) + @jtu.skip_on_devices("tpu") @unittest.skipIf(platform.system() == "Windows", "Subprocess test doesn't work on Windows") def test_subprocess_stderr_debug_logging(self): @@ -209,6 +212,7 @@ def test_subprocess_stderr_debug_logging(self): log_output = p.stderr self.assertIn("DEBUG", log_output) + @jtu.skip_on_devices("tpu") @unittest.skipIf(platform.system() == "Windows", "Subprocess test doesn't work on Windows") def test_subprocess_toggling_logging_level(self): @@ -241,6 +245,7 @@ def test_subprocess_toggling_logging_level(self): log_output_verbose) self.assertEqual(log_output_silent, "") + @jtu.skip_on_devices("tpu") @unittest.skipIf(platform.system() == "Windows", "Subprocess test doesn't work on Windows") def test_subprocess_double_logging_absent(self): @@ -267,6 +272,7 @@ def test_subprocess_double_logging_absent(self): self.assertLen([line for line in log_lines if "Finished tracing + transforming" in line], 1) + @jtu.skip_on_devices("tpu") @unittest.skipIf(platform.system() == "Windows", "Subprocess test doesn't work on Windows") def test_subprocess_cpp_logging_level(self): @@ -302,7 +308,8 @@ def test_subprocess_cpp_logging_level(self): cmd = shlex.split(f"{sys.executable} -c" f" '{program}'") p = subprocess.run(cmd, capture_output=True, text=True) - self.assertNotIn("Initializing CoordinationService", p.stderr) + if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1: + self.assertNotIn("Initializing CoordinationService", p.stderr) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From ea1e879577c71aa34fb15eba817339efbfaa272d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 5 Nov 2024 13:42:22 -0800 Subject: [PATCH 210/698] Include mpmath as a bazel dependency of lax_test. This test has additional test cases that require mpmath. PiperOrigin-RevId: 693464078 --- jaxlib/jax.bzl | 1 + tests/BUILD | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 40ec4ca7fe55..b5bfe733b992 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -67,6 +67,7 @@ _py_deps = { "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], "matplotlib": ["@pypi_matplotlib//:pkg"], + "mpmath": [], "opt_einsum": ["@pypi_opt_einsum//:pkg"], "pil": ["@pypi_pillow//:pkg"], "portpicker": ["@pypi_portpicker//:pkg"], diff --git a/tests/BUILD b/tests/BUILD index 58de3404979d..90a8913a1825 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -565,7 +565,7 @@ jax_multiplatform_test( deps = [ "//jax:internal_test_util", "//jax:lax_reference", - ] + py_deps("numpy"), + ] + py_deps("numpy") + py_deps("mpmath"), ) jax_multiplatform_test( From 44c6883ceeb27566e57229afe47fad5ae32a6bf9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 5 Nov 2024 15:36:14 -0800 Subject: [PATCH 211/698] Fix debug_nans false positive in jnp.quantile --- jax/_src/numpy/reductions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index fa8d73361e2b..be1e55675079 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2360,7 +2360,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) + with jax.debug_nans(False): + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) From d62510bfae2c18423575333f211ad34abe0d2785 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 24 Oct 2024 17:57:06 -0700 Subject: [PATCH 212/698] Adding start index and kv_seq_len to decode kernel --- .../pallas/ops/gpu/decode_attention.py | 347 +++++++++++------- tests/pallas/gpu_attention_test.py | 31 +- 2 files changed, 243 insertions(+), 135 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index a7e1b33e1f35..d09f1fbac113 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -14,6 +14,7 @@ """Module containing decode attention.""" from __future__ import annotations +import math import functools from typing import Any @@ -24,82 +25,115 @@ from jax.experimental.pallas import triton as plgpu import jax.numpy as jnp - def attn_forward_kernel( - q_ref, # [num_heads, head_dim] - k_ref, # [k_seq_len, head_dim] - v_ref, # [k_seq_len, head_dim] - o_ref: Any, # [num_heads, head_dim] + # inputs + q_ref, # [num_heads, head_dim] + k_ref, # [k_seq_len, head_dim] + v_ref, # [k_seq_len, head_dim] + start_idx_ref, # [] (i.e., scalar) + kv_seq_len_ref, # [] (i.e., scalar) + # outputs + o_ref: Any, # [num_heads, head_dim] *residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,] sm_scale: float, block_k: int, + block_h: int, + num_heads: int, ): - block_h, head_dim = q_ref.shape - k_seq_len, _ = k_ref.shape - start_q = pl.program_id(0) + _, head_dim = q_ref.shape + split_k_seq_len, _ = k_ref.shape + prog_i, prog_j = pl.program_id(0), pl.program_id(1) + q_slice = pl.ds(0, block_h) + q_mask = (jnp.arange(block_h) < num_heads - block_h * prog_i)[:, None] + + def _compute(start_idx, kv_seq_len, o, m_i, l_i): + # Load q: it will stay in L1 throughout. Indices form a matrix because we + # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. + # q tile has shape [block_h, head_dim]. + q = pl.load(q_ref, (q_slice, pl.ds(None)), mask=q_mask) + + def _dot(a, b): + # if a.shape[0] == 1: + # # Use matrix vector product + # return (a.T * b).sum(axis=0, keepdims=True) + return pl.dot(a, b) + + mask_indices = jnp.arange(block_k) + + # Loop over blocks of kv to process entire kv seq_len. + # Grid loops over q blocks over num_heads. + def body(start_k, carry): + o_prev, m_prev, l_prev = carry + curr_k_slice = pl.ds(start_k * block_k, block_k) + + k = pl.load(k_ref, (curr_k_slice, slice(None))) + qk = _dot(q, k.T) # [block_h, block_k] + if sm_scale != 1.0: + qk *= sm_scale # [block_h, block_k] + + # apply mask if start or sequence length is specified + if start_idx_ref is not None or kv_seq_len_ref is not None: + indices = (prog_j * split_k_seq_len + start_k * block_k + mask_indices) + mask = ((indices >= start_idx) & (indices < kv_seq_len))[None, :] + qk += (~mask) * (0.7 * jnp.finfo(qk.dtype).min) + + m_curr = qk.max(axis=-1) + m_next = jnp.maximum(m_prev, m_curr) + correction = jnp.exp(m_prev - m_next) + l_prev_corr = correction * l_prev + s_curr = jnp.exp( + qk - m_next[:, None] + ) # Use m_next instead of m_curr to avoid a correction on l_curr + l_curr = s_curr.sum(axis=-1) + l_next = l_prev_corr + l_curr + v = pl.load(v_ref, (curr_k_slice, slice(None))) + o_curr = _dot(s_curr.astype(v.dtype), v) + + # flash2 unscaled_o + o_next = correction[:, None] * o_prev + o_curr + return o_next, m_next, l_next + + max_it = jnp.minimum(pl.cdiv((kv_seq_len - prog_j * split_k_seq_len), + block_k), split_k_seq_len // block_k) + (o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i)) + return o, m_i, l_i # o is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop. - m_i = jnp.zeros(block_h, dtype=jnp.float32) - float("inf") + m_i = jnp.zeros(block_h, dtype=jnp.float32) + jnp.finfo(jnp.float32).min l_i = jnp.zeros(block_h, dtype=jnp.float32) o = jnp.zeros((block_h, head_dim), dtype=jnp.float32) - # Load q: it will stay in L1 throughout. Indices form a matrix because we - # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. - # q tile has shape [block_h, head_dim]. - curr_q_slice = pl.dslice(start_q * block_h, block_h) - q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) - - def _dot(a, b): - # if a.shape[0] == 1: - # # Use matrix vector product - # return (a.T * b).sum(axis=0, keepdims=True) - return pl.dot(a, b) - - # Loop over blocks of kv to process entire kv seq_len. - # Grid loops over q blocks over num_heads. - def body(start_k, carry): - o_prev, m_prev, l_prev = carry - curr_k_slice = pl.dslice(start_k * block_k, block_k) - - k = pl.load(k_ref, (curr_k_slice, slice(None))) - qk = _dot(q, k.T) # [block_h, block_k] - if sm_scale != 1.0: - qk *= sm_scale # [block_h, block_k] - - m_curr = qk.max(axis=-1) - m_next = jnp.maximum(m_prev, m_curr) - correction = jnp.exp(m_prev - m_next) - l_prev_corr = correction * l_prev - s_curr = jnp.exp( - qk - m_next[:, None] - ) # Use m_next instead of m_curr to avoid a correction on l_curr - l_curr = s_curr.sum(axis=-1) - l_next = l_prev_corr + l_curr - v = pl.load(v_ref, (curr_k_slice, slice(None))) - o_curr = _dot(s_curr.astype(v.dtype), v) - - # flash2 unscaled_o - o_next = correction[:, None] * o_prev + o_curr - return o_next, m_next, l_next - - upper_bound = pl.cdiv(k_seq_len, block_k) - # o is left unscaled; it will be scaled in the final reduction step - o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) + start_idx = split_k_seq_len * prog_j + if start_idx_ref is not None: + start_idx = jnp.maximum(start_idx, pl.load(start_idx_ref, ())) + kv_seq_len = (prog_j + 1) * split_k_seq_len # lower bound on actual k_seq_len + if kv_seq_len_ref is not None: + kv_seq_len = jnp.minimum(kv_seq_len, pl.load(kv_seq_len_ref, ())) + if start_idx_ref is None and kv_seq_len is None: + o, m_i, l_i = _compute(start_idx, kv_seq_len, o, m_i, l_i) + else: + o, m_i, l_i = jax.lax.cond( + start_idx >= kv_seq_len, lambda: (o, m_i, l_i), + lambda: _compute(start_idx, kv_seq_len, o, m_i, l_i)) + + # Write output to dram. if residual_refs: l_ref, m_ref = residual_refs - pl.store(l_ref, (curr_q_slice,), l_i) - pl.store(m_ref, (curr_q_slice,), m_i) - # Write output to dram. + vec_q_mask = q_mask.reshape(-1) if q_mask is not None else None + pl.store(l_ref, q_slice, l_i, mask=vec_q_mask) + pl.store(m_ref, q_slice, m_i, mask=vec_q_mask) o = o.astype(o_ref.dtype) - pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) + pl.store(o_ref, (q_slice, pl.ds(None)), o, mask=q_mask) -def attn_unbatched( - q, # [num_heads, head_dim] - k, # [k_seq_len, head_dim] - v, # [k_seq_len, head_dim] +def decode_attn_unbatched( + q, # [num_heads, head_dim] + k, # [k_seq_len, head_dim] + v, # [k_seq_len, head_dim] + start_idx, # [] + kv_seq_len, # [] sm_scale: float, block_h: int, block_k: int, @@ -113,12 +147,6 @@ def attn_unbatched( num_heads, head_dim = q.shape k_seq_len, _ = k.shape # Pad num query heads to 16 if needed, and slice output at the end. - original_num_heads = None - if num_heads < 16: - q = jnp.pad(q, ((0, 16 - num_heads), (0, 0))) - original_num_heads = num_heads - num_heads = q.shape[0] - block_h = min(block_h, num_heads) head_splits = pl.cdiv(num_heads, block_h) grid_ = grid if grid_ is None: @@ -127,11 +155,16 @@ def attn_unbatched( assert ( k_seq_len % k_splits == 0 ), f"{k_seq_len=} must be divisible by {k_splits=}" + assert k_seq_len // k_splits >= 16, ( + f"{k_seq_len=} divided by {k_splits=} must be >= 16.") + assert block_k >= 16, "block_k must be >= 16" k = k.reshape(k_splits, k_seq_len // k_splits, head_dim) v = v.reshape(k_splits, k_seq_len // k_splits, head_dim) - k_seq_len = k_seq_len // k_splits - assert min(num_heads, head_dim, k_seq_len) >= 16, "Minimum pl.dot size is 16" - block_k = min(block_k, k_seq_len) + split_k_seq_len = k_seq_len // k_splits + block_k = min(block_k, split_k_seq_len) + assert split_k_seq_len % block_k == 0, ( + f"Sequence length ({k_seq_len=}) split by {k_splits=} must by divisible by" + f" {block_k=}") num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 @@ -139,47 +172,49 @@ def attn_unbatched( attn_forward_kernel, sm_scale=sm_scale, block_k=block_k, + block_h=block_h, + num_heads=num_heads, ) o, l, m = pl.pallas_call( - kernel, - grid=grid_, - in_specs=[ - pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)), - pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), - pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), - ], - out_specs=[ - pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o - pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l - pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m - ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps_, num_stages=num_stages - ), - out_shape=[ - jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o - jax.ShapeDtypeStruct( - shape=(k_splits, num_heads), dtype=jnp.float32 - ), # l - jax.ShapeDtypeStruct( - shape=(k_splits, num_heads), dtype=jnp.float32 - ), # m - ], - debug=debug, - interpret=interpret, - name="mha_forward", - )(q, k, v) + kernel, + grid=grid_, + in_specs=[ + pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)), + pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)), + pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)), + ] + + [None if start_idx is None else pl.BlockSpec((), lambda i, j: ())] + + [None if kv_seq_len is None else pl.BlockSpec((), lambda i, j: ())], + out_specs=[ + pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m + ], + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages + ), + out_shape=[ + jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # l + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # m + ], + debug=debug, + interpret=interpret, + name="mha_forward", + )(q, k, v, start_idx, kv_seq_len) # final round of flash m_next = m.max(axis=0) correction = jnp.exp(m - m_next[None]) - o = o * correction[:, :, None] + o = o * correction[:, :, None].astype(o.dtype) l_next = (l * correction).sum(axis=0) - o = o.sum(axis=0) / l_next[:, None] - - if original_num_heads is not None: - o = o[:original_num_heads, :] + eps = jnp.finfo(l_next.dtype).eps + o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps) return o @@ -198,10 +233,12 @@ def attn_unbatched( ], ) def mqa( - q, # [batch_size, num_heads, head_dim] - k, # [batch_size, k_seq_len, head_dim] - v, # [batch_size, k_seq_len, head_dim] - sm_scale: float = 1.0, + q, # [batch_size, num_heads, head_dim] + k, # [batch_size, k_seq_len, head_dim] + v, # [batch_size, k_seq_len, head_dim] + start_idx=None, # [batch_size] + kv_seq_len=None, # [batch_size] + sm_scale: float | None = None, block_h: int = 16, block_k: int = 256, k_splits: int = 16, @@ -211,8 +248,14 @@ def mqa( interpret: bool = False, debug: bool = False, ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) + bs = q.shape[0] + if start_idx is not None: + start_idx = jnp.broadcast_to(start_idx, (bs,)) + if kv_seq_len is not None: + kv_seq_len = jnp.broadcast_to(kv_seq_len, (bs,)) inner = functools.partial( - attn_unbatched, + decode_attn_unbatched, sm_scale=sm_scale, block_h=block_h, block_k=block_k, @@ -223,7 +266,7 @@ def mqa( interpret=interpret, debug=debug, ) - return jax.vmap(inner)(q, k, v) + return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len) @functools.partial( @@ -241,12 +284,14 @@ def mqa( ], ) def gqa( - q, # [batch_size, num_q_heads, head_dim] - k, # [batch_size, k_seq_len, num_kv_heads, head_dim] - v, # [batch_size, k_seq_len, num_kv_heads, head_dim] - sm_scale: float = 1.0, + q, # [batch_size, num_q_heads, head_dim] + k, # [batch_size, k_seq_len, num_kv_heads, head_dim] + v, # [batch_size, k_seq_len, num_kv_heads, head_dim] + start_idx=None, # [batch_size] + kv_seq_len=None, # [batch_size] + sm_scale: float | None = None, block_h: int = 16, - block_k: int = 256, + block_k: int = 128, k_splits: int = 16, num_warps: int | None = None, num_stages: int = 2, @@ -254,10 +299,19 @@ def gqa( interpret: bool = False, debug: bool = False, ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) batch_size, q_heads, head_dim = q.shape - kv_heads = k.shape[2] + k_seq_len, kv_heads = k.shape[1], k.shape[2] assert kv_heads == v.shape[2] assert q_heads % kv_heads == 0 + if start_idx is not None: + assert start_idx.ndim in (0, 1) + start_idx = jnp.broadcast_to(jnp.asarray(start_idx)[..., None], + (batch_size, kv_heads)) + if kv_seq_len is not None: + assert kv_seq_len.ndim in (0, 1) + kv_seq_len = jnp.broadcast_to(jnp.asarray(kv_seq_len)[..., None], + (batch_size, kv_heads)) q_heads_per_kv_head = q_heads // kv_heads q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim) k_transposed = jnp.swapaxes( @@ -267,7 +321,7 @@ def gqa( v, 1, 2 ) # [batch_size, num_kv_heads, k_seq_len, head_dim] inner = functools.partial( - attn_unbatched, + decode_attn_unbatched, sm_scale=sm_scale, block_h=block_h, block_k=block_k, @@ -279,42 +333,70 @@ def gqa( debug=debug, ) with_kv_heads = jax.vmap(inner) - o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed) + o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed, + start_idx, kv_seq_len) return o.reshape(batch_size, q_heads, head_dim) @functools.partial(jax.jit, static_argnames=["sm_scale"]) def mqa_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, head_dim] - v, # [bs, k_seq_len, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, head_dim] + v, # [bs, k_seq_len, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, ): + bs = q.shape[0] + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32) + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) return jnp.einsum("bns,bsd->bnd", weights, v) @functools.partial(jax.jit, static_argnames=["sm_scale"]) def mha_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, num_k_heads, head_dim] - v, # [bs, k_seq_len, num_v_heads, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, ): + bs = q.shape[0] + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) assert q.shape[1] == k.shape[2] logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32) + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) return jnp.einsum("bns,bsnd->bnd", weights, v) @functools.partial(jax.jit, static_argnames=["sm_scale"]) def gqa_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, num_k_heads, head_dim] - v, # [bs, k_seq_len, num_v_heads, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) bs, num_q_heads, head_dim = q.shape num_kv_heads = k.shape[2] assert num_q_heads % num_kv_heads == 0 @@ -330,6 +412,15 @@ def gqa_reference( logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype( jnp.float32 ) + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed) - return o.reshape(bs, num_q_heads, head_dim) + o = o.reshape(bs, num_q_heads, head_dim) + return o diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index ed059c235329..afd2f6ae3fcf 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -62,12 +62,15 @@ class DecodeAttentionTest(PallasBaseTest): @parameterized.named_parameters(*[ ( - f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}", + (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_" + f"{start_idx=}_{kv_seq_len=}"), batch_size, seq_len, num_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ) for ( batch_size, @@ -80,6 +83,8 @@ class DecodeAttentionTest(PallasBaseTest): (2, 1024, 2, 64, {}), (1, 1024, 8, 64, {}), ] + for start_idx in [None, 123] + for kv_seq_len in [None, 250] ]) @jax.numpy_dtype_promotion("standard") def test_mqa( @@ -89,6 +94,8 @@ def test_mqa( num_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ): del kwargs @@ -97,19 +104,24 @@ def test_mqa( k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16) v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16) - o = decode_attention.mqa(q, k, v, interpret=self.INTERPRET) - o_ref = decode_attention.mqa_reference(q, k, v) + o = decode_attention.mqa(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len, interpret=self.INTERPRET) + o_ref = decode_attention.mqa_reference(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len) np.testing.assert_allclose(o, o_ref, atol=0.05) @parameterized.named_parameters(*[ ( - f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}_{kwargs=}", + (f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}" + f"_{kwargs=}_{start_idx=}_{kv_seq_len=}"), batch_size, seq_len, num_q_heads, num_kv_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ) for ( batch_size, @@ -123,6 +135,8 @@ def test_mqa( (1, 1024, 16, 16, 64, {}), (1, 1024, 32, 32, 64, {}), ] + for start_idx in [None, 123] + for kv_seq_len in [None, 250] ]) @jax.numpy_dtype_promotion("standard") def test_gqa( @@ -133,6 +147,8 @@ def test_gqa( num_kv_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ): del kwargs @@ -146,9 +162,10 @@ def test_gqa( v = random.normal( k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16 ) - - o = decode_attention.gqa(q, k, v, interpret=self.INTERPRET) - o_ref = decode_attention.gqa_reference(q, k, v) + o = decode_attention.gqa(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len, interpret=self.INTERPRET) + o_ref = decode_attention.gqa_reference(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len) np.testing.assert_allclose(o, o_ref, atol=0.05) class DecodeAttentionInterpretTest(DecodeAttentionTest): From b8c263a56cde8991bc13dffe7a57183c5662e88c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 6 Nov 2024 08:13:01 -0800 Subject: [PATCH 213/698] Add support for tpu v5e to `jax.make_mesh` PiperOrigin-RevId: 693732928 --- jax/_src/mesh_utils.py | 1 + jax/_src/sharding_impls.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index c37bbba4d836..996a6811a20d 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -32,6 +32,7 @@ _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' _TPU_V5_LITE = "TPU v5 lite" +_TPU_V5E = "TPU v5e" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 9cb1e49299ea..fa65bbe9328d 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1728,7 +1728,7 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], f'of mesh_shape {axis_shapes}') elif axis_size < len(devices): devices = devices[:axis_size] - if devices[0].device_kind == mesh_utils._TPU_V5_LITE: + if devices[0].device_kind in (mesh_utils._TPU_V5_LITE, mesh_utils._TPU_V5E): allow_split_physical_axes = True else: allow_split_physical_axes = False From d698da610a345a5cf4d317ac9e2195c241863e89 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 5 Nov 2024 09:48:24 -0800 Subject: [PATCH 214/698] scipy.special.beta: remove deprecated x and y parameters --- jax/_src/scipy/special.py | 31 ++--------------------- tests/lax_scipy_special_functions_test.py | 19 ++------------ 2 files changed, 4 insertions(+), 46 deletions(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 837aa011f165..605cde19b1e7 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -16,7 +16,7 @@ from functools import partial import operator -from typing import cast, overload, Any +from typing import cast, Any import numpy as np @@ -28,7 +28,6 @@ from jax._src import core from jax._src import custom_derivatives -from jax._src import deprecations from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact @@ -189,16 +188,8 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array: n, = promote_args_inexact("factorial", n) return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) -@overload -def beta(a: ArrayLike, b: ArrayLike) -> Array: ... -@overload -def beta(a: ArrayLike, *, y: ArrayLike) -> Array: ... - -@overload -def beta(*, x: ArrayLike, y: ArrayLike) -> Array: ... - -def beta(*args, **kwds): +def beta(a: ArrayLike, b: ArrayLike) -> Array: r"""The beta function JAX implementation of :obj:`scipy.special.beta`. @@ -220,24 +211,6 @@ def beta(*args, **kwds): - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.betaln` """ - # TODO(jakevdp): deprecation warning added 2024-06-10; finalize after 2024-09-10 - if 'x' in kwds: - msg = "The `x` parameter of jax.scipy.special.beta is deprecated, use `a` instead." - deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) - if 'a' in kwds: - raise TypeError("beta() got both parameter 'a' and parameter 'x'.") - kwds['a'] = kwds.pop('x') - if 'y' in kwds: - msg = "The `y` parameter of jax.scipy.special.beta is deprecated, use `b` instead." - deprecations.warn('jax-scipy-beta-args', msg, stacklevel=2) - if 'b' in kwds: - raise TypeError("beta() got both parameter 'b' and parameter 'y'.") - kwds['b'] = kwds.pop('y') - if extra := kwds.keys() - {'a', 'b'}: - raise TypeError(f"beta() got unexpected keyword arguments {list(extra)}") - return _beta(*args, **kwds) - -def _beta(a, b): a, b = promote_args_inexact("beta", a, b) sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b) return sign * lax.exp(betaln(a, b)) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index bd3bca5385b7..cb40ae291e76 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -23,7 +23,6 @@ import scipy.special as osp_special import jax -from jax._src import deprecations from jax._src import test_util as jtu from jax.scipy import special as lsp_special @@ -252,22 +251,8 @@ def testBetaParameterDeprecation(self): lsp_special.beta(1, 1) lsp_special.beta(1, b=1) lsp_special.beta(a=1, b=1) - if deprecations.is_accelerated('jax-scipy-beta-args'): - with self.assertRaises(ValueError): - lsp_special.beta(x=1, y=1) - else: - with self.assertWarns(DeprecationWarning): - lsp_special.beta(1, y=1) - with self.assertWarns(DeprecationWarning): - lsp_special.beta(a=1, y=1) - with self.assertWarns(DeprecationWarning): - lsp_special.beta(x=1, b=1) - with self.assertWarns(DeprecationWarning): - lsp_special.beta(x=1, y=1) - with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): - lsp_special.beta(1, x=1) - with self.assertRaises(TypeError), self.assertWarns(DeprecationWarning): - lsp_special.beta(b=1, y=1) + with self.assertRaises(TypeError): + lsp_special.beta(x=1, y=1) if __name__ == "__main__": From 542cb2e57e1bc928c19d84556a25ca44d8cc266a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 6 Nov 2024 09:29:28 -0800 Subject: [PATCH 215/698] Fix a bug in jax.scipy.stats.rankdata leading to breakage with shape polymorphism. PiperOrigin-RevId: 693755546 --- jax/_src/scipy/stats/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index f7b28d3ac301..65c457f79cc8 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -198,7 +198,7 @@ def rankdata( return jnp.apply_along_axis(rankdata, axis, a, method) arr = jnp.ravel(a) - arr, sorter = jax.lax.sort_key_val(arr, jnp.arange(len(arr))) + arr, sorter = jax.lax.sort_key_val(arr, jnp.arange(arr.size)) inv = invert_permutation(sorter) if method == "ordinal": @@ -207,7 +207,7 @@ def rankdata( dense = obs.cumsum()[inv] if method == "dense": return dense - count = jnp.nonzero(obs, size=arr.size + 1, fill_value=len(obs))[0] + count = jnp.nonzero(obs, size=arr.size + 1, fill_value=obs.size)[0] if method == "max": return count[dense] if method == "min": From 3df204a457d240650511d5f5baa7f4dd0e8aa7c9 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Wed, 6 Nov 2024 10:09:35 -0800 Subject: [PATCH 216/698] [Mosaic] Verify that tpu.sem_wait semaphore rank is zero Since we only wait on one semaphore, we should enforce this in the verifier. PiperOrigin-RevId: 693770055 --- jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 3bd4f651c0fc..3affd31e51d6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -616,6 +616,7 @@ def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> { ); let results = (outs); let assemblyFormat = [{ $semaphore `,` $amount attr-dict `:` type($semaphore)}]; + let hasVerifier = 1; } def TPU_AllocaSemaphoreOp : TPU_Op<"sem_alloc"> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 9a7f4f8a53e1..6f690f6a0fcb 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -820,6 +820,14 @@ LogicalResult SemaphoreSignalOp::verify() { return success(); } +LogicalResult SemaphoreWaitOp::verify() { + auto sem_type = getMemRefType(getSemaphore()); + if (sem_type.getRank() != 0) { + return emitOpError("Semaphore reference must be rank 0"); + } + return success(); +} + LogicalResult EnqueueDMAOp::verify() { auto source_sem = getSourceSemaphore(); if (source_sem) { From b6f5c95a5ae7f24f75bd70a2f3fe2b6c931be691 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 6 Nov 2024 10:28:59 -0800 Subject: [PATCH 217/698] [Pallas:TPU] Fix some stale/wrong skip conditions. Surprised that we didn't test f32 dot_general on TPU (?) Even tpu_ops_test doesn't exercise it. PiperOrigin-RevId: 693777426 --- tests/pallas/ops_test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 70ced6eb2be6..318df0b0bfcf 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -525,9 +525,6 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): tol = 1e-6 elif name == "exp2": tol = 1e-6 - elif jtu.test_device_matches(["tpu"]): - if not jtu.is_device_tpu_at_least(version=5) and False: - self.skipTest("TODO: not implemented on TPU v{3,4}") def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) @@ -1413,7 +1410,7 @@ def test_dot(self, size, dtype, trans_x, trans_y): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - if jtu.test_device_matches(["tpu"]): + if jtu.test_device_matches(["tpu"]) and trans_x: self.skipTest("Not implemented: Transposed LHS") @functools.partial( From 4e45a6d94d73ef8349e45e27fea0cf07f08060d6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Nov 2024 11:08:34 -0800 Subject: [PATCH 218/698] Remove some obsolete deprecation registrations PiperOrigin-RevId: 693793727 --- jax/_src/deprecations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 962244a321a9..c7a956068981 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -125,8 +125,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register("jax-numpy-astype-complex-to-real") register("jax-numpy-array-none") -register('jax-scipy-beta-args') -register('tracer-hash') register('jax-numpy-reshape-newshape') register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') From 1af3b01c1c0425db1b6529b17316a0f3e439120d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Nov 2024 11:18:11 -0800 Subject: [PATCH 219/698] register_dataclass: allow marking static fields via field(static=True) --- CHANGELOG.md | 3 ++ jax/_src/tree_util.py | 74 ++++++++++++++++++++++++++++++++--------- tests/tree_util_test.py | 60 +++++++++++++++++++-------------- 3 files changed, 97 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bafe551100d5..0af10bd4cb2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for passing compilation options to XLA. For the moment it's undocumented and may be in flux. + * {func}`jax.tree_util.register_dataclass` now allows metadata fields to be + declared inline via {func}`dataclasses.field`. See the function documentation + for examples. ## jax 0.4.35 (Oct 22, 2024) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 474bdfe4ec04..58f9d7862dcd 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -927,8 +927,8 @@ class that defines how it could be flattened with keys. @export def register_dataclass( nodetype: Typ, - data_fields: Sequence[str], - meta_fields: Sequence[str], + data_fields: Sequence[str] | None = None, + meta_fields: Sequence[str] | None = None, drop_fields: Sequence[str] = (), ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. @@ -945,24 +945,33 @@ def register_dataclass( attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among ``meta_fields`` or ``data_fields``. - meta_fields: auxiliary data field names. These fields *must* contain static, - hashable, immutable objects, as these objects are used to generate JIT cache - keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or - :class:`numpy.ndarray` objects. - data_fields: data field names. These fields *must* be JAX-compatible objects - such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or - pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be - ``None``, as this is recognized by JAX as an empty pytree. + meta_fields: metadata field names: these are attributes which will be treated as + {term}`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is + optional only if ``nodetype`` is a dataclass, in which case individual fields can + be marked static via :func:`dataclasses.field` (see examples below). + Metadata fields *must* be static, hashable, immutable objects, as these objects + are used to generate JIT cache keys. In particular, metadata fields cannot contain + :class:`jax.Array` or :class:`numpy.ndarray` objects. + data_fields: data field names: these are attributes which will be treated as non-static + when this pytree is passed to :func:`jax.jit`. ``data_fields`` is optional only if + ``nodetype`` is a dataclass, in which case fields are assumed data fields unless + marked via :func:`dataclasses.field` (see examples below). + Data fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array` + or :class:`numpy.ndarray`), scalars, or pytrees whose leaves are arrays or scalars. + Note that ``None`` is a valid data field, as JAX recognizes this as an empty pytree. Returns: The input class ``nodetype`` is returned unchanged after being added to JAX's - pytree registry. This return value allows ``register_dataclass`` to be partially - evaluated and used as a decorator as in the example below. + pytree registry, so that :func:`register_dataclass` can be used as a decorator. Examples: + In JAX v0.4.35 or older, you must specify ``data_fields`` and ``meta_fields`` + in order to use this decorator: + + >>> import jax >>> from dataclasses import dataclass >>> from functools import partial - >>> + ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) @@ -976,7 +985,26 @@ def register_dataclass( >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') - Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`: + Starting in JAX v0.4.36, the ``data_fields`` and ``meta_fields`` arguments are optional + for :func:`~dataclasses.dataclass` inputs, with fields defaulting to ``data_fields`` + unless marked as static using `static` metadata in :func:`dataclasses.field`. + + >>> import jax + >>> from dataclasses import dataclass, field + ... + >>> @jax.tree_util.register_dataclass + ... @dataclass + ... class MyStruct: + ... x: jax.Array # defaults to non-static data field + ... y: jax.Array # defaults to non-static data field + ... op: str = field(metadata=dict(static=True)) # marked as static meta field. + ... + >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') + >>> m + MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') + + Once this class is registered, it can be used with functions in :mod:`jax.tree` and + :mod:`jax.tree_util`: >>> leaves, treedef = jax.tree.flatten(m) >>> leaves @@ -987,7 +1015,8 @@ def register_dataclass( MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add') In particular, this registration allows ``m`` to be passed seamlessly through code - wrapped in :func:`jax.jit` and other JAX transformations: + wrapped in :func:`jax.jit` and other JAX transformations, with ``data_fields`` being + treated as dynamic arguments, and ``meta_fields`` being treated as static arguments: >>> @jax.jit ... def compiled_func(m): @@ -999,6 +1028,21 @@ def register_dataclass( >>> compiled_func(m) Array([1., 2., 3.], dtype=float32) """ + if data_fields is None or meta_fields is None: + if data_fields is not None or meta_fields is not None: + raise TypeError("register_dataclass: data_fields and meta_fields must both be specified" + f" when either is specified. Got {data_fields=} {meta_fields=}.") + if not dataclasses.is_dataclass(nodetype): + raise TypeError("register_dataclass: data_fields and meta_fields are required when" + f" nodetype is not a dataclass. Got {nodetype=}.") + data_fields = [f.name for f in dataclasses.fields(nodetype) + if not f.metadata.get('static', False)] + meta_fields = [f.name for f in dataclasses.fields(nodetype) + if f.metadata.get('static', False)] + + assert meta_fields is not None + assert data_fields is not None + # Store inputs as immutable tuples in this scope, because we close over them # for later evaluation. This prevents potentially confusing behavior if the # caller were to pass in lists that are later mutated. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 378e3803bba2..a3a8bc96eae0 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -17,7 +17,6 @@ import functools import pickle import re -from typing import TypeVar from absl.testing import absltest from absl.testing import parameterized @@ -142,27 +141,6 @@ def tree_unflatten(cls, meta, data): data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data)) return FlatCache(None, leaves=data, treedef=meta) -_T = TypeVar("_T") - - -# Inspired by Flax. -def pytree_node_dataclass(clz: _T, **kwargs) -> _T: - data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore - meta_fields = [] - data_fields = [] - for field_info in dataclasses.fields(data_clz): - is_pytree_node = field_info.metadata.get("pytree_node", True) - if is_pytree_node: - data_fields.append(field_info.name) - else: - meta_fields.append(field_info.name) - - jax.tree_util.register_dataclass( - data_clz, data_fields, meta_fields - ) - - return data_clz - @tree_util.register_static class StaticInt(int): @@ -231,16 +209,18 @@ def __eq__(self, other): "PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))", ) -@pytree_node_dataclass +@jax.tree_util.register_dataclass +@dataclasses.dataclass class ADataclass: x: tuple[int, int] y: int -@pytree_node_dataclass +@jax.tree_util.register_dataclass +@dataclasses.dataclass class ADataclassWithMeta: x: tuple[int, int] y: int - z: int = dataclasses.field(metadata={"pytree_node": False}) + z: int = dataclasses.field(metadata={"static": True}) TREES += ( (ADataclass(x=(1, 2), y=3),), @@ -1294,6 +1274,36 @@ def test_tree_unflatten(self): class RegistrationTest(jtu.JaxTestCase): + def test_register_dataclass_with_field_specifier(self): + @tree_util.register_dataclass + @dataclasses.dataclass + class Foo: + x: int + y: int = dataclasses.field(metadata=dict(static=True)) + + f = Foo(2, 3) + self.assertLen(jax.tree.leaves(f), 1) + + def test_register_dataclass_field_errors(self): + class Foo: # not a dataclass + x: int + y: int + + msg = ("register_dataclass: data_fields and meta_fields are required" + " when nodetype is not a dataclass. Got nodetype=") + with self.assertRaisesRegex(TypeError, msg): + tree_util.register_dataclass(Foo) + + msg = ("register_dataclass: data_fields and meta_fields must both be specified"\ + r" when either is specified. Got data_fields=\['x'\] meta_fields=None.") + with self.assertRaisesRegex(TypeError, msg): + tree_util.register_dataclass(Foo, data_fields=['x']) + + msg = ("register_dataclass: data_fields and meta_fields must both be specified"\ + r" when either is specified. Got data_fields=None meta_fields=\['y'\].") + with self.assertRaisesRegex(TypeError, msg): + tree_util.register_dataclass(Foo, meta_fields=['y']) + def test_register_dataclass_missing_fields(self): @dataclasses.dataclass class Foo: From ea7683f05829dc9f2a5366ed029fe9754af709e6 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 13:56:16 -0600 Subject: [PATCH 220/698] Add GHA workflow for opening PRs upstream (#116) * Add file for opening PRs upstream * Add HEAD_REF as environment variable * Fill out code for making a new branch and opening a PR to upstream * Add names for steps * Fix yaml * Fix yaml again * Leave a comment on the old PR linking to the new one * Add proper permissions for creating banches and opening PRs * Fix YAML --- .github/workflows/rocm-open-upstream-pr.yml | 39 +++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/rocm-open-upstream-pr.yml diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml new file mode 100644 index 000000000000..09dfd06e907e --- /dev/null +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -0,0 +1,39 @@ +name: ROCm Open Upstream PR +on: + pull_request: + types: [ labeled ] + branches: [ rocm-main ] +jobs: + open-upstream: + if: ${{ github.event.label.name == 'open-upstream' }} + permissions: + contents: write + pull-requests: write + runs-on: ubuntu-latest + outputs: + new-pr-link: ${{ steps.create-pr.outputs.link }} + env: + NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream" + NEW_PR_TITLE: "[ROCM] ${{ github.event.pull_request.title }}" + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Rebase code to main + run: | + git checkout -b $NEW_BRANCH_NAME ${{ github.head_ref }} + git rebase --onto main + git push origin HEAD + # TODO: Change the base of the PR to upstream main + - name: Create a PR to upstream + id: create-pr + run: | + echo link=$(gh pr create --repo rocm/jax --base main --head $NEW_BRANCH_NAME --title "$NEW_PR_TITLE" --body "${{ github.event.pull_request.body }}") >> "$GITHUB_OUTPUT" + comment-link: + needs: open-upstream + permissions: + pull-requests: write + runs-on: ubuntu-latest + steps: + - name: Leave comment on old PR + run: gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body ${{ needs.open-upstream.outputs.new-pr-link.link }} + From 69e93e5a81e379b9291770c999ff81c86b955673 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 14:03:26 -0600 Subject: [PATCH 221/698] Create a new branch when merging upstream main to rocm-main (#128) --- .../workflows/rocm-nightly-upstream-sync.yml | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index a15e49c2e87b..98f3d2cfa39c 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -6,6 +6,8 @@ on: workflow_dispatch: schedule: - cron: '0 6 * * 1-5' +env: + SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }} jobs: sync-main: permissions: @@ -15,12 +17,28 @@ jobs: - run: gh repo sync rocm/jax -b main env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + create-sync-branch: + needs: sync-main + permissions: + contents: write + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Create branch + run: | + git checkout -b $SYNC_BRANCH_NAME main + git push origin HEAD open-sync-pr: + needs: create-sync-branch permissions: pull-requests: write runs-on: ubuntu-latest steps: - run: | - gh pr create --repo $GITHUB_REPOSITORY --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" + gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + From dc33a28028102bcdbf8bed905c44a1b314180aaf Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 6 Nov 2024 12:18:34 -0800 Subject: [PATCH 222/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0f6331b1881ae34c8b1cd59580900d556bc8305c. PiperOrigin-RevId: 693819727 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3dc24da2559b..9190c136f6e8 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "5a9f79f295ba8d16afce24ea8724da525b8eb87d" -XLA_SHA256 = "83e516dd8f7c61541aa9e2cba7fe480166ea23f28a41fed445fef4c5b6d45519" +XLA_COMMIT = "0f6331b1881ae34c8b1cd59580900d556bc8305c" +XLA_SHA256 = "1e4e4317750b2bb2845c6138aaa96b0d94249484d23e9c799d2dd6ecd4b8dd3c" def repo(): tf_http_archive( From fbd409db5eb30ade445df15ff629f57274dbc8f1 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 15:31:57 -0600 Subject: [PATCH 223/698] Fix upstream sync checkout (#130) * Checkout main before trying to switch to it * Fix the checkout command --- .github/workflows/rocm-nightly-upstream-sync.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index 98f3d2cfa39c..f29bef3bc46c 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -29,7 +29,8 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Create branch run: | - git checkout -b $SYNC_BRANCH_NAME main + git checkout origin/main + git checkout -b $SYNC_BRANCH_NAME git push origin HEAD open-sync-pr: needs: create-sync-branch From 350e04d89a71dd8a8fbbcfb66b7f1b8ce795f121 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 16:38:11 -0600 Subject: [PATCH 224/698] Add git fetch (#132) --- .github/workflows/rocm-nightly-upstream-sync.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index f29bef3bc46c..e915ccba390d 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -29,6 +29,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Create branch run: | + git fetch git checkout origin/main git checkout -b $SYNC_BRANCH_NAME git push origin HEAD From 506671291a25068f6a2c2bd45494038726cd3586 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 7 Nov 2024 06:41:04 -0800 Subject: [PATCH 225/698] [Mosaic GPU] Fix the ordering of transforms in async_copy Previously we didn't really fully discharge squeezing the indexed dims before applying other GMEM transforms, leading to potential failures because they were not anticipating the increased rank. PiperOrigin-RevId: 694098739 --- jax/experimental/mosaic/gpu/core.py | 52 +++++++++++++++---- .../mosaic/gpu/examples/flash_attention.py | 9 +--- tests/mosaic/gpu_test.py | 24 +++++++++ 3 files changed, 68 insertions(+), 17 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index b08b40aa1860..409a87eb9af7 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -133,6 +133,14 @@ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: raise NotImplementedError("Subclasses should override this method") + def batch(self, leading_rank: int) -> 'MemRefTransform': + """Returns a transform that accepts a ref with the extra `leading_rank` dims. + + The returned transform should leave the leading dimensions unchanged and + only apply to the suffix of the shape. + """ + raise NotImplementedError("Subclasses should override this method") + @dataclasses.dataclass(frozen=True) class TileTransform(MemRefTransform): @@ -198,6 +206,9 @@ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: *self.tiling, ) + def batch(self, leading_rank: int) -> MemRefTransform: + return self + @dataclasses.dataclass(frozen=True) class TransposeTransform(MemRefTransform): @@ -217,6 +228,11 @@ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: return tuple(shape[p] for p in self.permutation) + def batch(self, leading_rank: int) -> MemRefTransform: + return TransposeTransform( + (*range(leading_rank), *(d + leading_rank for d in self.permutation)) + ) + OnDeviceProfiler = profiler.OnDeviceProfiler @@ -388,16 +404,26 @@ def async_copy( dyn_base_indices = tuple( c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices ) + squeezed_dims = [i for i, squeezed in enumerate(is_squeezed) if squeezed] + sliced_dims = [i for i, squeezed in enumerate(is_squeezed) if not squeezed] + # Indexing is really slicing + squeezing, and user transforms are meant to + # apply after that. However, we actually have to apply the indexing last + # (it's fused into the TMA) and so we need to commute it with all the user + # transforms. For slicing this is done using transform_index and + # transform_shape. For squeezing we actually move all the squeezed dims to + # the front, and then batch each transform, making it ignore the extra dims. + if squeezed_dims: + gmem_transform = (TransposeTransform((*squeezed_dims, *sliced_dims)), + *(t.batch(len(squeezed_dims)) for t in gmem_transform)) + slice_shape = tuple(slice_shape) for t in gmem_transform: dyn_base_indices = t.transform_index(dyn_base_indices) slice_shape = t.transform_shape(slice_shape) - for dim, squeezed in enumerate(is_squeezed): - if squeezed: - smem_ref = utils.memref_unsqueeze(smem_ref, dim) - smem_ref_ty = ir.MemRefType(smem_ref.type) - if slice_shape != tuple(smem_ref_ty.shape): + smem_ref_ty = ir.MemRefType(smem_ref.type) + # We moved all squeezed dims to the front. + if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape): raise ValueError( "Expected the SMEM reference to have the same shape as the" f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" @@ -411,6 +437,7 @@ def async_copy( dyn_base_indices = list(dyn_base_indices) slice_shape = list(slice_shape) + assert all(d == 1 for d in slice_shape[:len(squeezed_dims)]) collective_size = 1 if collective is not None: if isinstance(collective, gpu.Dimension): @@ -418,13 +445,16 @@ def async_copy( collective_size = math.prod(self.cluster_size[d] for d in collective) if collective_size > 1: def partition_dim(dim: int, idx: ir.Value, num_chunks: int): + # No need to partition squeezed dims. They don't even exist in smem_ref. + assert dim >= len(squeezed_dims) nonlocal smem_ref slice_shape[dim] //= num_chunks block_offset = arith.muli(idx, c(slice_shape[dim], index)) dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) smem_ref = utils.memref_slice( smem_ref, - (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) + (slice(None),) * (dim - len(squeezed_dims)) + + (utils.ds(block_offset, slice_shape[dim]),), ) stride = 1 idx = c(0, index) @@ -440,10 +470,12 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): rem_collective_size = 1 break elif rem_collective_size % slice_size == 0: - dim_idx = arith.remui(idx, c(slice_size, index)) - partition_dim(dim, dim_idx, slice_size) - idx = arith.divui(idx, c(slice_size, index)) - rem_collective_size //= slice_size + # This is an optimization and it lets us skip squeezed dims. + if slice_size > 1: + dim_idx = arith.remui(idx, c(slice_size, index)) + partition_dim(dim, dim_idx, slice_size) + idx = arith.divui(idx, c(slice_size, index)) + rem_collective_size //= slice_size else: break # We failed to partition the leading dimensions. del idx # We overwrote the block index in the loop. diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 04a64098ff17..808afae8fc05 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -300,9 +300,7 @@ def kv_loop(kv_step, carry): with ir.InsertionPoint(if_compute.else_block): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) with single_thread(per_block=False): - k_tr = ( - TileTransform(tiling), TransposeTransform((0, 2, 1, 3, 4)), - ) + k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): @@ -396,10 +394,7 @@ def kv_copy_init(slot, kv_seq_base): with single_thread(per_block=False): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) - k_tr = ( - TileTransform(tiling), - TransposeTransform((0, 2, 1, 3, 4)), - ) + k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)): ctx.async_copy( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 062d2de02bac..03e3257b45bd 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1060,6 +1060,30 @@ def kernel(ctx, src, dst, scratch): y = f(x) np.testing.assert_array_equal(y, x) + def test_tma_load_indexed_tiled(self): + shape = (128, 2, 128) + tiling = mgpu.TileTransform((32, 32)) + def kernel(ctx, src, dst, scratch): + tmp, barrier = scratch + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + barrier=barrier, + gmem_transform=tiling, + gmem_slice=(slice(None), 1, slice(None)), + ) + barrier.wait() + ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_transform=tiling) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) + smem = ( + jax.ShapeDtypeStruct((4, 4, 32, 32), jnp.float32), + mgpu.TMABarrier(), + ) + out_shape = jax.ShapeDtypeStruct((128, 128), jnp.float32) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, out_shape, smem) + np.testing.assert_array_equal(f(x), x[:, 1, :]) + @parameterized.product( swizzle=(None, 128), dtype=(jnp.float16, jnp.float32), From f8dba3c8a4ba5185e267761becac7ce9a11360a3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 7 Nov 2024 07:02:42 -0800 Subject: [PATCH 226/698] [Pallas:MGPU] Add support for multiple heads in attention PiperOrigin-RevId: 694104006 --- .../pallas/ops/gpu/attention_mgpu.py | 39 +++++++++++++------ tests/pallas/mgpu_attention_test.py | 8 ++-- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 1b240305aeff..56db5379d5e2 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -32,6 +32,14 @@ class TuningConfig: block_kv: int max_concurrent_steps: int + def __post_init__(self): + if self.block_q % 64: + raise ValueError(f"{self.block_q=} must be a multiple of 64") + if self.block_kv % 64: + raise ValueError(f"{self.block_kv=} must be a multiple of 64") + if self.max_concurrent_steps < 2: + raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") + @functools.partial(jax.jit, static_argnames=["config"]) def attention(q, k, v, config: TuningConfig): @@ -46,14 +54,16 @@ def attention(q, k, v, config: TuningConfig): raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)") if (dtype := q.dtype) != k.dtype or dtype != v.dtype: raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}") - if batch_size != 1 or num_q_heads != 1 or num_kv_heads != 1: - raise NotImplementedError( - "Only batch_size=1, num_q_heads=1, and num_kv_heads=1 are supported," - f" got: {batch_size=}, {num_q_heads=}, {num_kv_heads=}" - ) + if num_q_heads % num_kv_heads: + raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}") + q_heads_per_kv_head = num_q_heads // num_kv_heads + if head_dim % 64: + raise ValueError(f"{head_dim=} must be divisible by 64") + if batch_size != 1: + raise NotImplementedError(f"Only batch_size=1 is supported, got: {batch_size=}") if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") - q, k, v = map(lambda x: x[0, :, 0, :], (q, k, v)) + q, k, v = map(lambda x: x[0], (q, k, v)) max_concurrent_steps = min( config.max_concurrent_steps, kv_seq_len // config.block_kv ) @@ -74,9 +84,10 @@ def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] q_seq_base = lax.axis_index("q") * (2 * block_q) + wg_idx * block_q + q_head = lax.axis_index("heads") plgpu.copy_gmem_to_smem( - q_ref.at[pl.ds(q_seq_base, block_q)], + q_ref.at[pl.ds(q_seq_base, block_q), q_head], qo_smem, barrier=q_barriers.at[wg_idx], ) @@ -146,21 +157,22 @@ def _wait(): qo_smem[...] = acc.astype(dtype) plgpu.commit_smem() plgpu.copy_smem_to_gmem( - qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)], + qo_smem, out_ref.at[pl.ds(q_seq_base, block_q), q_head], ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) def _memory_wg(): plgpu.set_max_registers(40, action="decrease") + kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) for i in range(max_concurrent_steps): - s = pl.ds(i * block_kv, block_kv) + s = (pl.ds(i * block_kv, block_kv), kv_head) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], barrier=k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], barrier=v_barriers.at[i]) def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) - s = pl.ds(tma_step * block_kv, block_kv) + s = (pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barrier) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], barrier=k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barrier) @@ -179,7 +191,10 @@ def run(refs): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") mesh = plgpu.GPUMesh( - grid=(num_q_tiles,), num_threads=3, axis_names=("q", "wg"), approx_math=True, + grid=(num_q_tiles, num_q_heads), + num_threads=3, + axis_names=("q", "heads", "wg"), + approx_math=True, ) @pl.core_map(mesh) def _kernel_entry(): @@ -212,7 +227,7 @@ def _kernel_entry(): ) _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) - return out[None, :, None, :] + return out[None] @jax.jit diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index 87d58e96ad40..32319e45e2dc 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -52,11 +52,9 @@ def setUp(self): batch_size=(1,), q_seq_len=(4096,), kv_seq_len=(4096,), - num_q_and_kv_heads=((1, 1),), - # TODO(apaszke): Enable once we support many heads. - # num_q_and_kv_heads=((4, 1), # MQA - # (6, 3), # GQA - # (4, 4),), # MHA + num_q_and_kv_heads=((4, 1), # MQA + (6, 3), # GQA + (4, 4),), # MHA head_dim=(64, 128, 256), ) def test_flash_attention( From de06584d9895f94af831448c03b8a0d8cbebc552 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 7 Nov 2024 07:08:07 -0800 Subject: [PATCH 227/698] [Mosaic GPU] Introduce a more flexible layout system So far all of our layouts have been tailored to a limited set of use cases we've tried so far, but they're still not general enough to handle all of the register layouts needed for WGMMA or mixed precision matmuls (incl. intermediate steps during conversions). Instead of adding more special cases, I decided to adopt XLA tiled layouts and they do seem to work quite well! This change only lays the groundwork for the new layout system. Future changes will build upon them to add new features and eventually replace `WGMMA_LAYOUT` altogether. PiperOrigin-RevId: 694105514 --- .../mosaic/gpu/fragmented_array.py | 348 ++++++++++++++++-- jax/experimental/mosaic/gpu/utils.py | 9 + tests/mosaic/gpu_test.py | 25 ++ 3 files changed, 353 insertions(+), 29 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 5639f7356ea9..a9d12706ff47 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -14,10 +14,12 @@ # ============================================================================== """Utilities for code generator.""" +from __future__ import annotations + import dataclasses import functools import math -from typing import Callable +from typing import Sequence, TypeVar, Iterable import jax from jaxlib.mlir import ir @@ -35,44 +37,239 @@ # mypy: ignore-errors +T = TypeVar("T") WARPGROUP_SIZE = utils.WARPGROUP_SIZE +WARP_SIZE = 32 +WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE c = utils.c @dataclasses.dataclass(frozen=True) -class WGSplatFragLayout: - """A fragmented array where all the values are equal represented as a register per thread. +class Tiling: + """A tiling expression describing a permutation of elements of an nd-array. - FragmentedArrays in this layout can be are always the result of a - splat, each thread in the warpgroup has a single copy of the value, - while the FragmentedArray pretends it has whatever shape the user - wants. This means we can trivially broadcast, reshape and do - elementwise operations with all other layouts. + To apply one level of tiling to an array, each of the trailing dimensions (up + to the rank of the tile) is unfolded into two dimensions: first equal to the + ratio of the dimension size and the tile size, and second equal to the tile + size. Then, all newly unfolded minor dimensions are transposed to appear at + the end. - Examples: + This expression describes multi-level tiling, by applying each element of + `tiles` in sequence to the array. - To load a value in - ``` - FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2)) - ``` + See https://openxla.org/xla/tiled_layout for a more detailed explanation. + """ + tiles: tuple[tuple[int, ...], ...] + + def __post_init__(self): + max_rank = math.inf + for tile in self.tiles: + if not tile: + raise ValueError("Tiles must not be empty") + if len(tile) > max_rank: + raise ValueError("Tile ranks must be non-increasing") + max_rank = len(tile) + if any(d <= 0 for d in tile): + raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") + + def __str__(self): + return f"Tiling({''.join(map(str, self.tiles))})" + + def tile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Computes the shape of an array after tiling.""" + def fail(): + raise ValueError(f"Tiling {self.tiles} does not apply to shape {shape}") + for tile in self.tiles: + if len(tile) > len(shape): + fail() + untiled_dims, tiled_dims = shape[:-len(tile)], shape[-len(tile):] + if any(s % t != 0 for s, t in zip(tiled_dims, tile)): + fail() + shape = (*untiled_dims, *(d // t for d, t in zip(tiled_dims, tile)), *tile) + return shape + + def untile_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Computes the shape of an array before tiling from its tiled shape.""" + def fail(): + raise ValueError("Shape does not look like it's been tiled?") + for tile in reversed(self.tiles): + if len(tile) > len(shape): + fail() + untiled_dims = shape[:-2 * len(tile)] + tiled_dims = shape[-2 * len(tile):-len(tile)] + tiling_dims = shape[-len(tile):] + if tiling_dims != tile: + fail() + shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) + return shape + + def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: + """Computes the strides of an array after tiling.""" + for tile in self.tiles: + untiled, tiled = strides[:-len(tile)], strides[-len(tile):] + strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) + return strides + + +def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: + """Like built-in enumerate, but returns negative indices into the sequence.""" + offset = len(elems) + for i, e in enumerate(elems): + yield i - offset, e - A shape is always provided for sanity check reasons. +@dataclasses.dataclass(frozen=True) +class TiledLayout: + """A FragmentedArray layout derived from a tiling expression. + + A logical array is transformed according to the tiling expression, and then + split across warps (within a warpgroup), lanes, and vectorized according to + the dimension indices. All dimension indices must be negative and should refer + to the dimensions after tiling is applied. + + Note that warp_dim and vector_dim could be sets as well, but we don't have a + usecase for that yet. + + To better understand this layout, consider the example of WGMMA-related tiling + from https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d as + applied to a 128x128 array. The corresponding TiledLayout has a tiling of: + + (64, 8)(16, 8)(8, 8)(1, 2) + + and warp_dim=-8, lane_dims={-4, -3}, vector_dim=-1. + + We begin by applying the tiling (note that it always applies to a suffix): + + Tiled shape Remaining tiling actions + =========================================================================== + 128 128 (64, 8)(16, 8)(8, 8)(1, 2) + 2 16 64 8 (16, 8)(8, 8)(1, 2) + 2 16 4 1 16 8 (8, 8)(1, 2) + 2 16 4 1 2 1 8 8 (1, 2) + 2 16 4 1 2 1 8 4 1 2 + + The last expression is our final shape. At this stage, we're ready to + interpret the dimensions: warp_dim=-8 means that the 8-th dimension from the + end is partitioned over 4 warps in a warpgroup (and so it must be of size 4). + lane_dims={-4, -3} indicate that those two dimensions are partitioned over + the lanes within a warp (their product must be equal to 32, i.e. warp size). + Finally, vector_dim=-1 indicates that each (logical) register is a vector + containing 2 elements (there are no shape restrictions here). + + Given the above, the shape of the (logical) register array used to represent + the array in each thread is: (2, 16, 1, 1, 2, 1, 1, 1, 1, 1). We have set all + the dimensions above to 1, since each thread is a member of a single warp, + a single lane, and the elements along the vectorized dimension are represented + by a single (logical) register. """ + tiling: Tiling + warp_dim: int + lane_dims: frozenset[int] + vector_dim: int - shape: tuple[int, ...] = () + def __post_init__(self): + if not self.tiling.tiles: + raise ValueError("Tiling must have at least one tile") + min_shape = self.tiling.tiles[0] + min_tiled_shape = self.tiling.tile_shape(min_shape) + dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} + if len(dims_set) != len(self.lane_dims) + 2: + raise ValueError + for d in dims_set: + if d >= 0: + raise ValueError("All dimensions must be negative") + if d < -(len(min_tiled_shape) - len(min_shape)): + raise ValueError("Dimension out of range") + if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: + raise ValueError + if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: + raise ValueError - def can_broadcast_to(self, shape) -> bool: - """Check that the shape can be broadcast. + @functools.cached_property + def tiled_tiling_shape(self) -> tuple[int, ...]: + """The shape of the suffix of the array after tiling. - Only dimensions of size 1 can be broadcast. All other dimensions - must be the same as the argument shape. + We only allow our repeated tiling actions to further subdivide the + dimensions created by previous tiling actions (except for the first one), + so the tiled shape always ends with this suffix, no matter what array shape + it's applied to. """ - return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + return self.tiling.tile_shape(self.tiling.tiles[0]) - def thread_idxs(self, shape): - assert shape == self.shape - raise NotImplementedError + @property + def vector_length(self) -> int: + return self.tiled_tiling_shape[self.vector_dim] + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + tiled_shape = list(self.tiling.tile_shape(shape)) + tiled_shape[self.warp_dim] = 1 + for d in self.lane_dims: + tiled_shape[d] = 1 + tiled_shape[self.vector_dim] = 1 + return tuple(tiled_shape) + + def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the logical shape of an array given its register array shape. + + Inverse to `registers_shape`. + """ + tiled_tiling = self.tiled_tiling_shape + shape = list(shape) + shape[self.warp_dim] = WARPS_IN_WARPGROUP + for d in self.lane_dims: + shape[d] = tiled_tiling[d] + shape[self.vector_dim] = tiled_tiling[self.vector_dim] + return self.tiling.untile_shape(tuple(shape)) + + def lane_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = tuple( + d if i in self.lane_dims else 1 + for i, d in enumerate_negative(self.tiled_tiling_shape) + ) + assert math.prod(tiled_shape) == WARP_SIZE + lane_strides = utils.get_contiguous_strides(tiled_shape) + lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) + return tuple( + arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) + for stride, size in zip(lane_strides, tiled_shape) + ) + + def warp_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = tuple( + d if i == self.warp_dim else 1 + for i, d in enumerate_negative(self.tiled_tiling_shape) + ) + assert math.prod(tiled_shape) == WARPS_IN_WARPGROUP + warp_idx = arith.remui( + arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), + c(WARPS_IN_WARPGROUP, i32), + ) + indices = [arith.constant(i32, 0)] * len(tiled_shape) + indices[self.warp_dim] = warp_idx + return tuple(indices) + + +def _tiled_wgmma_layout(shape: tuple[int, ...]): + """Returns the tiled layout relevant for WGMMA operations. + + The tiled layout is equivalent to one described here in PTX documentation: + https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d + + This tiled layout is equivalent to WGMMAFragLayout and will subsume it. + """ + if len(shape) != 2: + raise ValueError(f"Shape {shape} is not 2D") + if shape[0] % 64 != 0 or shape[1] % 8 != 0: + raise ValueError(f"Shape {shape} is not a multiple of 64x8") + return TiledLayout( + Tiling(((64, 8), (16, 8), (8, 8), (1, 2))), + warp_dim=-8, + lane_dims=frozenset((-4, -3)), + vector_dim=-1, + ) @dataclasses.dataclass(frozen=True) @@ -96,6 +293,11 @@ def thread_idxs(self, shape): row = arith.addi(row_base, c(row_group + row_subgroup, index)) yield row, arith.addi(col_base, c(col_group, index)) + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + assert len(shape) == 2 + assert shape[0] % 64 == 0 and shape[1] % 8 == 0 + return (shape[0] // 64, shape[1] // 8, 2, 1) + @dataclasses.dataclass(frozen=True) class WGMMARowFragLayout: @@ -105,6 +307,42 @@ def thread_idxs(self, shape): raise NotImplementedError +@dataclasses.dataclass(frozen=True) +class WGSplatFragLayout: + """A fragmented array where all the values are equal represented as a register per thread. + + FragmentedArrays in this layout can be are always the result of a + splat, each thread in the warpgroup has a single copy of the value, + while the FragmentedArray pretends it has whatever shape the user + wants. This means we can trivially broadcast, reshape and do + elementwise operations with all other layouts. + + Examples: + + To load a value in + ``` + FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2)) + ``` + + A shape is always provided for sanity check reasons. + + """ + + shape: tuple[int, ...] = () + + def can_broadcast_to(self, shape) -> bool: + """Check that the shape can be broadcast. + + Only dimensions of size 1 can be broadcast. All other dimensions + must be the same as the argument shape. + """ + return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + + def thread_idxs(self, shape): + assert shape == self.shape + raise NotImplementedError + + @dataclasses.dataclass(frozen=True) class WGStridedFragLayout: """Convert the array to 1D and then shard across threads.""" @@ -162,7 +400,7 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout | TiledLayout WGMMA_LAYOUT = WGMMAFragLayout() @@ -230,6 +468,14 @@ def __init__( if _registers.size != 1: raise ValueError(f"Invalid register array shape: {_registers.shape}") + case TiledLayout(): + try: + self.layout.shape_from_registers_shape(_registers.shape) + except ValueError: + raise ValueError( + "Register array shape does not match the tiled layout" + ) from None + case _: raise NotImplementedError @@ -304,15 +550,21 @@ def shape(self): return shape case WGSplatFragLayout(shape=shape): return shape + case TiledLayout(): + return self.layout.shape_from_registers_shape(self.registers.shape) + case _: + raise NotImplementedError @property def mlir_dtype(self): reg_ty = self.registers.flat[0].type match self.layout: - case WGMMAFragLayout() | WGStridedFragLayout(): + case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type case WGMMARowFragLayout() | WGSplatFragLayout(): return reg_ty + case _: + raise NotImplementedError def to_layout(self, new_layout: FragmentedLayout): """Converts the fragmented array to the given layout. @@ -321,6 +573,17 @@ def to_layout(self, new_layout: FragmentedLayout): """ if self.layout == new_layout: return self + shape = self.shape + if len(shape) == 2 and shape[0] % 64 == 0 and shape[1] % 8 == 0: + tiled_layout = _tiled_wgmma_layout(shape) + if (self.layout == WGMMA_LAYOUT and new_layout == tiled_layout) or ( + self.layout == tiled_layout and new_layout == WGMMA_LAYOUT + ): + return FragmentedArray( + _registers=self.registers.reshape(new_layout.registers_shape(shape)), + _layout=new_layout, + _is_signed=self.is_signed, + ) if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" @@ -745,10 +1008,9 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") new_registers = np.empty_like(self.registers) match self.layout: - case WGMMAFragLayout(): - new_reg_ty = ir.VectorType.get((2,), new_dtype) - case WGStridedFragLayout(vec_size=vec_size): - new_reg_ty = ir.VectorType.get((vec_size,), new_dtype) + case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout(): + shape = ir.VectorType(self.registers.flat[0].type).shape + new_reg_ty = ir.VectorType.get(shape, new_dtype) case WGMMARowFragLayout() | WGSplatFragLayout(): new_reg_ty = new_dtype case _: @@ -916,6 +1178,8 @@ def store_untiled(self, ref: ir.Value): self._store_untiled_splat(ref) case WGStridedFragLayout(): self._store_untiled_wg_strided(ref) + case TiledLayout(): + self._store_untiled_tiled(ref) case _: raise NotImplementedError(self.layout) @@ -982,6 +1246,32 @@ def c(x): col = arith.addi(col_base, c(col_tile * 8 + col_idx)) memref.store(value, ref, [row, col]) + def _store_untiled_tiled(self, ref: ir.Value): + """Stores an array with a tiled layout. Not optimized at the moment.""" + i32 = ir.IntegerType.get_signless(32) + layout = self.layout + assert isinstance(layout, TiledLayout) + ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() + if ref_strides[layout.vector_dim] != 1: + raise NotImplementedError( + "Can't use vector stores with non-unit minormost stride" + ) + strides = layout.tiling.tile_strides(ref_strides) + ptr = utils.memref_ptr(ref) + # Fold warp and lane offsets into the pointer once, since they are dynamic. + dyn_strides = [arith.constant(i32, s) for s in strides] + def dyn_dot(x, y): + return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) + warp_offset = dyn_dot(layout.warp_indices(), dyn_strides) + lane_offset = dyn_dot(layout.lane_indices(), dyn_strides) + dyn_offset = arith.addi(warp_offset, lane_offset) + ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) + # All warp tile offsets are static and can be fused into the store. + for tile_idx, reg in np.ndenumerate(self.registers): + lin_idx = sum(i * s for i, s in zip(tile_idx, strides, strict=True)) + reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) + llvm.store(reg, reg_ptr) + def store_tiled(self, ref, swizzle: int | None): if self.layout != WGMMA_LAYOUT: raise NotImplementedError diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 87ffe09291fc..6d3fd54cf2f7 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -40,6 +40,7 @@ WARPGROUP_SIZE: int = 128 DYNAMIC = -9223372036854775808 +DYNAMIC32 = -2147483648 # pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes @@ -1036,3 +1037,11 @@ def is_signed(dtype: jax.typing.DTypeLike) -> bool | None: elif jnp.issubdtype(dtype, jnp.integer): return jnp.issubdtype(dtype, jnp.signedinteger) return None + + +def getelementptr( + ptr: ir.Value, indices: Sequence[ir.Value | int], dtype: ir.Type +) -> ir.Value: + static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices] + dyn_indices = [i for i in indices if not isinstance(i, int)] + return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 03e3257b45bd..9a3f8ccfdadd 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -29,6 +29,7 @@ from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax.experimental.mosaic.gpu import fragmented_array as fa import jax.numpy as jnp import numpy as np try: @@ -1583,5 +1584,29 @@ def kernel(ctx, i_gmem, o_gmem, _): del y # Make sure the destructor runs successfully. +class LayoutTest(TestCase): + + @parameterized.product( + shape=((128, 128), (64, 8), (64, 256)), + dtype=(jnp.int32, jnp.int16, jnp.int8), + ) + def test_wgmma_tiled_layout(self, shape, dtype): + def kernel(ctx, dst, _): + iota = iota_tensor(*shape, dtype) + tiled = iota.to_layout(fa._tiled_wgmma_layout(shape)) + # Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1) + self.assertEqual( + tiled.registers.shape, + (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1), + ) + self.assertEqual(tiled.shape, shape) + self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) + tiled.store_untiled(dst) + ty = jax.ShapeDtypeStruct(shape, dtype) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ()) + expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) + np.testing.assert_array_equal(f(), expected) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 9e7113e6ed2c5977b3df9cfe7e9a43757dd27f2d Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 7 Nov 2024 07:28:20 -0800 Subject: [PATCH 228/698] add extensibility pointers to `jax.extend` docstring --- jax/extend/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index e8ef32935cbf..bbb5925ab41a 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -14,16 +14,24 @@ """Modules for JAX extensions. -The :mod:`jax.extend` package provides modules for access to JAX +The :mod:`jax.extend` module provides modules for access to JAX internal machinery. See `JEP #15856 `_. +This module is not the only means by which JAX aims to be +extensible. For example, the main JAX API offers mechanisms for +`customizing derivatives +`_, +`registering custom pytree definitions +`_, +and more. + API policy ---------- Unlike the `public API `_, -this package offers **no compatibility guarantee** across releases. +this module offers **no compatibility guarantee** across releases. Breaking changes will be announced via the `JAX project changelog `_. """ From 1a544b6f363fbb03edc40e03d759cd42a6b64733 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Thu, 7 Nov 2024 08:36:44 -0800 Subject: [PATCH 229/698] [Pallas] Fix lowering tests for reduction ops Remove unnecessary skip statements. Also added tests for bf16 types. PiperOrigin-RevId: 694130207 --- tests/pallas/ops_test.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 318df0b0bfcf..58d353677535 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1793,6 +1793,7 @@ def reduce(x_ref, y_ref): for axis in [0, 1, (1,), (0, 1)] for dtype in [ "float16", + "bfloat16", "float32", "float64", "int32", @@ -1800,28 +1801,29 @@ def reduce(x_ref, y_ref): "uint32", "uint64", ] - if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): - if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: - self.skipTest("16-bit types are not supported on TPU") + if not isinstance(axis, int): + self.skipTest("TODO: tuple axes are not yet supported") if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") + if jtu.test_device_matches(["tpu"]): + self.skipTest("Unimplemented primitive: broadcast_to") + + if jtu.test_device_matches(["tpu"]) and dtype == "float16": + self.skipTest("float16 is not supported on TPU") + # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects # `index_type` to be i32 if ( jax.config.x64_enabled and jtu.test_device_matches(["gpu"]) - and op in {jnp.argmin, jnp.argmax} + and op in (jnp.argmin, jnp.argmax) ): self.skipTest("Not supported on GPU in 64-bit mode") - # The Pallas TPU lowering currently supports only blocks of rank >= 1 - if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") - m, n = 32, 8 def make_x(key): From 83383fc7171e62bdbc35cadc7659d542a33d9048 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 5 Nov 2024 09:26:42 -0800 Subject: [PATCH 230/698] Error on numpy array conversion of PRNG key array --- CHANGELOG.md | 2 ++ jax/_src/prng.py | 5 +++++ tests/random_test.py | 6 ++++++ 3 files changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bafe551100d5..10e3bca6e150 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. + * Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`) + now raises an error. Previously, this returned a scalar object array. * The following deprecated methods and functions in {mod}`jax.export` have been removed: * `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 8925a4342b29..8d1af46becfd 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -279,6 +279,11 @@ def copy(self): __hash__ = None # type: ignore[assignment] __array_priority__ = 100 + def __array__(self, dtype: np.dtype | None = None, copy: bool | None = None) -> np.ndarray: + raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array." + " Use jax.random.key_data(arr) if you wish to extract the underlying" + " integer array.") + # Overwritten immediately below @property def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override] diff --git a/tests/random_test.py b/tests/random_test.py index fed12792d5c6..e18100a63664 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1160,6 +1160,12 @@ def _f_bwd(_, state_bar): result = jax.grad(lambda theta: f(theta, state)[0])(3.0) self.assertEqual(result, 1.0) + def test_keyarray_array_conversion_fails(self): + key = jax.random.key(0) + msg = "JAX array with PRNGKey dtype cannot be converted to a NumPy array." + with self.assertRaisesRegex(TypeError, msg): + np.asarray(key) + # TODO(frostig,mattjj): more polymorphic primitives tests From 37fc834975375d63bd7843a53f0a100e5049689f Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 7 Nov 2024 12:55:07 -0800 Subject: [PATCH 231/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b0aae988216d48e2379c8de1c7c4aedeb98d8985. PiperOrigin-RevId: 694218761 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 9190c136f6e8..0aa248d61faa 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0f6331b1881ae34c8b1cd59580900d556bc8305c" -XLA_SHA256 = "1e4e4317750b2bb2845c6138aaa96b0d94249484d23e9c799d2dd6ecd4b8dd3c" +XLA_COMMIT = "b0aae988216d48e2379c8de1c7c4aedeb98d8985" +XLA_SHA256 = "42ed25652bf91b98c31a7d456d12ea4ca78d3b4083514047f650851383f2cb9d" def repo(): tf_http_archive( From 88a62a45d3e8431e6e9f65d88aef0749db20bc52 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 7 Nov 2024 13:09:18 -0800 Subject: [PATCH 232/698] Reverts 1a544b6f363fbb03edc40e03d759cd42a6b64733 PiperOrigin-RevId: 694223298 --- tests/pallas/ops_test.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 58d353677535..318df0b0bfcf 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1793,7 +1793,6 @@ def reduce(x_ref, y_ref): for axis in [0, 1, (1,), (0, 1)] for dtype in [ "float16", - "bfloat16", "float32", "float64", "int32", @@ -1801,29 +1800,28 @@ def reduce(x_ref, y_ref): "uint32", "uint64", ] + if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): - if not isinstance(axis, int): - self.skipTest("TODO: tuple axes are not yet supported") + if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: + self.skipTest("16-bit types are not supported on TPU") if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") - if jtu.test_device_matches(["tpu"]): - self.skipTest("Unimplemented primitive: broadcast_to") - - if jtu.test_device_matches(["tpu"]) and dtype == "float16": - self.skipTest("float16 is not supported on TPU") - # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects # `index_type` to be i32 if ( jax.config.x64_enabled and jtu.test_device_matches(["gpu"]) - and op in (jnp.argmin, jnp.argmax) + and op in {jnp.argmin, jnp.argmax} ): self.skipTest("Not supported on GPU in 64-bit mode") + # The Pallas TPU lowering currently supports only blocks of rank >= 1 + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU") + m, n = 32, 8 def make_x(key): From 8b7bcadebe57040d80ff09024d2466b479993d5a Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 7 Nov 2024 13:51:04 -0800 Subject: [PATCH 233/698] [Mosaic] Fix canonicalize_extract op name. PiperOrigin-RevId: 694236671 --- jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 232121cc834c..b471f92609c3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -528,7 +528,7 @@ const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ {tpu::MatmulOp::getOperationName(), canonicalize_matmul}, {vector::ContractionOp::getOperationName(), canonicalize_contraction}, - {vector::ContractionOp::getOperationName(), canonicalize_extract}, + {vector::ExtractOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}, {arith::SelectOp::getOperationName(), canonicalize_select}, From 3b2e4a16007e4554fe7e1aece87b0f0968d77676 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 7 Nov 2024 14:11:45 -0800 Subject: [PATCH 234/698] Remove sharding from custom_root_test. This test only takes around 30s on most hardware platforms, it does not need 10 shards. PiperOrigin-RevId: 694243316 --- tests/BUILD | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 90a8913a1825..ba725842b36c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -416,11 +416,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, ) jax_multiplatform_test( From 04a66522431f6932ce5f4ba0fe345b427eb86ad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 7 Nov 2024 14:28:06 -0800 Subject: [PATCH 235/698] [Mosaic] Fix handling of i1 splat constants PiperOrigin-RevId: 694248723 --- jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc | 4 ++-- jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index b4b2280ceea8..80d0e69e128c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3190,8 +3190,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, } const VectorLayout &layout_out = *layouts_out.front(); DenseElementsAttr value = cast(constant_op.getValue()); - const VectorType target_vty = - getNativeVregType(vty.getElementType(), ctx.target_shape); + const VectorType target_vty = getNativeVregOrVmaskType( + vty.getElementType(), layout_out.bitwidth(), ctx.target_shape); if (value.isSplat()) { if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { return op.emitOpError( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 3f5d6262d13c..a079815fa165 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -365,6 +365,11 @@ class VectorLayoutInferer { TPU_CHECK_OP(ty.getRank() > 0, "rank 0 vectors unsupported"); TPU_CHECK_OP(elems, "expected vector constants to use DenseElementsAttr"); auto bitwidth = ty.getElementTypeBitWidth(); + if (bitwidth == 1) { + // i1 is a special case where the layout bitwidth can be different from + // the element bitwidth, see comment in VectorLayout class + bitwidth = kNativeBitwidth; + } if (elems.isSplat()) { if (ty.getRank() == 1) { // Here, we choose to lay out along lanes arbitrarily. It would be From eb2dd2ab0f2eaea2f8fa9401eded6ee876e4c26a Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 7 Nov 2024 15:01:12 -0800 Subject: [PATCH 236/698] Fix copy-paste typo PiperOrigin-RevId: 694259486 --- jax/_src/pallas/mosaic/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f9014da221eb..677996fac9d3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1914,7 +1914,7 @@ def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): lowering_rules[lax.sub_p] = _sub_lowering_rule -skip_mlir_conversions.add(lax.max_p) +skip_mlir_conversions.add(lax.sub_p) def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): From 8b89adc665464a6217be5e299a99be45a830a38c Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 7 Nov 2024 15:30:36 -0800 Subject: [PATCH 237/698] Plumb dot dimension numbers into TPU matmul op. PiperOrigin-RevId: 694268559 --- jax/_src/pallas/mosaic/lowering.py | 93 ++++++++++++++++++++++++------ 1 file changed, 76 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 677996fac9d3..489aae59dcd2 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1564,6 +1564,71 @@ def _proxy_fun(val, *, shape, broadcast_dimensions): lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule +def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape): + """Converts a jax dot dimension numbers to a tpu dot dimension numbers. + + Jax dot dimension numbers are given as a tuple of tuples of sequences of ints + of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, + rhs_batch_dims)). + + TPU dot dimension numbers are given as an MLIR definition of the form + #tpu.dot_dimension_numbers - which can be found in the tpu dilect definition + # file, tpu.td . + """ + (contracting_dims, batch_dims) = dimension_numbers + lhs_contracting_dims, rhs_contracting_dims = contracting_dims + lhs_batch_dims, rhs_batch_dims = batch_dims + + lhs_total_dims = set(range(len(lhs_shape))) + rhs_total_dims = set(range(len(rhs_shape))) + + lhs_non_contracting_dims = sorted( + lhs_total_dims - set(lhs_contracting_dims) - set(lhs_batch_dims) + ) + rhs_non_contracting_dims = sorted( + rhs_total_dims - set(rhs_contracting_dims) - set(rhs_batch_dims) + ) + + # Create output_dim_order + # Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims, + # rhs_non_contracting_dims - this assumption is safe to make, as it is + # the same one made in jax's dot_general. + output_dim_order = [] + + lhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(lhs_shape)))} + rhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(rhs_shape)))} + + for dim in lhs_batch_dims: + output_dim_order.append(0) + output_dim_order.append(lhs_dim_map[dim]) + + for dim in lhs_non_contracting_dims: + output_dim_order.append(0) + output_dim_order.append(lhs_dim_map[dim]) + + for dim in rhs_non_contracting_dims: + output_dim_order.append(1) + output_dim_order.append(rhs_dim_map[dim]) + + def format_dims(dims): + return "[" + ", ".join(str(d) for d in dims) + "]" + + all_dims = ( + lhs_contracting_dims, + rhs_contracting_dims, + lhs_non_contracting_dims, + rhs_non_contracting_dims, + output_dim_order, + lhs_batch_dims, + rhs_batch_dims, + ) + tpu_dim_numbers_str = ( + f"#tpu.dot_dimension_numbers<{','.join(map(format_dims, all_dims))}>" + ) + + return ir.Attribute.parse(tpu_dim_numbers_str) + + def _dot_general_lowering_rule( ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_ ): @@ -1589,7 +1654,7 @@ def _dot_general_lowering_rule( raise NotImplementedError( f"Only 2D tensors supported in dot; received: {ctx.avals_in}" ) - lhs_aval, _ = ctx.avals_in + lhs_aval, rhs_aval = ctx.avals_in # This is really a matrix-vector product. It only looks like matrix-matrix. if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1: if ctx.avals_in[0].shape != ctx.avals_in[1].shape: @@ -1615,19 +1680,10 @@ def _dot_general_lowering_rule( ) return vector.shape_cast(out_type, red) - # TODO(mvoz): Plumb these into dot dimension numbers on the matmul op! - if lhs_dims == (1,): - transpose_lhs = False - elif lhs_dims == (0,): - transpose_lhs = True - else: - raise NotImplementedError - if rhs_dims == (0,): - transpose_rhs = False - elif rhs_dims == (1,): - transpose_rhs = True - else: - raise NotImplementedError + tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims( + dimension_numbers, lhs_aval.shape, rhs_aval.shape + ) + if precision is not None: if precision[0] != precision[1]: raise NotImplementedError("Per-operand dot precision unsupported") @@ -1644,9 +1700,12 @@ def _dot_general_lowering_rule( out_type, ir.DenseElementsAttr.get_splat(out_type, val) ) return tpu.matmul( - out_type, x, y, out_tile, - transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs, - precision=precision_attr + out_type, + x, + y, + out_tile, + dimension_numbers=tpu_dot_dims, + precision=precision_attr, ) From 0bb30f07778923420aaaa470b3094e728adcd693 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 7 Nov 2024 15:50:32 -0800 Subject: [PATCH 238/698] Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified. This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT. This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily. Fixes https://github.com/jax-ml/jax/issues/24521 PiperOrigin-RevId: 694274616 --- jax/_src/api.py | 2 +- jax/_src/array.py | 23 +++++++---- jax/_src/dispatch.py | 19 ++++++--- jax/_src/earray.py | 4 +- jax/_src/interpreters/pxla.py | 63 +++++++++++++++++++++--------- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/prng.py | 5 ++- tests/lax_test.py | 2 +- tests/memories_test.py | 29 ++++++++++++++ tests/pmap_test.py | 2 +- 10 files changed, 112 insertions(+), 39 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index a902c7de4c3e..2ae64b39b44a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1599,7 +1599,7 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - lambda x, s: pxla.shard_args([s], [None], [x])[0], + lambda x, s: pxla.shard_args([s], [None], [None], [x])[0], pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) diff --git a/jax/_src/array.py b/jax/_src/array.py index 30fedf4cff50..515fc2c7c7e6 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -40,6 +40,7 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe +from jax._src.lib import xla_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, NamedSharding, @@ -1110,7 +1111,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): # Look up all buffers that contain the correct slice of the logical array. candidates_list = candidates[hashed_index(idx)] if not candidates_list: - return pxla.shard_args([sharding], [None], [x._value], + return pxla.shard_args([sharding], [None], [None], [x._value], canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. @@ -1130,11 +1131,13 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _array_shard_arg(xs, shardings, layouts): +def _array_shard_arg(xs, shardings, layouts, copy_semantics): results = [] batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] + batch_cs = [] - for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)): + for i, (x, sharding, layout, cs) in enumerate( + safe_zip(xs, shardings, layouts, copy_semantics)): x._check_if_deleted() indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) same_layout = (True if layout is None else @@ -1156,6 +1159,7 @@ def _array_shard_arg(xs, shardings, layouts): batch_devs.append(list(devices)) batch_shardings.append(sharding) batch_indices.append(i) + batch_cs.append(cs) # Resharding starts here: elif not same_layout: results.append(api.device_put(x, Layout(layout, sharding))) @@ -1165,8 +1169,12 @@ def _array_shard_arg(xs, shardings, layouts): results.append( shard_sharded_device_array_slow_path(x, devices, indices, sharding)) - copy_outs = xc.batched_copy_array_to_devices_with_sharding( - batch_xs, batch_devs, batch_shardings) + if xla_extension_version >= 296: + copy_outs = xc.batched_copy_array_to_devices_with_sharding( + batch_xs, batch_devs, batch_shardings, batch_cs) + else: + copy_outs = xc.batched_copy_array_to_devices_with_sharding( # type: ignore + batch_xs, batch_devs, batch_shardings) for i, copy_out in safe_zip(batch_indices, copy_outs): assert results[i] is None results[i] = copy_out @@ -1200,8 +1208,9 @@ def _array_local_result_handler(aval, sharding, indices): # Token handlers -def _token_shard_arg(xs, shardings, layouts): - return _array_shard_arg([x._buf for x in xs], shardings, layouts) +def _token_shard_arg(xs, shardings, layouts, copy_semantics): + return _array_shard_arg([x._buf for x in xs], shardings, layouts, + copy_semantics) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b0b390773512..8d53742bc7cf 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -137,7 +137,7 @@ def get_token_input( # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0]) + sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -391,6 +391,7 @@ class _DeferredShardArg: s: Sharding aval: core.AbstractValue committed: bool + copy_semantics: CopySemantics @property def result_handler(self): @@ -435,7 +436,7 @@ def _device_put_sharding_impl(x, aval, device, copy): "device_put's second argument must be a Device or a Sharding which" f" represents addressable devices, but got {s}. Please pass device or" " Sharding which represents addressable devices.") - return _DeferredShardArg(x, s, aval, True) + return _DeferredShardArg(x, s, aval, True, copy) # Only `Device` exists below. `Sharding` instance is handled above. if isinstance(x, array.ArrayImpl): @@ -443,8 +444,11 @@ def _device_put_sharding_impl(x, aval, device, copy): raise ValueError( "device_put's first argument must be a fully addressable array, but " f"got value with devices {x.devices()}") - if device is None and copy == CopySemantics.ALIAS: - return x + if device is None: + if copy == CopySemantics.ALIAS: + return x + else: + return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], @@ -452,7 +456,7 @@ def _device_put_sharding_impl(x, aval, device, copy): sh = SingleDeviceSharding(pxla._get_default_device() if device is None else device) - return _DeferredShardArg(x, sh, aval, device is not None) + return _DeferredShardArg(x, sh, aval, device is not None, copy) def _device_put_impl( @@ -501,12 +505,14 @@ def _batched_device_put_impl( copy_semantics: Sequence[CopySemantics]): ys = [] shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], [] + shard_arg_copy_semantics = [] for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)): y = _device_put_impl(x, device=device, src=src, copy=cp) if isinstance(y, _DeferredShardArg): shard_arg_indices.append(i) shard_arg_xs.append(y.x) shard_arg_shardings.append(y.s) + shard_arg_copy_semantics.append(y.copy_semantics) ys.append(y) if shard_arg_xs: @@ -515,7 +521,8 @@ def _batched_device_put_impl( # device_put handles `Layout` via a different path, so just pass `None` as # the layout here. shard_arg_results = pxla.shard_args( - shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs) + shard_arg_shardings, [None] * len(shard_arg_xs), + shard_arg_copy_semantics, shard_arg_xs) for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): assert isinstance(ys[i], _DeferredShardArg) ys[i] = ys[i].result_handler(shard_arg_result) diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 4be10e733c0d..7bade8171078 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -108,12 +108,12 @@ def global_shards(self): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(xs, shardings, layouts): +def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics): arrs = [x._data for x in xs] phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] # TODO(yashkatariya): `layouts` should be converted to physical layouts. - return pxla.shard_args(phys_shardings, layouts, arrs) + return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 04d479fb757c..c83d3e3a4804 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -61,6 +61,7 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -105,21 +106,44 @@ class WeakRefList(list): ### util + +def to_xc_copy_semantics(copy_semantics): + if xla_extension_version < 296: + return [None] * len(copy_semantics) + out = [] + for cs in copy_semantics: + if cs is None or cs == dispatch.CopySemantics.ALIAS: + out.append(xc.ArrayCopySemantics.REUSE_INPUT) + elif cs == dispatch.CopySemantics.COPY: + out.append(xc.ArrayCopySemantics.ALWAYS_COPY) + elif cs == dispatch.CopySemantics.DONATE: + out.append(xc.ArrayCopySemantics.DONATE_INPUT) + else: + assert isinstance(cs, xc.ArrayCopySemantics) + out.append(cs) + return out + + def identity(x): return x @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], layouts, args, - canonicalize=True) -> Sequence[xc.ArrayImpl]: +def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics, + args, canonicalize=True) -> Sequence[xc.ArrayImpl]: + xc_copy_semantics = to_xc_copy_semantics(copy_semantics) + del copy_semantics # Fast path for one argument. if len(args) == 1: arg = args[0] if canonicalize: arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)]([arg], shardings, layouts) - - # type(arg) -> (list[indices], list[args], list[shardings]) - batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore - for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)): + return shard_arg_handlers[type(arg)]([arg], shardings, layouts, + xc_copy_semantics) + + # type(arg) -> (list[indices], list[args], list[shardings], list[layouts], + # list[copy_semantics]) + batches = collections.defaultdict(lambda: ([], [], [], [], [])) # type: ignore + for i, (arg, sharding, layout, cs) in enumerate( + safe_zip(args, shardings, layouts, xc_copy_semantics)): if canonicalize: arg = xla.canonicalize_dtype(arg) batch = batches[type(arg)] @@ -127,14 +151,15 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, batch[1].append(arg) batch[2].append(sharding) batch[3].append(layout) + batch[4].append(cs) # Call `shard_arg_handlers` per batch and build a flat list of arrays returned # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. results: list[jax.Array | None] = [None] * len(args) - for t, (indices, a, s, l) in batches.items(): - outs = shard_arg_handlers[t](a, s, l) + for t, (indices, a, s, l, cs) in batches.items(): + outs = shard_arg_handlers[t](a, s, l, cs) for i, out in safe_zip(indices, outs): results[i] = out assert all(result is not None for result in results) @@ -142,7 +167,8 @@ def shard_args(shardings: Sequence[JSharding], layouts, args, shard_arg_handlers: dict[ - Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]] + Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any], Sequence[Any]], + Sequence[Any]] ] = {} @@ -172,12 +198,12 @@ def is_default_layout(curr_layout, sharding, aval): raise -def _masked_array_error(xs, shardings, layouts): +def _masked_array_error(xs, shardings, layouts, copy_semantics): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_np_array(xs, shardings, layouts): +def _shard_np_array(xs, shardings, layouts, copy_semantics): results = [] for x, sharding, layout in safe_zip(xs, shardings, layouts): devices = sharding._addressable_device_assignment @@ -197,12 +223,12 @@ def _shard_np_array(xs, shardings, layouts): for _t in array_types: shard_arg_handlers[_t] = _shard_np_array -def _shard_darray(xs, shardings, layouts): - return shard_args(shardings, layouts, [x._data for x in xs]) +def _shard_darray(xs, shardings, layouts, copy_semantics): + return shard_args(shardings, layouts, copy_semantics, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(xs, shardings, layouts): - return shard_args(shardings, layouts, [x._buf for x in xs]) +def _shard_mutable_array(xs, shardings, layouts, copy_semantics): + return shard_args(shardings, layouts, copy_semantics, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, @@ -1135,7 +1161,8 @@ class InputsHandler: def __init__(self, in_shardings, in_layouts, local_devices=None, input_indices=None): - self.handler = partial(shard_args, in_shardings, in_layouts) + self.handler = partial(shard_args, in_shardings, in_layouts, + [None] * len(in_shardings)) self.in_shardings = in_shardings self.in_layouts = in_layouts self.local_devices = local_devices @@ -3047,7 +3074,7 @@ def aot_cache_miss(*args, **kwargs): JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): - return shard_args([sharding], [layout], [x])[0] + return shard_args([sharding], [layout], [None], [x])[0] def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 19d3429d2675..b5bb8658e675 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -704,7 +704,7 @@ def _maybe_put(x): aval = shaped_abstractify(x) s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) - return result_handler(pxla.shard_args([s], [None], [x])) + return result_handler(pxla.shard_args([s], [None], [None], [x])) else: return x diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 8925a4342b29..1a5ceab5d984 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -468,12 +468,13 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts): +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts, + copy_semantics): arrs = [x._base_array for x in xs] phys_shardings = [physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] # TODO(yashkatariya): `layouts` should be converted to physical layouts. - return pxla.shard_args(phys_shardings, layouts, arrs) + return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler diff --git a/tests/lax_test.py b/tests/lax_test.py index 12149700cb30..ab1557450864 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3817,7 +3817,7 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(xs, shardings, layouts): +def shard_foo_array_handler(xs, shardings, layouts, copy_semantics): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment diff --git a/tests/memories_test.py b/tests/memories_test.py index 5f0ab04612e2..337b1c24d835 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -25,6 +25,7 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb +from jax._src.lib import xla_extension_version from jax._src.layout import DeviceLocalLayout as DLL, Layout from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint @@ -695,6 +696,34 @@ def foo(x): if compiled_text is not None: self.assertIn('custom_call_target="AllocateBuffer"', compiled_text) + def test_disallow_alias_copies_arrays(self): + if xla_extension_version < 296: + self.skipTest("Requires xla_extension_version >= 296") + _, _, _, inp_host = _create_inputs( + (8, 2), P("x", "y"), mem_kind="pinned_host") + + inp_host_copy = jax.device_put(inp_host, may_alias=False) + + for a in jax.tree.leaves(inp_host): + a.delete() + + jax.block_until_ready(inp_host_copy) + + def test_disallow_alias_copies_arrays_with_donated_input(self): + if xla_extension_version < 296: + self.skipTest("Requires xla_extension_version >= 296") + _, _, _, inp_host = _create_inputs( + (8, 2), P("x", "y"), mem_kind="pinned_host") + + inp_host_donate = jax.jit(lambda x: x, donate_argnums=0)(inp_host) + + inp_host_donate_copy = jax.device_put(inp_host_donate, may_alias=False) + + for a in jax.tree.leaves(inp_host_donate): + a.delete() + + jax.block_until_ready(inp_host_donate_copy) + class ComputeOffload(jtu.BufferDonationTestCase): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 0df5d99715e6..f611ee981335 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3013,7 +3013,7 @@ def testShardArgs(self, shape, spec, make_arg): x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - results = pxla.shard_args([sharding], [None], [arg]) + results = pxla.shard_args([sharding], [None], [None], [arg]) self.assertEqual(len(results), 1) if isinstance(results[0], array.ArrayImpl): bufs = results[0]._arrays From 1f1d27de2f6c6605f0dcd55677d5b3574dec9741 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 7 Nov 2024 15:57:07 -0800 Subject: [PATCH 239/698] [Mosaic GPU] Implement the skeleton of a lowering pass for the Mosaic GPU dialect. Also add a lowering rule for `mosaic_gpu.initialize_barrier`. PiperOrigin-RevId: 694276698 --- jax/experimental/mosaic/gpu/__init__.py | 6 + .../mosaic/gpu/dialect_lowering.py | 125 ++++++++++++++++++ tests/mosaic/gpu_dialect_test.py | 86 +++++++++++- 3 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 jax/experimental/mosaic/gpu/dialect_lowering.py diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 4feb12704f98..cf8d2c84c246 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -27,6 +27,12 @@ Union as Union, as_gpu_kernel as as_gpu_kernel, ) + +if dialect is not None: + from .dialect_lowering import lower_mgpu_dialect +else: + lower_mgpu_dialect = None + from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py new file mode 100644 index 000000000000..7d36272dc111 --- /dev/null +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -0,0 +1,125 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lowering rules and pass for the MLIR Mosaic GPU dialect.""" + +from collections.abc import Callable +import functools +import operator +from typing import Sequence, Type + +from jax._src.interpreters import mlir as mlir_interpreter +from jax._src.lib import mosaic_gpu_dialect as mgpu + +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm +from .utils import c, memref_ptr, single_thread_predicate + + +MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]] + + +_lowerings: dict[str, MlirLoweringRule] = {} + + +# TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36. +# Jaxlib doesn't contain Mosaic GPU dialect bindings. +InitializeBarrierOp = mgpu.InitializeBarrierOp if mgpu is not None else None + +def _register_lowering( + op: str | Type[ir.OpView] +) -> Callable[[MlirLoweringRule], MlirLoweringRule]: + def wrapper(f): + op_name = op if isinstance(op, str) else op.OPERATION_NAME # pytype: disable=attribute-error + _lowerings[op_name] = f + return f + + return wrapper + + +def _lowered_barrier_type() -> ir.Type: + return ir.IntegerType.get_signless(64) + + +def _gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int: + match address_space: + case gpu.AddressSpace.Global: + return 1 + case gpu.AddressSpace.Workgroup: + return 3 + case _: + raise NotImplementedError(f"address_space not supported: {address_space}") + + +@_register_lowering(InitializeBarrierOp) +def _initialize_barrier_op_lowering_rule( + initialize_barrier_op: InitializeBarrierOp) -> Sequence[ir.Value]: + + shape = initialize_barrier_op.barriers_ref.type.shape + num_barriers = functools.reduce(operator.mul, shape, 1) + + i32 = ir.IntegerType.get_signless(32) + workgroup_nvptx_address_space = _gpu_address_space_to_nvptx( + gpu.AddressSpace.Workgroup) + ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") + + lowered_barrier_type = _lowered_barrier_type() + lowered_barrier_ref = memref.alloca( + ir.MemRefType.get(shape, lowered_barrier_type), [], []) + barrier_ref_address = memref_ptr( + lowered_barrier_ref, memory_space=workgroup_nvptx_address_space) + + predicate = single_thread_predicate(per_block=True) + for i in range(num_barriers): + nvvm.mbarrier_init_shared( + llvm.getelementptr(ptr_ty, barrier_ref_address, [], [i], + lowered_barrier_type), + c(initialize_barrier_op.arrival_count.value, i32), + predicate=predicate + ) + return barrier_ref_address, + + +def lower_mgpu_dialect(module: ir.Module): + module.context.append_dialect_registry(mlir_interpreter.upstream_dialects) + module.context.load_all_available_dialects() + + lowered_operations: set[ir.Operation | ir.OpView] = set() + + def _lower_op(op: ir.OpView): + if op.name not in _lowerings: + return + lowering_rule = _lowerings[op.name] + new_results = lowering_rule(op) + for old, new in zip(op.results, new_results): + old.replace_all_uses_with(new) + lowered_operations.add(op) + + def _traverse_and_lower_op(op: ir.OpView): + for region in op.operation.regions: + for block in region: + for block_op in list(block): + with ir.InsertionPoint(block_op): + _traverse_and_lower_op(block_op) + _lower_op(op) + + with ir.InsertionPoint(module.body): + for op in module.body: + _traverse_and_lower_op(op) + + for lowered_op in lowered_operations: + lowered_op.erase() diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index d6428a98c96a..c5a3e9d6cc57 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -14,12 +14,19 @@ # ============================================================================== """(Deviceless) tests for the Mosaic GPU MLIR dialect.""" +from typing import Callable + from absl.testing import parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir -from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import nvvm +from jax._src.lib.mlir.dialects import scf +from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member +from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import _cext = mgpu._cext if mgpu is not None else None @@ -29,10 +36,35 @@ def _make_ir_context(): context = ir.Context() + context.append_dialect_registry(mlir_interpreter.upstream_dialects) + context.load_all_available_dialects() mgpu.register_dialect(context) return context +def walk_operations(op: ir.OpView, callback): + for region in op.operation.regions: + for block in region: + for block_op in block: + walk_operations(block_op, callback) + callback(op) + + +def find_if(module: ir.Module, + predicate: Callable[[ir.OpView], bool]) -> list[ir.OpView]: + result = [] + def callback(op: ir.OpView): + if predicate(op): + result.append(op) + for op in module.body.operations: + walk_operations(op, callback) + return result + + +def is_mosaic_gpu_op(op: ir.OpView) -> bool: + return op.name.startswith("mosaic_gpu.") + + class DialectTest(parameterized.TestCase): def setUp(self): @@ -72,5 +104,57 @@ def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self): mgpu.InitializeBarrierOp) +class DialectLoweringTest(DialectTest): + + def test_lowering_removes_mosaic_gpu_ops(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + arrival_count=1) + lower_mgpu_dialect(self.module) + + self.assertEmpty( + list(filter(is_mosaic_gpu_op, self.module.body.operations))) + + def test_lowering_traverses_regions_correctly(self): + with ir.InsertionPoint(self.module.body): + bool_type = ir.IntegerType.get_signless(1) + cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1)) + if_op = scf.IfOp(cst_true) + with ir.InsertionPoint(if_op.then_block): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + arrival_count=1) + scf.yield_([]) + lower_mgpu_dialect(self.module) + + self.assertEmpty( + list(filter(is_mosaic_gpu_op, if_op.then_block.operations))) + + def test_initialize_barrier_op_lowering_rule(self): + shape = (3, 4) + num_shape_elements = shape[0] * shape[1] + arrival_count = 1337 + + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), + arrival_count=arrival_count) + lower_mgpu_dialect(self.module) + + all_mbarrier_init_shared_ops = find_if( + self.module, + lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME) + + # One nvvm.mbarrier_init_shared is issued per barrier. + self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements) + + # Each barrier has its count equal to the arrival count. + for op in all_mbarrier_init_shared_ops: + count = op.count.owner.opview + self.assertIsInstance(count, arith.ConstantOp) + self.assertEqual(count.literal_value, arrival_count) + + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 58dee3ea33e1143aee601ddd3735f489fb9e162e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 7 Nov 2024 16:01:22 -0800 Subject: [PATCH 240/698] jax.device_get: handle generic extended dtypes --- jax/_src/api.py | 8 ++------ jax/_src/prng.py | 5 ----- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index a902c7de4c3e..e2dc9c9a1f57 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2445,12 +2445,8 @@ def _device_get(x): # Extended dtypes dispatch via their device_get rule. if isinstance(x, basearray.Array) and dtypes.issubdtype(x.dtype, dtypes.extended): - try: - to_device = x.dtype._rules.device_get - except AttributeError: - pass - else: - return to_device(x) + bufs, tree = tree_util.dispatch_registry.flatten(x) + return tree.unflatten(device_get(bufs)) # Other types dispatch via their __array__ method. try: diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 8925a4342b29..b9054431e020 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -400,11 +400,6 @@ def make_sharded_array(aval, sharding, arrays, committed): phys_result = phys_handler(phys_arrays) return PRNGKeyArray(aval.dtype._impl, phys_result) - @staticmethod - def device_get(val): - buffer = api.device_get(random_unwrap(val)) - return random_wrap(buffer, impl=val.dtype._impl) - @staticmethod def device_put_sharded(vals, aval, sharding, devices): physical_aval = core.physical_aval(aval) From 6e8a35f08ca9f76d19534aaf9eb6550f143f0e63 Mon Sep 17 00:00:00 2001 From: James Martens Date: Thu, 7 Nov 2024 17:05:13 -0800 Subject: [PATCH 241/698] Adding support for copy_p primitive to jet. PiperOrigin-RevId: 694296952 --- jax/experimental/jet.py | 1 + tests/jet_test.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 8dd2a319a1cb..827e4d01b390 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -329,6 +329,7 @@ def linear_prop(prim, primals_in, series_in, **params): deflinear(lax.reduce_sum_p) deflinear(lax.reduce_window_sum_p) deflinear(lax.fft_p) +deflinear(lax.copy_p) deflinear(dispatch.device_put_p) def _dynamic_slice_jet_rule(primals_in, series_in, **params): diff --git a/tests/jet_test.py b/tests/jet_test.py index 4e437c044426..7c2c71e9bbfa 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -319,6 +319,8 @@ def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0)) def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1))) @jtu.skip_on_devices("tpu") def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3))) + @jtu.skip_on_devices("tpu") + def test_copy(self): self.unary_check(jnp.array) @jtu.skip_on_devices("tpu") From 61150607e5cac7323732638fc1a2da3f1960213e Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 8 Nov 2024 01:21:11 +0000 Subject: [PATCH 242/698] don't warn on unused `type: ignore` --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9f5f06e7a1b0..6e625e708d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ show_error_codes = true disable_error_code = "attr-defined, name-defined, annotation-unchecked" no_implicit_optional = true warn_redundant_casts = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = [ From 60a6cd475b4155892a90da60bb2af086be5fc9b3 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 7 Nov 2024 18:08:17 -0800 Subject: [PATCH 243/698] Add note on etils requirement for the Jax compilation cache. The compilation cache has a dependency on etils.epath if the cache is not on a local filesystem. PiperOrigin-RevId: 694311585 --- docs/persistent_compilation_cache.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index c49e18394e9a..246d3a6cb084 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -1,11 +1,18 @@ # Persistent compilation cache - + JAX has an optional disk cache for compiled programs. If enabled, JAX will store copies of compiled programs on disk, which can save recompilation time when running the same or similar tasks repeatedly. +Note: if the compilation cache is not on a local filesystem, +[etils](https://pypi.org/project/etils/) needs to be installed. + +```python +pip install etils +``` + ## Usage ### Quick start From 0a42bf12c3fcbd9969caa7d514cbe9d627f93ccb Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 7 Nov 2024 07:28:20 -0800 Subject: [PATCH 244/698] add about page This is an initial draft. There is more to come back and add/improve. --- docs/about.md | 123 +++++++++++++++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + 2 files changed, 124 insertions(+) create mode 100644 docs/about.md diff --git a/docs/about.md b/docs/about.md new file mode 100644 index 000000000000..c4bc93140fbc --- /dev/null +++ b/docs/about.md @@ -0,0 +1,123 @@ +(about-the-project)= + +# About the project + +The JAX project is led by the JAX core team. We develop in the open, +and welcome open-source contributions from across the community. We +frequently see contributions from [Google +DeepMind](https://deepmind.google/), Alphabet more broadly, +[NVIDIA](https://docs.nvidia.com/deeplearning/frameworks/jax-release-notes/overview.html), +and elsewhere. + +At the heart of the project is the [JAX +core](http://github.com/google/jax) library, which focuses on the +fundamentals of machine learning and numerical computing, at scale. + +When [developing](#development) the core, we want to maintain agility +and a focused scope, so we lean heavily on a surrounding [modular +technology stack](#components). First, we design the `jax` module +to be +[composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) +and +[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +that a wide variety of domain-specific libraries can thrive outside of +it in a decentralized manner. Second, we lean heavily on a modular +backend stack (compiler and runtime) to target different +accelerators. Whether you are [writing a new domain-specific library +built with JAX](#upstack), or looking to [support +new hardware](#downstack), you can often +contribute these with *minimal to no modifications* to the JAX core +codebase. + +Many of JAX's core contributors have roots in open-source software and +in research, in fields spanning computer science and the natural +sciences. We strive to continuously enable the cutting edge of machine +learning and numerical computing---across all compute platforms and +accelerators---and to discover the truths of array programming at +scale. + +(development)= +## Open development + +JAX's day-to-day development takes place in the open on GitHub, using +pull requests, the issue tracker, discussions, and [JAX Enhancement +Proposals +(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +and participating in these is a good way to get involved. We also +maintain [developer +notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +that cover JAX's internal design. + +The JAX core team determines whether to accept changes and +enhancements. Maintaining a simple decision-making structure currently +helps us develop at the speed of the research frontier. Open +development is a core value of ours, and we may adapt to a more +intricate decision structure over time (e.g. with designated area +owners) if/when it becomes useful to do so. + +For more see [contributing to +JAX](https://jax.readthedocs.io/en/latest/contributing.html). + +(components)= +## A modular stack + +To enable (a) a growing community of users across numerical domains, +and (b) an advancing hardware landscape, we lean heavily on +**modularity**. + +(upstack)= +### Libraries built on JAX + +While the JAX core library focuses on the fundamentals, we want to +encourage domain-specific libraries and tools to be built on top of +JAX. Indeed, [many +libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +emerged around JAX to offer higher-level features and extensions. + +How do we encourage such decentralized development? We guide it with +several technical choices. First, JAX's main API focuses on basic +building blocks (e.g. numerical primitives, NumPy operations, arrays, +and transformations), encouraging auxiliary libraries to develop +utilities as needed for their domain. In addition, JAX exposes a +handful of more advanced APIs for +[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +and +[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +can [lean on these +APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +order to use JAX as an internal means of implementation, to integrate +more with its transformations like autodiff, and more. + +Projects across the JAX ecosystem are developed in a distributed and +often open fashion. They are not governed by the JAX core team, even +though sometimes team members contribute to them or maintain contact +with their developers. + +(downstack)= +### A pluggable backend + +We want JAX to run on CPUs, GPUs, TPUs, and other hardware platforms +as they emerge. To encourage unhindered support of JAX on new +platforms, the JAX core emphasizes modularity in its backend too. + +To manage hardware devices and memory, and for compilation to such +devices, JAX calls out to the open [XLA +compiler](https://openxla.org/) and the [PJRT +runtime](https://github.com/openxla/xla/tree/main/xla/pjrt/c#pjrt---uniform-device-api). Both +of these are projects external to the JAX core, governed and +maintained by OpenXLA (again, with frequent contributions from and +discussion with the JAX core developers). + +XLA aims for interoperability across accelerators (e.g. by ingesting +[StableHLO](https://openxla.org/stablehlo) as input) and PJRT offers +extensibility through a plug-in device API. Adding support for new +devices is done by implementing a backend lowering for XLA, and +implementing a plug-in device API defined by PJRT. If you're looking +to contribute to compilation, or to supporting new hardware, we +encourage you to contribute at the XLA and PJRT layers. + +These open system components allow third parties to support JAX on new +accelerator platforms, *without requiring changes in the JAX +core*. There are several plug-ins in development today. For example, a +team at Apple is working on a PJRT plug-in to get [JAX running on +Apple Metal](https://developer.apple.com/metal/jax/). diff --git a/docs/index.rst b/docs/index.rst index 7d555fe85eb4..ba724f8e77ab 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -142,6 +142,7 @@ maintains an up-to-date list. extensions notes jax + about .. toctree:: From 218f7632556f9af187d273fd2a33628f4b9411f7 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Thu, 7 Nov 2024 23:55:38 -0600 Subject: [PATCH 245/698] (follow-up of PR #23852) add missing `typename` keyword to work with `gcc` This update is a follow-up of PR #23852. In the previous PR, there was one missing place where the `typename` was not added. --- jaxlib/gpu/solver_kernels_ffi.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 7852da4bc04f..b5742b5a7972 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -915,7 +915,7 @@ ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto s_data = static_cast::value*>(s->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); auto u_data = static_cast(u->untyped_data()); auto v_data = static_cast(v->untyped_data()); auto info_data = info->typed_data(); From 927d7fc20519788a5c343377fb71472162a5931b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 7 Nov 2024 22:59:58 -0800 Subject: [PATCH 246/698] Skip flaky test on tpuv4 PiperOrigin-RevId: 694372268 --- tests/pallas/pallas_jumble_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py index f26352da0f38..8452d1ee7264 100644 --- a/tests/pallas/pallas_jumble_test.py +++ b/tests/pallas/pallas_jumble_test.py @@ -227,6 +227,9 @@ def test_vmap_jumble_over_matmul_kernel(self): if not jtu.test_device_matches(["tpu"]): self.skipTest("Only tested on TPU") + if jtu.is_device_tpu(version=4): + self.skipTest("Flaky 15% of the time on tpuv4?") + m = 128 k = 640 n = 640 From 5e43220e974fe05425680023bf9bbe0e9701bf1f Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 8 Nov 2024 02:58:38 -0800 Subject: [PATCH 247/698] [mosaic_gpu] Scalar arguments to kernels. PiperOrigin-RevId: 694426328 --- jax/experimental/mosaic/gpu/utils.py | 28 +++++++++++++++------------- tests/mosaic/gpu_test.py | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 6d3fd54cf2f7..f8918488563e 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -46,27 +46,29 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType): - if len(memref_ty.shape) == 0: - raise NotImplementedError i64 = ir.IntegerType.get_signless(64) rank = len(memref_ty.shape) - desc_ty = ir.Type.parse( - f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>" - ) + if rank > 0: + desc_ty = ir.Type.parse( + f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>" + ) + else: + desc_ty = ir.Type.parse("!llvm.struct<(ptr, ptr, i64)>") desc = llvm.UndefOp(desc_ty) desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2] ) - for i, s in enumerate(memref_ty.shape): - desc = llvm.InsertValueOp( - desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] - ) - for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): - desc = llvm.InsertValueOp( - desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] - ) + if rank > 0: + for i, s in enumerate(memref_ty.shape): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] + ) + for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): + desc = llvm.InsertValueOp( + desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] + ) return builtin.unrealized_conversion_cast([memref_ty], [desc]) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 9a3f8ccfdadd..382cad79fb2b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -363,6 +363,26 @@ def kernel(ctx, inp, out, _): else: do_test() + @parameterized.parameters(jnp.uint64, jnp.uint32, jnp.uint16, jnp.uint8) + def test_scalar_argument(self, dtype): + scalar = 42 + expected = np.full((128, 128), scalar, dtype=dtype) + + def kernel(ctx, inp, out, _): + del ctx + inp = memref.load(inp, []) + mgpu.FragmentedArray.splat(inp, expected.shape, is_signed=True).store_untiled(out) + + res = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + jax.ShapeDtypeStruct(shape=(), dtype=expected.dtype), + expected, + (), + )(scalar) + np.testing.assert_array_equal(res, expected) + def get_packed_shape(strides, shape): perm = sorted(range(len(strides)), key=lambda i: strides[i], reverse=True) From 6a124ac554f24e1dd5f494b3e3c1c97bc8057cba Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 8 Nov 2024 04:37:32 -0800 Subject: [PATCH 248/698] [Mosaic GPU] Implement tiled and swizzled transfers for tiled layouts PiperOrigin-RevId: 694449664 --- .../mosaic/gpu/fragmented_array.py | 216 +++++++++++++++--- jax/experimental/mosaic/gpu/utils.py | 4 + tests/mosaic/BUILD | 2 +- tests/mosaic/gpu_test.py | 59 +++++ 4 files changed, 245 insertions(+), 36 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index a9d12706ff47..040174b900c9 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -19,7 +19,7 @@ import dataclasses import functools import math -from typing import Sequence, TypeVar, Iterable +from typing import Iterable, Sequence, TypeVar import jax from jaxlib.mlir import ir @@ -110,6 +110,23 @@ def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides + def tile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: + for tile in self.tiles: + untiled, tiled = indices[:-len(tile)], indices[-len(tile):] + indices = ( + *untiled, + *(i // t for i, t in zip(tiled, tile)), + *(i % t for i, t in zip(tiled, tile)), + ) + return indices + + def untile_indices(self, indices: tuple[int, ...]) -> tuple[int, ...]: + for tile in reversed(self.tiles): + untiled = indices[:-2 * len(tile)] + outer = indices[-2 * len(tile):-len(tile)] + inner = indices[-len(tile):] + indices = (*untiled, *(o * t + i for o, i, t in zip(outer, inner, tile))) + return indices def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: """Like built-in enumerate, but returns negative indices into the sequence.""" @@ -185,6 +202,15 @@ def __post_init__(self): if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: raise ValueError + @property + def base_tile_shape(self) -> int: + """The shape of the first tile in the tiling expression. + + This tile acts as the divisibility constraint for a suffix of arrays to + which this layout applies. + """ + return self.tiling.tiles[0] + @functools.cached_property def tiled_tiling_shape(self) -> tuple[int, ...]: """The shape of the suffix of the array after tiling. @@ -194,7 +220,7 @@ def tiled_tiling_shape(self) -> tuple[int, ...]: so the tiled shape always ends with this suffix, no matter what array shape it's applied to. """ - return self.tiling.tile_shape(self.tiling.tiles[0]) + return self.tiling.tile_shape(self.base_tile_shape) @property def vector_length(self) -> int: @@ -231,6 +257,8 @@ def lane_indices(self) -> tuple[ir.Value, ...]: assert math.prod(tiled_shape) == WARP_SIZE lane_strides = utils.get_contiguous_strides(tiled_shape) lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) + # TODO(apaszke): Rewrite so that we can be sure that this never actually + # does arithmetic for any dimensions that are not in lane_dims. return tuple( arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) for stride, size in zip(lane_strides, tiled_shape) @@ -1260,10 +1288,8 @@ def _store_untiled_tiled(self, ref: ir.Value): ptr = utils.memref_ptr(ref) # Fold warp and lane offsets into the pointer once, since they are dynamic. dyn_strides = [arith.constant(i32, s) for s in strides] - def dyn_dot(x, y): - return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) - warp_offset = dyn_dot(layout.warp_indices(), dyn_strides) - lane_offset = dyn_dot(layout.lane_indices(), dyn_strides) + warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) + lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) dyn_offset = arith.addi(warp_offset, lane_offset) ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) # All warp tile offsets are static and can be fused into the store. @@ -1273,41 +1299,68 @@ def dyn_dot(x, y): llvm.store(reg, reg_ptr) def store_tiled(self, ref, swizzle: int | None): - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError - dtype = self.mlir_dtype - bw = mgpu.bytewidth(dtype) - m, n = self.shape - assert m % 64 == 0 # This is implied by the layout. - cols_per_tile = swizzle // bw - expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] - if n < cols_per_tile: # We allow singular tiles shorter than swizzle. - expected_shape = [m // 64, 1, 64, cols_per_tile] - if ir.MemRefType(ref.type).shape != expected_shape: - raise ValueError(ref.type, (m, n)) - for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): - vector.store(get(self.registers), ref, idxs) + match self.layout: + case WGMMAFragLayout(): + dtype = self.mlir_dtype + bw = mgpu.bytewidth(dtype) + m, n = self.shape + assert m % 64 == 0 # This is implied by the layout. + cols_per_tile = swizzle // bw + expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if n < cols_per_tile: # We allow singular tiles shorter than swizzle. + expected_shape = [m // 64, 1, 64, cols_per_tile] + if ir.MemRefType(ref.type).shape != expected_shape: + raise ValueError(ref.type, (m, n)) + for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): + vector.store(get(self.registers), ref, idxs) + case TiledLayout(): + layout, shape = self.layout, self.shape + for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): + llvm.store(get(self.registers), ptr) + case _: + raise NotImplementedError(self.layout) @classmethod def load_tiled( - cls, ref, swizzle: int | None, *, is_signed: bool | None = None + cls, + ref, + swizzle: int | None, + *, + is_signed: bool | None = None, + layout: FragmentedLayout = WGMMA_LAYOUT, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type - bw = mgpu.bytewidth(dtype) - m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape - if m_tile_size != 64 or n_tile_size != (swizzle // bw): - raise ValueError - m, n = m_tiles * m_tile_size, n_tiles * n_tile_size - assert m % 64 == 0 # This is implied by the layout. - registers = np.full( - (m_tiles, n // 8, 2, 1), - vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)), - dtype=object, - ) - for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): - update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) - return cls(_registers=registers, _layout=WGMMA_LAYOUT, _is_signed=is_signed) + match layout: + case TiledLayout(): + ref_ty = ir.MemRefType(ref.type) + tiled_shape = ref_ty.shape + if len(tiled_shape) % 2: + raise ValueError("Tiled reference must have even rank") + tiling = Tiling((tiled_shape[len(tiled_shape) // 2:],)) + shape = tiling.untile_shape(tiled_shape) + registers = np.full(layout.registers_shape(shape), None, dtype=object) + reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) + for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): + update(registers, llvm.load(reg_ty, ptr)) + assert all(r is not None for r in registers.flat) + case WGMMAFragLayout(): + bw = mgpu.bytewidth(dtype) + m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape + if m_tile_size != 64 or n_tile_size != (swizzle // bw): + raise ValueError + m, n = m_tiles * m_tile_size, n_tiles * n_tile_size + assert m % 64 == 0 # This is implied by the layout. + registers = np.full( + (m_tiles, n // 8, 2, 1), + vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)), + dtype=object, + ) + for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): + update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) + case _: + raise NotImplementedError(layout) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): @@ -1393,6 +1446,99 @@ def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx): regs[right_idx] = arith.select(is_stagger_left, regs[right_idx], new) yield get_register, update_registers, idx + @staticmethod + def transfer_tiled2( + ref: ir.Value, + swizzle: int | None, + layout: TiledLayout, + shape: tuple[int, ...], + ): + """Generate a transfer schedule for a tiled layout. + + Given a ref with one level tiling applied to it (we assume all dimensions + have been tiled), this function generates an iterable describing a good + schedule for swizzled SMEM loads/stores. + + At each step, the iterable yields a tuple of three values: + * a function that takes a register array and returns the register to be + stored at the current address + * a function that takes a register array and a register loaded from the + current address, and updates the register array with that register + * the current address for load/store instructions + """ + # TODO(apaszke): Use ldmatrix/stmatrix when possible. + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + tiling = layout.tiling + + ref_ty = ir.MemRefType(ref.type) + dtype = ref_ty.element_type + if ref_ty.rank % 2: + raise ValueError("Tiled refence must have even rank") + ref_tiling_shape = tuple(ref_ty.shape[ref_ty.rank // 2:]) + ref_tiling = Tiling((ref_tiling_shape,)) + ref_strides, _ = ref_ty.get_strides_and_offset() + if ref_tiling.untile_shape(tuple(ref_ty.shape)) != shape: + raise ValueError() + if len(layout.base_tile_shape) > len(ref_tiling_shape): + raise ValueError("Memory tiling must be a multiple of the register tiling") + ref_tiling_suffix = ref_tiling_shape[-len(layout.base_tile_shape):] + if any(t % wt for t, wt in zip(ref_tiling_suffix, layout.base_tile_shape)): + raise ValueError("Memory tiling must be a multiple of the register tiling") + + if swizzle not in {32, 64, 128}: + raise ValueError("Only swizzled transfers supported") + bw = mgpu.bytewidth(dtype) + swizzle_tile_elems = 16 // bw + swizzle_group_elems = 128 // bw + swizzle_groups_per_block = swizzle // 16 + swizzle_block_elems = swizzle_groups_per_block * swizzle_group_elems + + tiled_strides = list(tiling.tile_strides(tuple(ref_strides))) + tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape))) + if tiled_strides[layout.vector_dim] != 1: + raise ValueError("Stride of the vectorized dimension should be 1") + for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): + tiled_shape[d] = 1 + full_tiling = Tiling((ref_tiling_shape, *tiling.tiles)) + full_layout = dataclasses.replace(layout, tiling=full_tiling) + + # XXX: This method is still slightly incompete. For example, it does not + # verify that the vector transfers don't cross swizzle tile boundaries. It + # also does not guarantee that the transfer pattern does not cause bank + # conflicts. For that reason, we only allow a select subset of layouts. + if layout != _tiled_wgmma_layout(shape) or bw > 2: + raise NotImplementedError("transfer_tiled2 not general enough yet") + + dyn_tiled_strides = [c(s) for s in tiled_strides] + lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides) + warp_offset = utils.dyn_dot(full_layout.warp_indices(), dyn_tiled_strides) + dyn_offset = arith.addi(lane_offset, warp_offset) + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError("Tiled stores can be performed into SMEM") + ptr = utils.memref_ptr(ref, memory_space=3) + for tile_idx in np.ndindex(*tiled_shape): + const_offset = sum(i * s for i, s in zip(tile_idx, tiled_strides)) + # We split the offset into a part that interacts with swizzling and a + # part that doesn't. This lets us generate better code because constant + # offsets can be fused into load and store instructions. + const_offset_swizzle = const_offset % swizzle_block_elems + const_offset_no_swizzle = const_offset - const_offset_swizzle + offset_pre_swizzle = arith.addi(dyn_offset, c(const_offset_swizzle)) + swizzle_group = arith.remui( + arith.divui(offset_pre_swizzle, c(swizzle_group_elems)), + c(swizzle_groups_per_block), + ) + swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_elems)) + offset = arith.xori(offset_pre_swizzle, swizzle_bits) + reg_ptr = utils.getelementptr(ptr, [offset], dtype) + reg_ptr = utils.getelementptr(reg_ptr, [const_offset_no_swizzle], dtype) + reg_idx = tiling.tile_indices(full_tiling.untile_indices(tile_idx)) + def get_register(regs, reg_idx=reg_idx): + return regs[reg_idx] + def update_registers(regs, new, reg_idx=reg_idx): + regs[reg_idx] = new + yield get_register, update_registers, reg_ptr + def tree_flatten(self): aux = self.layout, self.registers.shape, self.is_signed return list(self.registers.flat), aux diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index f8918488563e..b716456eceb3 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1047,3 +1047,7 @@ def getelementptr( static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices] dyn_indices = [i for i in indices if not isinstance(i, int)] return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype) + + +def dyn_dot(x, y): + return functools.reduce(arith.addi, (arith.muli(a, b) for a, b in zip(x, y))) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index ca2c9a4bf27d..6ea9c02b9639 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -37,7 +37,7 @@ jax_multiplatform_test( "gpu_h100", "gpu_h100_2gpu", ], - shard_count = 4, + shard_count = 8, tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 382cad79fb2b..157f682f5eef 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -18,6 +18,8 @@ import itertools import math import operator +import os +import re import unittest from absl.testing import absltest, parameterized @@ -1627,6 +1629,63 @@ def kernel(ctx, dst, _): expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) np.testing.assert_array_equal(f(), expected) + @parameterized.product( + load_tiled=[False, True], + store_tiled=[False, True], + dtype=[jnp.int16], + swizzle=[32, 64, 128], + num_col_tiles=[1, 2, 4], + ) + def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles): + mlir_dtype = utils.dtype_to_ir_type(dtype) + col_tiling = swizzle // bytewidth(mlir_dtype) + m, n = 128, col_tiling * num_col_tiles + tiling = (64, col_tiling) + tiled_layout = fa._tiled_wgmma_layout((m, n)) + load_layout = tiled_layout if load_tiled else mgpu.WGMMA_LAYOUT + store_layout = tiled_layout if store_tiled else mgpu.WGMMA_LAYOUT + def kernel(ctx, in_, out, smems): + smem_in, smem_out, barrier = smems + ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) + barrier.wait() + t = mgpu.FragmentedArray.load_tiled( + smem_in, swizzle=swizzle, is_signed=True, layout=load_layout + ) + t.to_layout(store_layout).store_tiled(smem_out, swizzle=swizzle) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem_out, dst_ref=out, swizzle=swizzle) + ctx.await_async_copy(0) + expected = ( + np.arange(m * n, dtype=dtype) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .transpose(0, 2, 1, 3) + ) + + prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) + os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" + try: + with jtu.capture_stdout() as get_sass: + iota = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, + [expected, expected, mgpu.TMABarrier()], + )(expected) + finally: + if prev_dump is not None: + os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump + np.testing.assert_array_equal(iota, expected) + + # Verify that we don't use too many registers for the transfers. + # We verify LDS and STS separately, because they might use two different + # methods of computing offsets and we don't rely on CSE between them. + register_pattern = re.compile(r"(R[0-9]+)") + expected_regs = swizzle // bytewidth(mlir_dtype) // 8 + for instr in ("STS", "LDS"): + with self.subTest(instr + " count"): + addrs = re.findall(instr + r".* \[(.*)\]", get_sass()) + chain = itertools.chain.from_iterable + used_regs = set(chain(register_pattern.findall(addr) for addr in addrs)) + self.assertLen(used_regs, expected_regs) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 834e71bbe1e4e23de8f658e7380c385be7c5099a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 8 Nov 2024 09:56:12 -0500 Subject: [PATCH 249/698] Don't perform size 0 slices into scipy rotations. This is disallowed by scipy after https://github.com/scipy/scipy/pull/21776. --- tests/scipy_spatial_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index 540136b33870..fe2232d7ffe6 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -164,7 +164,7 @@ def testRotationConcatenate(self, shape, other_shape, dtype): @jtu.sample_product( dtype=float_dtypes, shape=[(10, 4)], - indexer=[slice(1, 5), slice(0), slice(-5, -3)], + indexer=[slice(1, 5), slice(0, 1), slice(-5, -3)], ) def testRotationGetItem(self, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) From 9763044d27c7584d77e946cc8f6cc36c34355599 Mon Sep 17 00:00:00 2001 From: dymil <30931139+dymil@users.noreply.github.com> Date: Fri, 1 Nov 2024 02:00:05 -0400 Subject: [PATCH 250/698] Fix argmin docstring to not say "maximum" --- jax/_src/numpy/lax_numpy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4200a9fdae72..d2e89833915d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10185,18 +10185,18 @@ def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the minimum value of an array. - JAX implementation of :func:`numpy.argmax`. + JAX implementation of :func:`numpy.argmin`. Args: a: input array - axis: optional integer specifying the axis along which to find the maximum + axis: optional integer specifying the axis along which to find the minimum value. If ``axis`` is not specified, ``a`` will be flattened. out: unused by JAX keepdims: if True, then return an array with the same number of dimensions as ``a``. Returns: - an array containing the index of the maximum value along the specified axis. + an array containing the index of the minimum value along the specified axis. See also: - :func:`jax.numpy.argmax`: return the index of the maximum value. From ce3826d09897eee82e99315ff4f6c9b41c78ce1a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 8 Nov 2024 08:34:35 -0800 Subject: [PATCH 251/698] [Mosaic GPU] Make sure to free the cloned MLIR module when debugging We only recently started using this in tests and it has caused ASAN to report a bunch of leaks. PiperOrigin-RevId: 694510867 --- jaxlib/mosaic/gpu/BUILD | 1 + jaxlib/mosaic/gpu/custom_call.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 875dc1d151ba..2fb8f0103e65 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -120,6 +120,7 @@ cc_library( ":target", "//jaxlib/cuda:cuda_vendor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index e3bbcf0cd0e3..2d479f712408 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -32,6 +32,7 @@ limitations under the License. #include #include "absl/base/optimization.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/status/status.h" @@ -290,6 +291,7 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, } module = module.clone(); // Prevent accidental modification. + absl::Cleanup module_destroyer = [module] { module->erase(); }; auto passes = GetPassPipeline( module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); if (mlir::failed(passes) || From 4a365670f7a5d05f275b71c8f675422773d53aa5 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 7 Nov 2024 20:06:22 -0500 Subject: [PATCH 252/698] Fix pre-commit to run on all files in CI. --- .github/workflows/ci-build.yaml | 2 +- jax/_src/array.py | 2 +- jax/_src/config.py | 4 ++-- jax/experimental/mosaic/gpu/dialect_lowering.py | 2 ++ jax/experimental/mosaic/gpu/fragmented_array.py | 1 + 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 5c786272ee3d..db1477ac38b1 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -39,7 +39,7 @@ jobs: with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} - - run: pre-commit run --show-diff-on-failure --color=always + - run: pre-commit run --show-diff-on-failure --color=always --all-files build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" diff --git a/jax/_src/array.py b/jax/_src/array.py index 515fc2c7c7e6..cf346067ea31 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1173,7 +1173,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): copy_outs = xc.batched_copy_array_to_devices_with_sharding( batch_xs, batch_devs, batch_shardings, batch_cs) else: - copy_outs = xc.batched_copy_array_to_devices_with_sharding( # type: ignore + copy_outs = xc.batched_copy_array_to_devices_with_sharding( # pytype: disable=missing-parameter batch_xs, batch_devs, batch_shardings) for i, copy_out in safe_zip(batch_indices, copy_outs): assert results[i] is None diff --git a/jax/_src/config.py b/jax/_src/config.py index 215ef443c799..f3edde69981f 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -22,7 +22,7 @@ import os import sys import threading -from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING +from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast from jax._src import lib from jax._src.lib import guard_lib @@ -371,7 +371,7 @@ class _Unset: pass _thread_local_state = threading.local() - class State(Generic[_T]): + class State(Generic[_T]): # type: ignore[no-redef] __slots__ = ( '_name', '_value', '_update_thread_local_hook', '_update_global_hook', diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 7d36272dc111..8eee9aef3f2e 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -29,6 +29,8 @@ from jaxlib.mlir.dialects import nvvm from .utils import c, memref_ptr, single_thread_predicate +# mypy: ignore-errors + MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]] diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 040174b900c9..0c5dd0ef793e 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -19,6 +19,7 @@ import dataclasses import functools import math +from collections.abc import Callable from typing import Iterable, Sequence, TypeVar import jax From 78da9fa4322bb62e4e1cc55977cb79fc20cb0ccb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Oct 2024 15:33:24 -0700 Subject: [PATCH 253/698] Add float8_e4m3 and float8_e3m4 types support --- jax/_src/dtypes.py | 19 +++++++++++++++++++ jax/_src/export/serialization.fbs | 2 ++ jax/_src/export/serialization.py | 4 ++++ jax/_src/export/serialization_generated.py | 2 ++ jax/_src/interpreters/mlir.py | 12 ++++++------ jax/_src/lax/lax.py | 12 ++++++++++-- jax/_src/numpy/lax_numpy.py | 4 ++++ jax/_src/public_test_util.py | 14 ++++++++++++++ jax/_src/test_util.py | 17 +++++++++++++---- jax/numpy/__init__.py | 9 +++++++++ tests/dtypes_test.py | 4 ++++ 11 files changed, 87 insertions(+), 12 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index ac0418932b83..c9710c287879 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,12 +90,17 @@ def type(self) -> type: ... # fp8 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float8_e3m4: type[np.generic] | None = None +float8_e4m3: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz +_float8_e3m4_dtype: np.dtype | None = None +_float8_e4m3_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] +# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 +if hasattr(ml_dtypes, "float8_e4m3"): + float8_e4m3 = ml_dtypes.float8_e4m3 + _float8_e4m3_dtype = np.dtype(float8_e4m3) + _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e4m3_dtype) + _float8_dtypes.insert(0, _float8_e4m3_dtype) +if hasattr(ml_dtypes, "float8_e3m4"): + float8_e3m4 = ml_dtypes.float8_e3m4 + _float8_e3m4_dtype = np.dtype(float8_e3m4) + _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e3m4_dtype) + _float8_dtypes.insert(0, _float8_e3m4_dtype) + # 2-bit integer support int2: type[np.generic] | None = None uint2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 3198f83aa120..b71b377d8999 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -67,6 +67,8 @@ enum DType: byte { i4 = 15, ui4 = 16, + f8_e3m4 = 24, + f8_e4m3 = 23, f8_e4m3b11fnuz = 17, f8_e4m3fn = 18, f8_e4m3fnuz = 19, diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index e392289da64d..0d9ce961b556 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -359,6 +359,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, } +if dtypes._float8_e3m4_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 +if dtypes._float8_e4m3_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index 18dd2c3cbab1..70d298020961 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -53,6 +53,8 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 + f8_e3m4 = 24 + f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2c0e26019e4d..54a85f92c873 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -184,13 +184,13 @@ def _is_ir_values(x: IrValues) -> bool: if dtypes.int2 is not None: assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( - ir.IntegerType.get_signless, 2 - ) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( - ir.IntegerType.get_unsigned, 2 - ) + _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) + _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) +if dtypes.float8_e3m4 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get +if dtypes.float8_e4m3 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8b6a517a54b3..7fa2dd4acbfa 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -937,11 +937,15 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): - fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), + fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)) + np.dtype(dtypes.float8_e5m2fnuz)] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3625,6 +3629,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + if dtypes.float8_e3m4 is not None: + fp8_dtypes += (dtypes.float8_e3m4,) + if dtypes.float8_e4m3 is not None: + fp8_dtypes += (dtypes.float8_e4m3,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d2e89833915d..c419c083f837 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -217,6 +217,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) +if dtypes.float8_e3m4 is not None: + float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +if dtypes.float8_e4m3 is not None: + float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 9859eb64cda2..6bbcdd08471f 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -90,6 +90,14 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } +# TODO: make this unconditional when ml_dtypes>=0.5.0 is required +if _dtypes.float8_e3m4 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 +if _dtypes.float8_e4m3 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.float8_e5m2fnuz, _dtypes.bfloat16, ] + + if _dtypes.float8_e4m3 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e4m3) + if _dtypes.float8_e3m4 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e3m4) + def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index bb81c979bc48..e7707f58fc4a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1431,10 +1431,19 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ - _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, - _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + float_dtypes = [ + _dtypes.bfloat16, + _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e4m3fn, + _dtypes.float8_e4m3fnuz, + _dtypes.float8_e5m2, + _dtypes.float8_e5m2fnuz, + ] + if _dtypes.float8_e3m4 is not None: + float_dtypes += [_dtypes.float8_e3m4] + if _dtypes.float8_e4m3 is not None: + float_dtypes += [_dtypes.float8_e4m3] + return [np.dtype(t) for t in float_dtypes] @_cached_property def floating(self): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9be73e96adcf..9a643bf49bf0 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -273,6 +273,15 @@ except ImportError: pass +# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 +try: + from jax._src.numpy.lax_numpy import ( + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + ) +except ImportError: + pass + from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 89d70871a8f9..6c7e9e3ab712 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -64,6 +64,10 @@ fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz)] +if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] +if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes From afd8239ea435e329f1778d22c77b0269788160fd Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Fri, 8 Nov 2024 11:02:09 -0800 Subject: [PATCH 254/698] [SDY] add JAX lowering to Shardy `ShardingGroupOp` for shard_alike. PiperOrigin-RevId: 694567084 --- jax/_src/shard_alike.py | 8 +++++++- tests/BUILD | 1 + tests/shard_alike_test.py | 8 +++----- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/jax/_src/shard_alike.py b/jax/_src/shard_alike.py index 574d725c4999..e2ddec15e8d4 100644 --- a/jax/_src/shard_alike.py +++ b/jax/_src/shard_alike.py @@ -15,6 +15,7 @@ from functools import partial import itertools +from jax._src import config from jax._src import core from jax._src.interpreters import ad from jax._src.interpreters import mlir @@ -24,7 +25,7 @@ from jax._src.util import safe_zip from jax._src.lib import xla_client as xc from jax._src.api_util import shaped_abstractify -from jax._src.lib.mlir import ir +from jax._src.lib.mlir import dialects, ir _next_shard_group_id = itertools.count() @@ -91,6 +92,11 @@ def _group_shard( ) -> tuple[ir.Value, ir.Value]: shard_group_id = next(_next_shard_group_id) + if config.use_shardy_partitioner.value: + dialects.sdy.ShardingGroupOp(x, shard_group_id) + dialects.sdy.ShardingGroupOp(y, shard_group_id) + return x, y + unknown_op_sharding = xc.OpSharding() unknown_op_sharding.type = xc.OpSharding.Type.UNKNOWN unknown_op_sharding.is_shard_group = True diff --git a/tests/BUILD b/tests/BUILD index ba725842b36c..bd248fad2902 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -283,6 +283,7 @@ jax_multiplatform_test( "tpu_v3_2x2", "tpu_v5e_4x2", "tpu_v4_2x2", + "tpu_v3_2x2_shardy", ], deps = [ "//jax:experimental", diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 10267ff5eb98..557be3839baf 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import numpy as np from absl.testing import absltest +from jax._src import config from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec as P from jax.experimental.shard_alike import shard_alike @@ -221,18 +222,16 @@ def test_shard_alike_inputs(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) - rep_s = NamedSharding(mesh, P()) arr = jax.device_put(np_inp, s) - arr2 = jax.device_put(np_inp, rep_s) def f(x, y): return shard_alike(x, y) - eager_out1, eager_out2 = f(arr, arr2) + eager_out1, eager_out2 = f(arr, np_inp) self.assertEqual(eager_out1.sharding, s) self.assertEqual(eager_out2.sharding, s) - out1, out2 = jax.jit(f)(arr, arr2) + out1, out2 = jax.jit(f)(arr, np_inp) self.assertEqual(out1.sharding, s) self.assertEqual(out2.sharding, s) @@ -282,6 +281,5 @@ def test_sharding_preserverd_single_device(self): _, y = shard_alike(x, jnp.arange(8)) self.assertEqual(y.sharding, s) - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 8f169e7fb549cb5d6711f60d9433def5d3343c3f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 8 Nov 2024 11:19:46 -0800 Subject: [PATCH 255/698] Disable the paged_attention test on TPU v5p. This test is failing in CI. PiperOrigin-RevId: 694574616 --- tests/pallas/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 3166526dffcb..f95ea53b4929 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -390,6 +390,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], + disable_configs = [ + "tpu_v5p_1x1", + ], enable_backends = ["tpu"], shard_count = 5, tags = [ From 7285f10e84dc95491792234c59e9638aa6f8d35f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 8 Nov 2024 11:32:40 -0800 Subject: [PATCH 256/698] Disable lax_test on ARM in Google's internal CI. There are numerical errors from the complex plane function tests. PiperOrigin-RevId: 694579368 --- tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/BUILD b/tests/BUILD index bd248fad2902..3d79298b8a57 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -551,6 +551,7 @@ jax_multiplatform_test( name = "lax_test", srcs = ["lax_test.py"], backend_tags = { + "cpu": ["not_run:arm"], # Numerical issues, including https://github.com/jax-ml/jax/issues/24787 "tpu": ["noasan"], # Times out. }, shard_count = { From 0cc1747873fc869fd04ca0444e1b969df51f8169 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 8 Nov 2024 13:35:04 -0800 Subject: [PATCH 257/698] Add tests for jnp.einsum in Pallas PiperOrigin-RevId: 694622626 --- tests/pallas/ops_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 318df0b0bfcf..e4017c21f017 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1033,6 +1033,19 @@ def isnan(x_ref, o_ref): x = x.at[3].set(jnp.nan) np.testing.assert_allclose(isnan(x), jnp.isnan(x)) + def test_jnp_einsum_grad_y_pallas(self): + x = jnp.arange(128 * 256, dtype=jnp.float32).reshape((128, 256)) + y = jnp.arange(256 * 128, dtype=jnp.float32).reshape((128, 256)) + + def kernel(x_ref, y_ref, out_ref): + # grad_y side of grouped matmul + out_ref[...] = jnp.einsum('mk,mn->kn', x_ref[...], y_ref[...]) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((256, 256), jnp.float32) + )(x, y) + np.testing.assert_array_equal(out, jnp.einsum('mk,mn->kn', x, y)) + @parameterized.parameters( ("int32", "float32"), ("float32", "float32"), From aa6adfb9c51a06812cfcc3e97d112e18f6d17454 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Nov 2024 13:36:28 -0800 Subject: [PATCH 258/698] Remove unused import --- tests/shard_alike_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 557be3839baf..c5f80a6d97f3 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -18,7 +18,6 @@ import jax.numpy as jnp import numpy as np from absl.testing import absltest -from jax._src import config from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec as P from jax.experimental.shard_alike import shard_alike From c1360f5463b49da81189fb9b9bff0a0b219924d4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 8 Nov 2024 13:40:27 -0800 Subject: [PATCH 259/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6d95565e652fdd021dcb6d306a54e786572e7a34. PiperOrigin-RevId: 694624452 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 0aa248d61faa..7226f7900fce 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "b0aae988216d48e2379c8de1c7c4aedeb98d8985" -XLA_SHA256 = "42ed25652bf91b98c31a7d456d12ea4ca78d3b4083514047f650851383f2cb9d" +XLA_COMMIT = "6d95565e652fdd021dcb6d306a54e786572e7a34" +XLA_SHA256 = "1fc547c054905d2724ecc9f4698d40c3f887e0193aed55aaf6ac36774b800c66" def repo(): tf_http_archive( From 7404e0d29d1f3fe1c275fb497d5983304c9419be Mon Sep 17 00:00:00 2001 From: Ke Wu Date: Fri, 8 Nov 2024 14:20:57 -0800 Subject: [PATCH 260/698] Add typing overloads for jax.extend.ffi.ffi_call() to aid type checkers PiperOrigin-RevId: 694639758 --- jax/_src/extend/ffi.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 3012b74cf941..60db341254c6 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -18,7 +18,7 @@ import ctypes import functools import os -from typing import Any +from typing import Any, overload import numpy as np @@ -240,6 +240,43 @@ def _convert_layouts_for_ffi_call( for aval, layout in zip(avals, layouts)) +# ffi_call() returns as many results as result_shape_dtypes. +@overload +def ffi_call( + target_name: str, + result_shape_dtypes: ResultMetadata, + *deprecated_args: ArrayLike, + has_side_effect: bool = ..., + vmap_method: str | None = ..., + input_layouts: Sequence[FfiLayoutOptions] | None = ..., + output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ..., + input_output_aliases: dict[int, int] | None = ..., + custom_call_api_version: int = ..., + legacy_backend_config: str | None = ..., + vectorized: bool | DeprecatedArg = ..., + **deprecated_kwargs: Any, +) -> Callable[..., Array] | Array: + ... + + +@overload +def ffi_call( + target_name: str, + result_shape_dtypes: Sequence[ResultMetadata], + *deprecated_args: ArrayLike, + has_side_effect: bool = ..., + vmap_method: str | None = ..., + input_layouts: Sequence[FfiLayoutOptions] | None = ..., + output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ..., + input_output_aliases: dict[int, int] | None = ..., + custom_call_api_version: int = ..., + legacy_backend_config: str | None = ..., + vectorized: bool | DeprecatedArg = ..., + **deprecated_kwargs: Any, +) -> Callable[..., Sequence[Array]] | Sequence[Array]: + ... + + def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], From d833066a1fddb0363ffa15fa2eb3fff29cf446fe Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Fri, 8 Nov 2024 14:32:59 -0800 Subject: [PATCH 261/698] [MOSAIC:GPU] Add `async_load`, `async_store`, and supporting attributes to the MLIR Mosaic GPU Dialect. PiperOrigin-RevId: 694643777 --- jaxlib/mosaic/dialect/gpu/BUILD | 31 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc | 91 +++++- jaxlib/mosaic/dialect/gpu/mosaic_gpu.h | 11 +- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 185 +++++++++++ jaxlib/mosaic/python/BUILD | 3 +- jaxlib/mosaic/python/mosaic_gpu.py | 3 +- tests/mosaic/gpu_dialect_test.py | 390 +++++++++++++++++++++++- 7 files changed, 686 insertions(+), 28 deletions(-) diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 4207e769e6a1..7c7a3589d9e6 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -29,7 +29,9 @@ td_library( srcs = ["mosaic_gpu.td"], includes = ["."], deps = [ + "@llvm-project//mlir:BasicPtxBuilderIntTdFiles", "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", ], ) @@ -109,6 +111,7 @@ cc_library( hdrs = ["mosaic_gpu.h"], deps = [ ":mosaic_gpu_inc_gen", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -116,9 +119,11 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", "@tsl//tsl/platform:statusor", @@ -152,12 +157,19 @@ cc_test( gentbl_filegroup( name = "mosaic_gpu_python_gen_raw", tbl_outs = [ + ( + [ + "-gen-python-enum-bindings", + "-bind-dialect=mosaic_gpu", + ], + "_mosaic_gpu_gen_enums_raw.py", + ), ( [ "-gen-python-op-bindings", "-bind-dialect=mosaic_gpu", ], - "_mosaic_gpu_gen_raw.py", + "_mosaic_gpu_gen_ops_raw.py", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", @@ -169,10 +181,19 @@ gentbl_filegroup( ) genrule( - name = "mosaic_gpu_python_gen", - srcs = ["_mosaic_gpu_gen_raw.py"], - outs = ["_mosaic_gpu_gen.py"], - cmd = "cat $(location _mosaic_gpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@", + name = "mosaic_gpu_python_gen_enums", + srcs = ["_mosaic_gpu_gen_enums_raw.py"], + outs = ["_mosaic_gpu_gen_enums.py"], + cmd = """ + cat $(location _mosaic_gpu_gen_enums_raw.py) | \ + sed -e 's/^from \\.\\.ir/from jaxlib\\.mlir\\.ir/g; s/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@""", +) + +genrule( + name = "mosaic_gpu_python_gen_ops", + srcs = ["_mosaic_gpu_gen_ops_raw.py"], + outs = ["_mosaic_gpu_gen_ops.py"], + cmd = "cat $(location _mosaic_gpu_gen_ops_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@", ) DIALECT_CAPI_SOURCES = [ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index 8c5573bf1b80..c86450fbdf0c 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -18,18 +18,17 @@ limitations under the License. #include #include -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -44,6 +43,12 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mlir/include/mlir/IR/Diagnostics.h" #include "tsl/platform/statusor.h" // Generated definitions. @@ -232,11 +237,89 @@ void DeclareRuntimeFunctions(mlir::OpBuilder& builder) { .setVisibility(mlir::func::FuncOp::Visibility::Private); } +bool IsContiguous(mlir::MemRefType type) { + return type.getLayout().isIdentity() || + (type.hasStaticShape() && type.getNumElements() > 0 && + mlir::memref::isStaticShapeAndContiguousRowMajor(type)); +} + +namespace { +llvm::LogicalResult VerifyCommonLoadStoreOp( + mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name, + mlir::MemRefType smem_type, absl::string_view smem_name, + mlir::ArrayRef slice_lengths, int num_indices) { + auto error = [loc](auto... params) { + return emitError(loc, llvm::formatv(params...)); + }; + + if (!IsContiguous(smem_type)) { + return error("The `{0}` memref must be contiguous.", smem_name); + } + if (gmem_type.getElementType() != smem_type.getElementType()) { + return error( + "The `source` and `destination` memrefs must have the same element " + "type."); + } + if (absl::c_any_of(slice_lengths, [](int64_t s) { return s < -1; })) { + return error( + "The `slice_lengths` attribute must not contain values less than -1."); + } + if (gmem_type.getRank() != + smem_type.getRank() + absl::c_count(slice_lengths, -1)) { + return error( + "The rank of the `{0}` must be equal to the rank of the " + "`{1}` plus the number of collapsed dimensions as indicated " + "by -1 values in `slice_lengths`.", + gmem_name, smem_name); + } + if (num_indices != gmem_type.getRank()) { + return error("The size of `indices` must be equal to the rank of `{0}`.", + gmem_name); + } + if (slice_lengths.size() != gmem_type.getRank()) { + return error( + "The size of `slice_lengths` must be equal to the rank of `{0}`.", + gmem_name); + } + return llvm::success(); +} +} // namespace + +llvm::LogicalResult AsyncLoadOp::verify() { + auto r = VerifyCommonLoadStoreOp(getLoc(), getSource().getType(), "source", + getDestination().getType(), "destination", + getSliceLengths(), getIndices().size()); + if (failed(r)) { + return r; + } + + for (int i = 0; i < getCollective().size(); ++i) { + for (int k = i + 1; k < getCollective().size(); ++k) + if (getCollective()[i] == getCollective()[k]) { + return emitError( + "The `collective` attribute must not contain duplicate " + "dimensions."); + } + } + + return llvm::success(); +} + +llvm::LogicalResult AsyncStoreOp::verify() { + return VerifyCommonLoadStoreOp(getLoc(), getDestination().getType(), + "destination", getSource().getType(), "source", + getSliceLengths(), getIndices().size()); +} + void MosaicGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc" + >(); addOperations< #define GET_OP_LIST #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_ops.cc.inc" diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index b46675d1c9a7..14c0d0295a8f 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -19,14 +19,17 @@ limitations under the License. #include #include -#include "absl/status/status.h" -#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep @@ -43,6 +46,10 @@ namespace mosaic_gpu { using Memref = ::mlir::TypedValue<::mlir::MemRefType>; using Pointer = ::mlir::TypedValue<::mlir::LLVM::LLVMPointerType>; +struct GlobalMemory : public mlir::SideEffects::Resource::Base { + llvm::StringRef getName() final { return ""; } +}; + constexpr absl::string_view kRuntimeTmaDescriptorInitializerName = "mosaic_gpu_init_tma_desc"; constexpr absl::string_view kRuntimeMemcpyAsyncH2DName = diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index b05e6ebd71b7..e20dd55043e6 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -16,6 +16,8 @@ limitations under the License. #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ +include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/IR/CommonAttrConstraints.td" @@ -28,6 +30,7 @@ def MosaicGPU_Dialect : Dialect { let name = "mosaic_gpu"; let cppNamespace = "::mosaic_gpu"; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } class MosaicGPU_Type traits = []> @@ -35,6 +38,11 @@ class MosaicGPU_Type traits = []> let mnemonic = mnemonic_; } +class MosaicGPU_Attr + : AttrDef { + let mnemonic = mnemonic_; +} + def MosaicGPU_Barrier : MosaicGPU_Type<"Barrier", "barrier", [MemRefElementTypeInterface]> { let summary = "barrier"; let description = "A barrier to use for synchronizing threads"; @@ -83,4 +91,181 @@ def MosaicGPU_FragmentedLayoutAttr : EnumAttr< let assemblyFormat = "`<` $value `>`"; } +// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td +// but it was not possible to reuse that definition. Including that file +// pulls in ops definitions that we don't want and they fail to compile. +def MosaicGPU_Dimension : I32EnumAttr<"Dimension", + "a dimension, either 'x', 'y', or 'z'", + [ + I32EnumAttrCase<"x", 0>, + I32EnumAttrCase<"y", 1>, + I32EnumAttrCase<"z", 2> + ]>{ + let cppNamespace = "::mosaic_gpu"; + let genSpecializedAttr = 0; +} + +def MosaicGPU_DimensionAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode", + "What swizzling to use for a memory access.", + [ + I32EnumAttrCase<"kNoSwizzle", 0, "none">, + I32EnumAttrCase<"k32ByteSwizzle", 1, "32">, + I32EnumAttrCase<"k64ByteSwizzle", 2, "64">, + I32EnumAttrCase<"k128ByteSwizzle", 3, "128"> + ]>{ + let cppNamespace = "::mosaic_gpu"; + let genSpecializedAttr = 0; +} + +def MosaicGPU_SwizzlingModeAttr : EnumAttr; + +def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> { + let parameters = (ins Variadic:$tiling); + let summary = "Tiles a suffix of memref dimensions."; + let description = [{ + For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), + the shape of the result will be (5, 2, 4, 64, 32). The shape always ends + with the tile shape, and the size of tiled dimensions is divided by the tile + size. This is especially useful for swizzled WGMMA, which expect tiled + layouts in shared memory. + + Each tiled dimension must have a size that is either smaller than the + corresponding tile size or a multiple of the tile size. + }]; + let assemblyFormat = "`<` $tiling `>`"; +} + +def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> { + let parameters = (ins Variadic:$permutation); + let summary = "Specifies how to transpose a memref."; + let assemblyFormat = "`<` $permutation `>`"; +} + +def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">; + +def MosaicGPU_AsyncLoadOp : Op]>]> { + let summary = "Schedules an async load of a MemRef from GMEM to SMEM"; + let description = [{ + Schedules an async copy of the contents of the `source` MemRef in GMEM to + the `destination` MemRef in SMEM. The `destination` MemRef in SMEM must be + contiguous. + + If `arrive` is true, the `arrive.expect-tx(expect_count)` operation will be + executed on the provided `barrier` before the copy is scheduled. Upon + completion of the copy, the `complete-tx(complete-count)` operation will + always be executed on the provided `barrier`. + + The `indices` and `slice_lengths` inputs define what slice of the GMEM + `source` corresponds to the SMEM `destination`. Both `indices` and + `slice_lengths` must have a length equal to the rank of the `source`. The + values in `indices` are the starting indices of each dimension and the + values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths` + indicates that the slice length is 1 and that the corresponding dimension + should be collapsed and does not appear in the `destination` MemRef. + + Additional `transforms` may be provided to control how the `source` data is + mapped to the `destination`. The transformations will be composed in the + order they are provided. The `swizzle` attribute controls what swizzling + is applied to the data after it is transformed, before it is finally written + to SMEM. The transformed data is written in row-major order to the + contiguous SMEM `destination`. The untransformed `source` data does not need + to be contiguous, except for the last dimension, which needs to be + contiguous and the minor-most dimension. + + The `collective` attribute can be provided to use TMA multicast to more + efficiently load the GMEM data in cases where multiple thread blocks are + grouped together in a cluster and need to load the same data. Each block in + a cluster will first load a slice from GMEM to SMEM and then the slices will + be multicast to all other blocks in the cluster. In this way TMA multicast + guarnatees L2 cache hits. The `collective` attribute is the list of + cluster dimensions along which to partition the input data loads. + + The `predicate` input should be set to `true` by a single thread in the + warpgroup so that it schedules the load operation. All other threads in the + warpgroup should set the `predicate` to `false`. + }]; + + let arguments = (ins + MemRefOf<[AnyType]>:$source, + MemRefOf<[AnyType]>:$destination, + MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier, + Variadic:$indices, + PtxPredicate:$predicate, + + // Attributes + DenseI64ArrayAttr:$slice_lengths, + TypedArrayAttrBase, "transforms">:$transforms, + DefaultValuedAttr:$swizzle, + DefaultValuedAttr:$arrive, + TypedArrayAttrBase:$collective + ); + + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `destination` `(` $destination `:` type($destination) `)` + `barrier` `(` $barrier `:` type($barrier) `)` + `indices` `(` $indices `)` + `predicate` `(` $predicate `)` + attr-dict + }]; + + let hasVerifier = 1; +} + +def MosaicGPU_AsyncStoreOp : Op]>]> { + let summary = "Schedules an async store of a MemRef from SMEM to GMEM"; + let description = [{ + Schedules an async store of the contents of the `source` MemRef in SMEM to + the `destination` MemRef in GMEM. The `source` MemRef in SMEM must be + contiguous. + + The `indices` and `slice_lengths` inputs define what slice of the GMEM + `destination` corresponds to the SMEM `source`. Both `indices` and + `slice_lengths` must have a length equal to the rank of the `destination`. + The values in `indices` are the starting indices of each dimension and the + values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths` + indicates that this dimension is collapsed in the `source` and needs to be + expanded to a slice of size 1 in the `destination`. + + Additional `transforms` may be provided to control how the `destination` + data in GMEM is mapped to the `source` data in SMEM. The transformations + will be composed in the order they are provided. The `swizzle` attribute + is the swizzling mode of the `source` data in SMEM. The `source` SMEM data + is contiguous and the transformed data is written to the `destination` GMEM + which does not need to be contiguous. + + The `predicate` input should be set to `true` by a single thread in the + warpgroup so that it schedules the store operation. All other threads in the + warpgroup should set the `predicate` to `false`. + }]; + + let arguments = (ins + MemRefOf<[AnyType]>:$source, + MemRefOf<[AnyType]>:$destination, + Variadic:$indices, + PtxPredicate:$predicate, + + // Attributes + DenseI64ArrayAttr:$slice_lengths, + TypedArrayAttrBase, "transforms">:$transforms, + DefaultValuedAttr:$swizzle + ); + + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `destination` `(` $destination `:` type($destination) `)` + `indices` `(` $indices `)` + `predicate` `(` $predicate `)` + attr-dict + }]; + + let hasVerifier = 1; +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index ef6230f70321..6e575fb3092a 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -21,7 +21,8 @@ py_library( name = "gpu_dialect", srcs = [ "mosaic_gpu.py", - "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py", + "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_enums.py", + "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_ops.py", ], visibility = ["//visibility:public"], deps = [ diff --git a/jaxlib/mosaic/python/mosaic_gpu.py b/jaxlib/mosaic/python/mosaic_gpu.py index 3157242e48a8..f99f53cfdb69 100644 --- a/jaxlib/mosaic/python/mosaic_gpu.py +++ b/jaxlib/mosaic/python/mosaic_gpu.py @@ -23,7 +23,8 @@ # pylint: disable=g-bad-import-order -from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen import * # pylint: disable=wildcard-import # type: ignore[import-not-found] +from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_ops import * # pylint: disable=wildcard-import # type: ignore[import-not-found] +from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_enums import * # pylint: disable=wildcard-import # type: ignore[import-not-found] from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import * # pylint: disable=wildcard-import # type: ignore[import-not-found] try: diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index c5a3e9d6cc57..ae16efe66e71 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -22,9 +22,9 @@ from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf - from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import @@ -50,12 +50,15 @@ def walk_operations(op: ir.OpView, callback): callback(op) -def find_if(module: ir.Module, - predicate: Callable[[ir.OpView], bool]) -> list[ir.OpView]: +def find_if( + module: ir.Module, predicate: Callable[[ir.OpView], bool] +) -> list[ir.OpView]: result = [] + def callback(op: ir.OpView): if predicate(op): result.append(op) + for op in module.body.operations: walk_operations(op, callback) return result @@ -81,16 +84,19 @@ def test_dialect_module_is_loaded(self): def test_initialize_barrier_op_result_memref_must_wrap_barriers(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( - ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1) + ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1 + ) with self.assertRaisesRegex( - ir.MLIRError, "must be memref of barrier values"): + ir.MLIRError, "must be memref of barrier values" + ): self.module.operation.verify() def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=0) + arrival_count=0, + ) with self.assertRaisesRegex(ir.MLIRError, "value is positive"): self.module.operation.verify() @@ -98,10 +104,358 @@ def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=1) + arrival_count=1, + ) self.assertTrue(self.module.operation.verify()) - self.assertIsInstance(self.module.body.operations[0], - mgpu.InitializeBarrierOp) + self.assertIsInstance( + self.module.body.operations[0], mgpu.InitializeBarrierOp + ) + + def test_async_load_op_dest_must_be_contiguous(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get( + [4, 8], + ir.F32Type.get(), + layout=ir.Attribute.parse("strided<[16, 1]>"), + ), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `destination` memref must be contiguous", + ): + self.module.operation.verify() + + def test_async_load_op_source_and_dest_must_have_same_element_type(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F64Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`source` and `destination` memrefs must have the same element", + ): + self.module.operation.verify() + + def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[-2, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `slice_lengths` attribute must not contain values less than -1", + ): + self.module.operation.verify() + + def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([1, 4, 8], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[-1, 4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`destination` plus the number of collapsed dimensions as indicated", + ): + self.module.operation.verify() + + def test_async_load_op_indices_size_must_match_source_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `indices` must be equal to the rank of `source`", + ): + self.module.operation.verify() + + def test_async_load_op_slice_lengths_size_must_match_source_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `slice_lengths` must be equal to the rank of `source`", + ): + self.module.operation.verify() + + def test_async_load_op_slice_collective_must_be_unique(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")), + ir.IntegerType.get_signless(32), + name="async_load", + )( + lambda source, destination, barrier, *indices: mgpu.async_load( + source, + destination, + barrier, + indices, + slice_lengths=[4], + transforms=ir.ArrayAttr.get([]), + collective=ir.ArrayAttr.get([ + ir.Attribute.parse( + f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>" + ), + ir.Attribute.parse( + f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>" + ), + ]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `collective` attribute must not contain duplicate dimensions", + ): + self.module.operation.verify() + + def test_async_store_op_source_must_be_contiguous(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get( + [4, 8], + ir.F32Type.get(), + layout=ir.Attribute.parse("strided<[16, 1]>"), + ), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `source` memref must be contiguous", + ): + self.module.operation.verify() + + def test_async_store_op_source_and_dest_must_have_same_element_type(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F64Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`source` and `destination` memrefs must have the same element", + ): + self.module.operation.verify() + + def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[-2, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The `slice_lengths` attribute must not contain values less than -1", + ): + self.module.operation.verify() + + def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([1, 4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[-1, 4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "`source` plus the number of collapsed dimensions as indicated", + ): + self.module.operation.verify() + + def test_async_store_op_indices_size_must_match_destination_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.MemRefType.get([4, 8], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `indices` must be equal to the rank of `destination`", + ): + self.module.operation.verify() + + def test_async_store_op_slice_lengths_size_must_match_source_rank(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.MemRefType.get([4], ir.F32Type.get()), + ir.MemRefType.get([4], ir.F32Type.get()), + ir.IntegerType.get_signless(32), + name="async_store", + )( + lambda source, destination, *indices: mgpu.async_store( + source, + destination, + indices, + slice_lengths=[4, 8], + transforms=ir.ArrayAttr.get([]), + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + "The size of `slice_lengths` must be equal to the rank of" + " `destination`", + ): + self.module.operation.verify() class DialectLoweringTest(DialectTest): @@ -110,11 +464,13 @@ def test_lowering_removes_mosaic_gpu_ops(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=1) + arrival_count=1, + ) lower_mgpu_dialect(self.module) self.assertEmpty( - list(filter(is_mosaic_gpu_op, self.module.body.operations))) + list(filter(is_mosaic_gpu_op, self.module.body.operations)) + ) def test_lowering_traverses_regions_correctly(self): with ir.InsertionPoint(self.module.body): @@ -124,12 +480,14 @@ def test_lowering_traverses_regions_correctly(self): with ir.InsertionPoint(if_op.then_block): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=1) + arrival_count=1, + ) scf.yield_([]) lower_mgpu_dialect(self.module) self.assertEmpty( - list(filter(is_mosaic_gpu_op, if_op.then_block.operations))) + list(filter(is_mosaic_gpu_op, if_op.then_block.operations)) + ) def test_initialize_barrier_op_lowering_rule(self): shape = (3, 4) @@ -139,12 +497,14 @@ def test_initialize_barrier_op_lowering_rule(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=arrival_count) + arrival_count=arrival_count, + ) lower_mgpu_dialect(self.module) all_mbarrier_init_shared_ops = find_if( self.module, - lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME) + lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME, + ) # One nvvm.mbarrier_init_shared is issued per barrier. self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements) From d352f4f245246cab78b43d9cdae0b448d111de02 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 8 Nov 2024 18:14:51 -0800 Subject: [PATCH 262/698] Put the set of current spmd axis names in the axis env instead of spelunking through the trace stack to find it. PiperOrigin-RevId: 694710181 --- jax/_src/core.py | 21 ++++++++++++++++++--- jax/_src/interpreters/batching.py | 14 ++++++++------ jax/experimental/shard_map.py | 7 ++----- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 6f96dc760cc0..d9c9306d854c 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -955,6 +955,7 @@ def __eq__(self, other): @dataclass(frozen=True) class AxisEnv: axis_sizes : dict[AxisName, int] + spmd_axis_names : set[AxisName] def axis_size(self, axis_name): if axis_name not in self.axis_sizes: @@ -971,20 +972,24 @@ def axis_names(self): def pop_pure(self, axis_name): new_sizes = self.axis_sizes.copy() new_sizes.pop(axis_name) - return AxisEnv(new_sizes) + return AxisEnv(new_sizes, self.spmd_axis_names) def extend_pure(self, name_size_pairs): new_sizes = self.axis_sizes.copy() new_sizes.update((name, size) for name, size in name_size_pairs if name is not no_axis_name) - return AxisEnv(new_sizes) + return AxisEnv(new_sizes, self.spmd_axis_names) + + def add_spmd_axis_names(self, axis_names): + new_spmd_axis_names = self.spmd_axis_names | set(axis_names) + return AxisEnv(self.axis_sizes, new_spmd_axis_names) def as_hashable_key(self): return tuple((name, size) for (name, size) in self.axis_sizes.items() if name is not no_axis_name) eval_trace = EvalTrace() -top_axis_env = AxisEnv({}) +top_axis_env = AxisEnv({}, set()) class TracingContext(threading.local): trace: Trace | None @@ -1045,6 +1050,16 @@ def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]): finally: trace_ctx.set_axis_env(prev) +@contextmanager +def add_spmd_axis_names(axis_names: AxisName | None): + prev = trace_ctx.axis_env + try: + if axis_names is not None: + trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names)) + yield + finally: + trace_ctx.set_axis_env(prev) + def get_axis_env(): return trace_ctx.axis_env diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 590e60383b90..0adb582a7993 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -596,9 +596,10 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, source_info_util.current())) in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) - with core.set_current_trace(trace): - with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]): - outs = yield in_tracers, {} + with (core.set_current_trace(trace), + core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), + core.add_spmd_axis_names(axis_data.spmd_name)): + outs = yield in_tracers, {} out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), @@ -795,9 +796,10 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): _, in_axes = resolve_ragged_axes(in_vals, in_axes) in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val for val, dim in zip(in_vals, in_axes)] - with core.set_current_trace(trace): - with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]): - outs = yield in_tracers, {} + with (core.set_current_trace(trace), + core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), + core.add_spmd_axis_names(axis_data.spmd_name)): + outs = yield in_tracers, {} out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) new_out_axes = indirectify_ragged_axes_against_inputs_outputs( out_axes, in_vals, out_vals) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c67b4f68cc9b..7ddd3805b5d0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1506,7 +1506,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-rep case. name_set = {n for ns in names.values() for n in ns} - return [n for n in mesh.axis_names if n not in name_set] + return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set] def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, @@ -1652,10 +1652,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]: - trace = core.unsafe_get_current_trace() if trace is None else trace - stack = core.unsafe_get_trace_stack(trace) - batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)] - spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name } + spmd_names = core.get_axis_env().spmd_axis_names return tuple(name for name in mesh.axis_names if name not in spmd_names) # DCE From 87ce0cbb00c8a31a0266e2b10809a184b989a2cf Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 8 Nov 2024 18:28:52 -0800 Subject: [PATCH 263/698] Make GPU work with copy=True and device_put since same device pinned_host -> pinned_host copy is possible. PiperOrigin-RevId: 694713334 --- tests/memories_test.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 337b1c24d835..da4239338c02 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -699,8 +699,10 @@ def foo(x): def test_disallow_alias_copies_arrays(self): if xla_extension_version < 296: self.skipTest("Requires xla_extension_version >= 296") - _, _, _, inp_host = _create_inputs( - (8, 2), P("x", "y"), mem_kind="pinned_host") + mesh = jtu.create_mesh((2,), ("x",)) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P("x"), memory_kind="pinned_host") + inp_host = jax.device_put(np_inp, s) inp_host_copy = jax.device_put(inp_host, may_alias=False) @@ -712,8 +714,10 @@ def test_disallow_alias_copies_arrays(self): def test_disallow_alias_copies_arrays_with_donated_input(self): if xla_extension_version < 296: self.skipTest("Requires xla_extension_version >= 296") - _, _, _, inp_host = _create_inputs( - (8, 2), P("x", "y"), mem_kind="pinned_host") + mesh = jtu.create_mesh((2,), ("x",)) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P("x"), memory_kind="pinned_host") + inp_host = jax.device_put(np_inp, s) inp_host_donate = jax.jit(lambda x: x, donate_argnums=0)(inp_host) From 45ae4dfb9e612d84c8f9b8ccc4acdf3d1b376169 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sat, 9 Nov 2024 11:10:16 +0200 Subject: [PATCH 264/698] [shape_poly] Remove caching for the symbolic shape evaluator The caching used for the shape_poly.CachingShapeEvaluator leads to leaked tracer errors. This is because the `lru_cache` is attached to the `CachingShapeEvaluator.evaluate` and persists for the duration of the program. It is possible to reimplement the caching, but in this case caching does not help much so we just remove it. --- jax/_src/export/_export.py | 7 ++++--- jax/_src/export/shape_poly.py | 26 +++++++++++++------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index fe4deed57d19..ad2c7fdac2dc 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1297,9 +1297,10 @@ def pp_arg_dim(dim_idx: int | None) -> str: # Must express the exported_dim_vars in terms of the shapes in in_avals. solution, shape_constraints, synth_dim_vars = shape_poly.solve_dim_vars( exported.in_avals, args_kwargs_tree=exported.in_tree) - synthetic_env = {vname: in_avals[arg_idx].shape[dim_idx] - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = shape_poly.CachingShapeEvaluator(**synthetic_env) + synthetic_env: shape_poly.DimVarEnv = { + vname: in_avals[arg_idx].shape[dim_idx] + for (vname, arg_idx, dim_idx) in synth_dim_vars} + synthetic_eval = shape_poly.ShapeEvaluator(synthetic_env) # We discharge all the constraints statically. This results in much simpler # composability (because we do not have to worry about the constraints of the # Exported called recursively; we only need to worry about entry-point diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 77786cbf1a9d..15f99533d59e 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1746,11 +1746,10 @@ def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]: return sorted(dim_vars) -class CachingShapeEvaluator: - def __init__(self, **env): +class ShapeEvaluator: + def __init__(self, env: DimVarEnv): self.env = env - @functools.lru_cache(128) def evaluate(self, e: DimSize): if core.is_constant_dim(e): res = op.index(e) # type: ignore @@ -1769,7 +1768,7 @@ class ShapeConstraint: # is formed by evaluating the DimSize and concatenating the sequence. error_message_pieces: Sequence[str | DimSize] - def check_statically(self, eval: CachingShapeEvaluator) -> None: + def check_statically(self, eval: ShapeEvaluator) -> None: """Evaluates a constraint statically.""" left, right = eval.evaluate(self.left), eval.evaluate(self.right) try: @@ -1785,7 +1784,7 @@ def check_statically(self, eval: CachingShapeEvaluator) -> None: if not ok: raise self.make_error(eval) - def compute(self, eval: CachingShapeEvaluator) -> jax.Array | None: + def compute(self, eval: ShapeEvaluator) -> jax.Array | None: """Computes if the constraint is satisfied. If the constraint can be resolved statically returns None @@ -1820,7 +1819,7 @@ def __str__(self): def error_message_and_inputs( self, - eval: CachingShapeEvaluator) -> tuple[str, Sequence[Any]]: + eval: ShapeEvaluator) -> tuple[str, Sequence[Any]]: """Forms the error_message and error message_inputs. See shape_assertion. """ @@ -1849,7 +1848,7 @@ def error_message_and_inputs( return ("".join(error_message_strings), error_message_inputs) - def make_error(self, eval: CachingShapeEvaluator) -> Exception: + def make_error(self, eval: ShapeEvaluator) -> Exception: error_message, error_message_inputs = self.error_message_and_inputs(eval) return ValueError(error_message.format(*error_message_inputs)) @@ -1865,7 +1864,7 @@ def add_constraint(self, c = ShapeConstraint(comp, left, right, error_message_pieces) self.constraints.append(c) - def check_statically(self, eval: CachingShapeEvaluator) -> None: + def check_statically(self, eval: ShapeEvaluator) -> None: """Evaluates all the constraints statically. If the static checking of any constraint fails, raises ValueError. @@ -1873,7 +1872,7 @@ def check_statically(self, eval: CachingShapeEvaluator) -> None: for constraint in self.constraints: constraint.check_statically(eval) - def shape_assertions(self, eval: CachingShapeEvaluator) -> None: + def shape_assertions(self, eval: ShapeEvaluator) -> None: """Computes the shape assertions for the set of constraints. See jax_export.Exported docstring. @@ -2014,10 +2013,11 @@ def compute_dim_vars_from_arg_shapes( tuple(args_avals), args_kwargs_tree=args_kwargs_tree) # Replace the synthetic vars with the dynamic shape of the actual arg - synthetic_env = {vname: dimension_size_p.bind(actual_args[arg_idx], - dimension=dim_idx) - for (vname, arg_idx, dim_idx) in synth_dim_vars} - synthetic_eval = CachingShapeEvaluator(**synthetic_env) + synthetic_env: DimVarEnv = { + vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx) + for (vname, arg_idx, dim_idx) in synth_dim_vars + } + synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] return tuple(dim_values) From b51187ca0c76b2e8e77fa9c49affd9198a279962 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 9 Nov 2024 14:38:30 -0800 Subject: [PATCH 265/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/290673692ae80faf3128b6642bd8727f8527cb12. PiperOrigin-RevId: 694921620 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 7226f7900fce..f0433caa7d76 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "6d95565e652fdd021dcb6d306a54e786572e7a34" -XLA_SHA256 = "1fc547c054905d2724ecc9f4698d40c3f887e0193aed55aaf6ac36774b800c66" +XLA_COMMIT = "290673692ae80faf3128b6642bd8727f8527cb12" +XLA_SHA256 = "56ea1710da3730e86fddd2c605fee3d2009c5a8fb87ca1a6b2190b808fe085fb" def repo(): tf_http_archive( From 098d582e70dd1b5fe469fb9d808ae02e8b2ae809 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 10 Nov 2024 14:42:16 -0800 Subject: [PATCH 266/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/cebb50cc77915aaf16e76b10d78e4c3deb955939. PiperOrigin-RevId: 695127776 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f0433caa7d76..ace06be47494 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "290673692ae80faf3128b6642bd8727f8527cb12" -XLA_SHA256 = "56ea1710da3730e86fddd2c605fee3d2009c5a8fb87ca1a6b2190b808fe085fb" +XLA_COMMIT = "cebb50cc77915aaf16e76b10d78e4c3deb955939" +XLA_SHA256 = "e4be11c05a6b59e8a090e6205d34f138f889826b74113633822fd11b65258668" def repo(): tf_http_archive( From a041ea152ee9784db6d390c93397d28ec03718c9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 10 Nov 2024 16:37:29 -0800 Subject: [PATCH 267/698] Skip test_jnp_einsum_grad_y_pallas on gpu due to ooms PiperOrigin-RevId: 695143627 --- tests/pallas/ops_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index e4017c21f017..b8a42ecf1835 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1034,6 +1034,9 @@ def isnan(x_ref, o_ref): np.testing.assert_allclose(isnan(x), jnp.isnan(x)) def test_jnp_einsum_grad_y_pallas(self): + if jtu.test_device_matches(["gpu"]): + self.skipTest("This test ooms on gpu") + x = jnp.arange(128 * 256, dtype=jnp.float32).reshape((128, 256)) y = jnp.arange(256 * 128, dtype=jnp.float32).reshape((128, 256)) From 763952a607038ffc65b7112f87b834341f67279c Mon Sep 17 00:00:00 2001 From: Dougal Date: Sun, 10 Nov 2024 18:07:31 -0500 Subject: [PATCH 268/698] Fix buggy and confusing logic in the C++/pjit caching path. When we have a cache miss in `_cpp_pjit` we want to compile the function and store the executable. Previously we had a roundabout way of getting hold of that executable. We'd trace the function to a jaxpr but we wouldn't lower and compile it ourselves. Instead, we'd call `pjit_p.bind`. The layers of the tracing onion would be peeled off and eventually we'd hit the `pjit_p` impl rule, `_pjit_call_impl`. This rule has its own cache. With luck we'd also miss *that* cache, and then `_pjit_call_impl` would lower and compile the jaxpr and store the executable in `most_recent_pjit_call_executable`. We'd eventually pop the stack back up to the `_cpp_pjit` cache miss and then we'd get hold of the compiled object by looking up `most_recent_pjit_call_executable`. There's room for bugs here if we hit one cache but not the other. For example, if we miss the `_cpp_pjit` cache but we hit the `_pjit_call_impl` cache then we won't compile the executable. Normally that would just mean that the `_cpp_pjit` cache won't be populated. But if we've previously hit a function with the same jaxpr but slightly different compilation parameters (e.g. device IDs) then we'll get a bogus hit in `most_recent_call_exectuable` and we'll add an incorrect cache entry. The divergent cache behavior you need to trigger this started happening with the "stackless" change because the tracing context became a bigger part of the cache key and `_cpp_pjit` and `_pjit_call_impl` will in general have different tracing contexts. With this change, we remove the whole `most_recent_pjit_call_executable` system. Instead `_cpp_pjit` lowers, compiles and runs the jaxpr itself and obtains the executable directly rather than calling into `pjit_p.bind`. We do call into `pjit_p.bind` if we're not in an eval context, but in that case we don't expect to be able to populate the `_cpp_pjit` cache anyway. --- jax/_src/core.py | 14 +++++++----- jax/_src/pjit.py | 53 +++++++++++++++++----------------------------- tests/pjit_test.py | 17 +++++++++++++++ 3 files changed, 45 insertions(+), 39 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index d9c9306d854c..e3aa40f75bb6 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -892,6 +892,11 @@ def unsafe_buffer_pointer(self): aval_property = namedtuple("aval_property", ["fget"]) aval_method = namedtuple("aval_method", ["fun"]) +def check_eval_args(args): + for arg in args: + if isinstance(arg, Tracer): + raise escaped_tracer_error(arg) + class EvalTrace(Trace): def process_primitive(self, primitive, args, params): @@ -902,12 +907,11 @@ def process_primitive(self, primitive, args, params): else: # TODO(dougalm): delete. this shouldn't be necessary args = map(full_lower, args) - for arg in args: - if isinstance(arg, Tracer): - if config.data_dependent_tracing_fallback.value: + if config.data_dependent_tracing_fallback.value: + for arg in args: + if isinstance(arg, Tracer): return primitive.bind_with_trace(arg._trace, args, params) - else: - raise escaped_tracer_error(arg) + check_eval_args(args) return primitive.impl(*args, **params) def process_call(self, primitive, f, tracers, params): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 604acfb39c16..6ab8c90811a6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -23,7 +23,6 @@ import operator as op import weakref from typing import NamedTuple, Any, Union, cast -import threading import warnings import numpy as np @@ -185,7 +184,16 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - out_flat = pjit_p.bind(*args_flat, **p.params) + if (core.trace_state_clean() and + not config.debug_key_reuse.value and + not config.data_dependent_tracing_fallback.value): + args_flat = map(core.full_lower, args_flat) + core.check_eval_args(args_flat) + out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) + else: + out_flat = pjit_p.bind(*args_flat, **p.params) + compiled = None + profiler = None except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if p.params['resource_env'] is None else 'pjit' @@ -215,7 +223,8 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs): _set_states(p.attrs_tracked, final_states) outs = tree_unflatten(p.out_tree, out_flat) - return outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], p.attrs_tracked + return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], + p.attrs_tracked, compiled, profiler) def _set_states(attrs_tracked, vals): @@ -286,21 +295,6 @@ def _get_fastpath_data( return fastpath_data -class _MostRecentPjitCallExecutable(threading.local): - def __init__(self): - self.weak_key_dict = weakref.WeakKeyDictionary() - self.weak_pgle_profiler_dict = weakref.WeakKeyDictionary() - -_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable() - - -def _read_most_recent_pjit_call_executable(jaxpr): - return _most_recent_pjit_call_executable.weak_key_dict.get(jaxpr, None) - - -def _read_pgle_profiler(jaxpr): - return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None) - def _cpp_pjit_evict_fn(self): self._clear_cache() _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error @@ -335,10 +329,9 @@ def cache_miss(*args, **kwargs): if config.no_tracing.value: raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " "`jit`, but 'no_tracing' is set") - outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( + outs, out_flat, out_tree, args_flat, jaxpr, \ + attrs_tracked, executable, pgle_profiler = _python_pjit_helper( fun, jit_info, *args, **kwargs) - executable = _read_most_recent_pjit_call_executable(jaxpr) - pgle_profiler = _read_pgle_profiler(jaxpr) maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, @@ -1619,17 +1612,11 @@ def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): - global _most_recent_pjit_call_executable - pgle_compile_options, pgle_profiler = {}, None - pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: - if jaxpr not in pgle_profiler_dict: - pgle_profiler_dict[jaxpr] = profiler.PGLEProfiler( - config.pgle_profiling_runs.value, - config.pgle_aggregation_percentile.value) - - pgle_profiler = pgle_profiler_dict[jaxpr] + pgle_profiler = profiler.PGLEProfiler( + config.pgle_profiling_runs.value, + config.pgle_aggregation_percentile.value) # The method below will return FDO profile when module was profiled # config.jax_pgle_profiling_runs amount of times, otherwise the result will # be None. @@ -1652,7 +1639,6 @@ def _pjit_call_impl_python( compiler_options_kvs=compiler_options_kvs, ).compile() - _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.enable_checks.value: pxla.check_array_xla_sharding_layout_match( @@ -1674,7 +1660,7 @@ def _pjit_call_impl_python( ("abstract args", map(xla.abstractify, args)), ("fingerprint", fingerprint)) try: - return compiled.unsafe_call(*args), compiled + return compiled.unsafe_call(*args), compiled, pgle_profiler except FloatingPointError as e: assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case @@ -1720,13 +1706,12 @@ def _pjit_call_impl(*args, jaxpr, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): def call_impl_cache_miss(*args_, **kwargs_): - out_flat, compiled = _pjit_call_impl_python( + out_flat, compiled, pgle_profiler = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) - pgle_profiler = _read_pgle_profiler(jaxpr) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, jaxpr.consts, None, pgle_profiler) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3a0b6cc86114..a9760d02fc0f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1292,6 +1292,23 @@ def f(x): with self.assertRaisesRegex(ValueError, "spmd_axis_name"): jax.vmap(f, spmd_axis_name='x')(xs) + def test_cache_bug(self): + devices = list(jax.devices()) + if len(devices) < 2: + raise unittest.SkipTest("Test requires 2 devices") + + def under_jvp(f): + return jax.jvp(f, (), ()) + + x0 = jnp.zeros(1, device=devices[0]) + x1 = jnp.zeros(1, device=devices[1]) + + # comments describe how caches worked under the old `_most_recent_pjit_call_executable` system + under_jvp(lambda: jnp.sin(x0)) # cpp_pjit miss, pjit_call_impl miss + jnp.sin(x1) # cpp_pjit miss, pjit_call_impl miss + ans1 = jnp.sin(x0) # cpp_pjit miss, pjit_call_impl hit. Bad cpp_pjit entry created + ans2 = jnp.sin(x0) # cpp_pjit hit with bad cache entry + assert(ans1.devices() == ans2.devices()) @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): From 7491fdd94c4aaf0a93339d164de6c3f50d4c977b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 11 Nov 2024 04:08:39 -0800 Subject: [PATCH 269/698] Disable for_loop_test on TPU v5p. This test is failing in CI. PiperOrigin-RevId: 695278007 --- tests/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index 3d79298b8a57..dc81c408c4ce 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1314,6 +1314,9 @@ jax_multiplatform_test( jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], + disable_configs = [ + "tpu_v5p_1x1", # TODO(b/377666550): enable once XLA is fixed. + ], shard_count = { "cpu": 20, "gpu": 10, From da89c9e38c00a3499d8f5ac381fb29de0ea0c597 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 11 Nov 2024 06:17:05 -0800 Subject: [PATCH 270/698] [Mosaic GPU] Add `base_pointer` argument to `InitializeBarrierOp`. This corresponds to what's implemented in `BarrierRef`, and ultimately makes it easier to allocate barriers at a specific address in dynamic shared memory. PiperOrigin-RevId: 695308297 --- jax/experimental/mosaic/gpu/__init__.py | 7 ++- .../mosaic/gpu/dialect_lowering.py | 16 +++---- jaxlib/mosaic/dialect/gpu/BUILD | 1 + jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 12 ++++- tests/mosaic/gpu_dialect_test.py | 47 +++++++++++++------ 5 files changed, 54 insertions(+), 29 deletions(-) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index cf8d2c84c246..7857ffb3c09b 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -29,9 +29,12 @@ ) if dialect is not None: - from .dialect_lowering import lower_mgpu_dialect + from .dialect_lowering import ( + gpu_address_space_to_nvptx as gpu_address_space_to_nvptx, + lower_mgpu_dialect as lower_mgpu_dialect + ) else: - lower_mgpu_dialect = None + gpu_address_space_to_nvptx, lower_mgpu_dialect = None, None from .fragmented_array import ( FragmentedArray as FragmentedArray, diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 8eee9aef3f2e..927da0f30418 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -25,9 +25,8 @@ from jaxlib.mlir import ir from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm -from .utils import c, memref_ptr, single_thread_predicate +from .utils import c, single_thread_predicate # mypy: ignore-errors @@ -57,7 +56,7 @@ def _lowered_barrier_type() -> ir.Type: return ir.IntegerType.get_signless(64) -def _gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int: +def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int: match address_space: case gpu.AddressSpace.Global: return 1 @@ -75,25 +74,22 @@ def _initialize_barrier_op_lowering_rule( num_barriers = functools.reduce(operator.mul, shape, 1) i32 = ir.IntegerType.get_signless(32) - workgroup_nvptx_address_space = _gpu_address_space_to_nvptx( + workgroup_nvptx_address_space = gpu_address_space_to_nvptx( gpu.AddressSpace.Workgroup) ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") lowered_barrier_type = _lowered_barrier_type() - lowered_barrier_ref = memref.alloca( - ir.MemRefType.get(shape, lowered_barrier_type), [], []) - barrier_ref_address = memref_ptr( - lowered_barrier_ref, memory_space=workgroup_nvptx_address_space) predicate = single_thread_predicate(per_block=True) for i in range(num_barriers): nvvm.mbarrier_init_shared( - llvm.getelementptr(ptr_ty, barrier_ref_address, [], [i], + llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i], lowered_barrier_type), c(initialize_barrier_op.arrival_count.value, i32), predicate=predicate ) - return barrier_ref_address, + + return initialize_barrier_op.base_pointer, def lower_mgpu_dialect(module: ir.Module): diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 7c7a3589d9e6..681ee708edd8 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -32,6 +32,7 @@ td_library( "@llvm-project//mlir:BasicPtxBuilderIntTdFiles", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", "@llvm-project//mlir:OpBaseTdFiles", ], ) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index e20dd55043e6..4129dcd1b345 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -16,6 +16,7 @@ limitations under the License. #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/AttrTypeBase.td" @@ -48,19 +49,26 @@ def MosaicGPU_Barrier : MosaicGPU_Type<"Barrier", "barrier", [MemRefElementTypeI let description = "A barrier to use for synchronizing threads"; } +def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; + def MosaicGPU_InitializeBarrierOp : Op { let summary = "Initializes a memref of barriers"; let description = [{ Initializes a memref of barriers each meant to synchronize exactly `arrival_count` threads. + + The base pointer of the result memref corresponds to `base_pointer`, which + must be a pointer to a shared memory location. }]; - let arguments = (ins ConfinedAttr:$arrival_count); + let arguments = (ins + LLVM_PointerShared:$base_pointer, + ConfinedAttr:$arrival_count); let results = (outs MemRefOf<[MosaicGPU_Barrier]>:$barriers_ref); let assemblyFormat = [{ - $arrival_count attr-dict `:` type($barriers_ref) + $base_pointer $arrival_count attr-dict `:` type($barriers_ref) }]; } diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index ae16efe66e71..68d0d3fdd5eb 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -23,9 +23,12 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import gpu +from jax._src.lib.mlir.dialects import llvm from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member +from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import _cext = mgpu._cext if mgpu is not None else None @@ -68,6 +71,12 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool: return op.name.startswith("mosaic_gpu.") +def workgroup_ptr_ty() -> ir.Type: + workgroup_nvptx_address_space = gpu_address_space_to_nvptx( + gpu.AddressSpace.Workgroup) + return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") + + class DialectTest(parameterized.TestCase): def setUp(self): @@ -84,8 +93,8 @@ def test_dialect_module_is_loaded(self): def test_initialize_barrier_op_result_memref_must_wrap_barriers(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( - ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1 - ) + ir.MemRefType.get((1, 2), ir.F32Type.get()), + llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1) with self.assertRaisesRegex( ir.MLIRError, "must be memref of barrier values" ): @@ -95,21 +104,29 @@ def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=0, - ) + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=0) with self.assertRaisesRegex(ir.MLIRError, "value is positive"): self.module.operation.verify() + def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self): + with ir.InsertionPoint(self.module.body): + mgpu.initialize_barrier( + ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), + llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")), + arrival_count=1) + with self.assertRaisesRegex(ir.MLIRError, "pointer in address space 3"): + self.module.operation.verify() + def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=1, - ) + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=1) self.assertTrue(self.module.operation.verify()) - self.assertIsInstance( - self.module.body.operations[0], mgpu.InitializeBarrierOp - ) + self.assertIsInstance(self.module.body.operations[1], + mgpu.InitializeBarrierOp) def test_async_load_op_dest_must_be_contiguous(self): with ir.InsertionPoint(self.module.body): @@ -464,8 +481,8 @@ def test_lowering_removes_mosaic_gpu_ops(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=1, - ) + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=1) lower_mgpu_dialect(self.module) self.assertEmpty( @@ -480,8 +497,8 @@ def test_lowering_traverses_regions_correctly(self): with ir.InsertionPoint(if_op.then_block): mgpu.initialize_barrier( ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=1, - ) + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=1) scf.yield_([]) lower_mgpu_dialect(self.module) @@ -497,8 +514,8 @@ def test_initialize_barrier_op_lowering_rule(self): with ir.InsertionPoint(self.module.body): mgpu.initialize_barrier( ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), - arrival_count=arrival_count, - ) + llvm.UndefOp(workgroup_ptr_ty()), + arrival_count=arrival_count) lower_mgpu_dialect(self.module) all_mbarrier_init_shared_ops = find_if( From 93599163212999a842f06a75a5818b91eb69e4ba Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 11 Nov 2024 06:42:06 -0800 Subject: [PATCH 271/698] jnp.bincount: support boolean inputs --- jax/_src/numpy/lax_numpy.py | 4 +++- tests/lax_numpy_test.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d2e89833915d..b90004e19932 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3070,6 +3070,8 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, Array([2, 1, 0, 1, 0], dtype=int32) """ util.check_arraylike("bincount", x) + if _dtype(x) == bool: + x = lax.convert_element_type(x, 'int32') if not issubdtype(_dtype(x), integer): raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}") if ndim(x) != 1: @@ -3080,7 +3082,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, x_arr = core.concrete_or_error(asarray, x, "The error occurred because of argument 'x' of jnp.bincount. " "To avoid this error, pass a static `length` argument.") - length = max(minlength, x_arr.size and int(x_arr.max()) + 1) + length = max(minlength, x_arr.size and int(max(0, x_arr.max())) + 1) else: length = core.concrete_dim_or_error(length, "The error occurred because of argument 'length' of jnp.bincount.") diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 61baa7c97df4..7c2728af415e 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4905,7 +4905,7 @@ def testAtLeastNdLiterals(self, dtype, op): @jtu.sample_product( shape=[(0,), (5,), (10,)], - dtype=int_dtypes, + dtype=int_dtypes + bool_dtypes, weights=[True, False], minlength=[0, 20], length=[None, 8], From 1d24630b418cdeb97e07631bac41668e344efcff Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 11 Nov 2024 07:53:08 -0800 Subject: [PATCH 272/698] [XLA:GPU] Change `assert` to `CHECK` in Triton sparsity extensions. A JAX test is hitting this after the latest Triton integrate cl/694073628. Disable the test until we get to the bottom of it. PiperOrigin-RevId: 695337017 --- tests/sparse_nm_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py index 9ecf30eb6229..686575ee4019 100644 --- a/tests/sparse_nm_test.py +++ b/tests/sparse_nm_test.py @@ -47,6 +47,9 @@ def setUp(self): ) @jtu.run_on_devices("gpu") def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Skipping test on Ampere because of bug b/377940729") + # Build keyword arguments kwargs = { "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), @@ -93,6 +96,9 @@ def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): ) @jtu.run_on_devices("gpu") def test_types(self, lhs_type, rhs_type, output_type): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Skipping test on Ampere because of bug b/377940729") + tile_m, tile_n, tile_k = 64, 32, 128 # Build input data From 8a7bf2e4b03bd1ba9dd5bb16d7efbe13dbfd3615 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 11 Nov 2024 08:01:12 -0800 Subject: [PATCH 273/698] [Mosaic GPU] Ensure that lowering `InitializeBarrierOp` preserves the result's type. Otherwise, the lowered IR won't be type-correct. PiperOrigin-RevId: 695339726 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 9 +++++++-- tests/mosaic/gpu_dialect_test.py | 9 ++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 927da0f30418..9bda5b5b7191 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -26,7 +26,7 @@ from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import nvvm -from .utils import c, single_thread_predicate +from .utils import c, ptr_as_memref, single_thread_predicate # mypy: ignore-errors @@ -89,7 +89,12 @@ def _initialize_barrier_op_lowering_rule( predicate=predicate ) - return initialize_barrier_op.base_pointer, + barrier_base_ptr = llvm.getelementptr( + ir.Type.parse("!llvm.ptr"), + initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type) + + return ptr_as_memref( + barrier_base_ptr, initialize_barrier_op.barriers_ref.type), def lower_mgpu_dialect(module: ir.Module): diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index 68d0d3fdd5eb..3edddaad9d12 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -25,6 +25,7 @@ from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member @@ -512,11 +513,17 @@ def test_initialize_barrier_op_lowering_rule(self): arrival_count = 1337 with ir.InsertionPoint(self.module.body): - mgpu.initialize_barrier( + barriers_ref = mgpu.initialize_barrier( ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), llvm.UndefOp(workgroup_ptr_ty()), arrival_count=arrival_count) + # Add a user for barriers_ref to make sure that the lowering keeps types + # consistent. + memref.copy(barriers_ref, barriers_ref) + + self.assertTrue(self.module.operation.verify()) lower_mgpu_dialect(self.module) + self.assertTrue(self.module.operation.verify()) all_mbarrier_init_shared_ops = find_if( self.module, From f18f62a5d2c5d300017162634f17aea1f35a28c6 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 11 Nov 2024 08:57:11 -0800 Subject: [PATCH 274/698] [XLA:GPU] Skip small tile sizes for sparse gemms on Ampere as well. Enable the JAX test again that has been failing. PiperOrigin-RevId: 695360850 --- tests/sparse_nm_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py index 686575ee4019..9ecf30eb6229 100644 --- a/tests/sparse_nm_test.py +++ b/tests/sparse_nm_test.py @@ -47,9 +47,6 @@ def setUp(self): ) @jtu.run_on_devices("gpu") def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): - if not jtu.is_cuda_compute_capability_at_least("9.0"): - self.skipTest("Skipping test on Ampere because of bug b/377940729") - # Build keyword arguments kwargs = { "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), @@ -96,9 +93,6 @@ def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): ) @jtu.run_on_devices("gpu") def test_types(self, lhs_type, rhs_type, output_type): - if not jtu.is_cuda_compute_capability_at_least("9.0"): - self.skipTest("Skipping test on Ampere because of bug b/377940729") - tile_m, tile_n, tile_k = 64, 32, 128 # Build input data From 9b562158ac7b7b425b2ca3dbf3d83f3aea21cd24 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 11 Nov 2024 09:02:36 -0800 Subject: [PATCH 275/698] Internal: create decorators for defining ufuncs --- jax/_src/numpy/ufuncs.py | 153 ++++++++++--------- jax/experimental/jax2tf/tests/jax2tf_test.py | 4 +- 2 files changed, 81 insertions(+), 76 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 8692c30a3e17..acaed78e4db7 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -57,6 +57,24 @@ def _to_bool(x: Array) -> Array: return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) +def unary_ufunc(func: Callable[[ArrayLike], Array]) -> ufunc: + """An internal helper function for defining unary ufuncs.""" + func_jit = jit(func, inline=True) + return ufunc(func_jit, name=func.__name__, nin=1, nout=1, call=func_jit) + + +def binary_ufunc(identity: Any, reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None) -> Callable[[Callable[[ArrayLike, ArrayLike], Array]], ufunc]: + """An internal helper function for defining binary ufuncs.""" + def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc: + func_jit = jit(func, inline=True) + return ufunc(func_jit, name=func.__name__, nin=2, nout=1, call=func_jit, + identity=identity, reduce=reduce, accumulate=accumulate, at=at, reduceat=reduceat) + return decorator + + @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. @@ -160,8 +178,8 @@ def invert(x: ArrayLike, /) -> Array: return lax.bitwise_not(*promote_args('invert', x)) -@partial(jit, inline=True) -def _negative(x: ArrayLike, /) -> Array: +@unary_ufunc +def negative(x: ArrayLike, /) -> Array: """Return element-wise negative values of the input. JAX implementation of :obj:`numpy.negative`. @@ -1126,8 +1144,16 @@ def cbrt(x: ArrayLike, /) -> Array: """ return lax.cbrt(*promote_args_inexact('cbrt', x)) -@partial(jit, inline=True) -def _add(x: ArrayLike, y: ArrayLike, /) -> Array: +def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.add.at.""" + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].add(b).astype(bool) + return a.at[indices].add(b) + +@binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) +def add(x: ArrayLike, y: ArrayLike, /) -> Array: """Add two arrays element-wise. JAX implementation of :obj:`numpy.add`. This is a universal function, @@ -1156,8 +1182,17 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) -@partial(jit, inline=True) -def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: +def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.multiply.at.""" + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].mul(b).astype(bool) + else: + return a.at[indices].mul(b) + +@binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) +def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: """Multiply two arrays element-wise. JAX implementation of :obj:`numpy.multiply`. This is a universal function, @@ -1186,8 +1221,8 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) -@partial(jit, inline=True) -def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=-1) +def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise AND operation elementwise. JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, @@ -1215,8 +1250,8 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) -@partial(jit, inline=True) -def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=0) +def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise OR operation elementwise. JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, @@ -1244,8 +1279,8 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) -@partial(jit, inline=True) -def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=0) +def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise XOR operation elementwise. JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, @@ -1433,8 +1468,12 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.ne(*promote_args("not_equal", x, y)) -@partial(jit, inline=True) -def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array: +def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.subtract.at.""" + return a.at[indices].subtract(b) + +@binary_ufunc(identity=None, at=_subtract_at) +def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: """Subtract two arrays element-wise. JAX implementation of :obj:`numpy.subtract`. This is a universal function, @@ -1754,8 +1793,17 @@ def spacing(x: ArrayLike, /) -> Array: # Logical ops -@partial(jit, inline=True) -def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + """Implementation of jnp.logical_and.reduce.""" + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_and.reduce()") + result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + +@binary_ufunc(identity=True, reduce=_logical_and_reduce) +def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical AND operation elementwise. JAX implementation of :obj:`numpy.logical_and`. This is a universal function, @@ -1774,8 +1822,18 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) -@partial(jit, inline=True) -def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + +def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + """Implementation of jnp.logical_or.reduce.""" + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_or.reduce()") + result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + +@binary_ufunc(identity=False, reduce=_logical_or_reduce) +def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical OR operation elementwise. JAX implementation of :obj:`numpy.logical_or`. This is a universal function, @@ -1794,8 +1852,9 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) -@partial(jit, inline=True) -def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=False) +def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical XOR operation elementwise. JAX implementation of :obj:`numpy.logical_xor`. This is a universal function, @@ -3653,57 +3712,3 @@ def _sinc_maclaurin(k, x): def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t - - -def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_and.reduce()") - result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - - -def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_or.reduce()") - result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - -def _add_at(a: Array, indices: Any, b: ArrayLike): - if a.dtype == bool: - a = a.astype('int32') - b = lax.convert_element_type(b, bool).astype('int32') - return a.at[indices].add(b).astype(bool) - return a.at[indices].add(b) - -def _subtract_at(a: Array, indices: Any, b: ArrayLike): - return a.at[indices].subtract(b) - -def _multiply_at(a: Array, indices: Any, b: ArrayLike): - if a.dtype == bool: - a = a.astype('int32') - b = lax.convert_element_type(b, bool).astype('int32') - return a.at[indices].mul(b).astype(bool) - else: - return a.at[indices].mul(b) - -# Generate ufunc interfaces for several common binary functions. -# We start with binary ufuncs that have well-defined identities.' -# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience? -# TODO(jakevdp): optimize some implementations. -# - define add.at/multiply.at in terms of scatter_add/scatter_mul -# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod -# - define all monoidal reductions in terms of lax.reduce -add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) -multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) -bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and) -bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or) -bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor) -logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce) -logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce) -logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor) -negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative) -subtract = ufunc(_subtract, name="subtract", nin=2, nout=1, call=_subtract, at=_subtract_at) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index e59084041306..8993d044cb3b 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -979,8 +979,8 @@ def caller_jax(x): self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) else: graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) + if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: + self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) def test_bfloat16_constant(self): # Re: https://github.com/jax-ml/jax/issues/3942 From 24af8a676b3484b4277c65cb8cf4e00fed5ce588 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Fri, 8 Nov 2024 22:22:31 -0500 Subject: [PATCH 276/698] [Mosaic GPU] Only run tests requiring sm90a on Hopper --- jax/_src/test_util.py | 8 ++++++++ .../mosaic/gpu/examples/flash_attention.py | 4 ++-- tests/mosaic/flash_attention_test.py | 4 ++-- tests/mosaic/matmul_test.py | 4 ++-- tests/pallas/mgpu_attention_test.py | 4 ++-- tests/pallas/mosaic_gpu_test.py | 10 ++++++++++ 6 files changed, 26 insertions(+), 8 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index bb81c979bc48..78de511d4ec4 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -538,6 +538,14 @@ def is_cuda_compute_capability_at_least(capability: str) -> bool: current = tuple(int(x) for x in d.compute_capability.split(".")) return current >= target +def is_cuda_compute_capability_equal(capability: str) -> bool: + if not is_device_cuda(): + return False + d, *_ = jax.local_devices(backend="gpu") + target = tuple(int(x) for x in capability.split(".")) + current = tuple(int(x) for x in d.compute_capability.split(".")) + return current == target + def _get_device_tags(): """returns a set of tags defined for the device under test""" if is_device_rocm(): diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 808afae8fc05..4728f00a9243 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -600,9 +600,9 @@ def ref(q, k, v): if __name__ == "__main__": if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): + not jtu.is_cuda_compute_capability_equal("9.0")): warnings.warn( - "Mosaic GPU Flash Attention requires compute capability 9.0 to run, " + "Mosaic GPU Flash Attention requires compute capability 9.0a to run, " "skipping.") exit(0) diff --git a/tests/mosaic/flash_attention_test.py b/tests/mosaic/flash_attention_test.py index 1d15159ca44e..46a2199e19cc 100644 --- a/tests/mosaic/flash_attention_test.py +++ b/tests/mosaic/flash_attention_test.py @@ -43,8 +43,8 @@ def setUp(self): if flash_attention is None: self.skipTest("Mosaic GPU not available.") if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on GPU with capability >= sm90") + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") @parameterized.product( batch_size=(1,), diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index 27ce4e3f02d7..d598d7d0c0ec 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -55,8 +55,8 @@ def setUp(self): if matmul is None: self.skipTest("Mosaic GPU not available.") if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on GPU with capability >= sm90") + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") @parameterized.named_parameters( (f"_shard{i}", i) for i in range(5) diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index 32319e45e2dc..3fa4f6a6f2dd 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -45,8 +45,8 @@ def setUp(self): if attention_mgpu is None: self.skipTest("Mosaic GPU not available.") if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on GPU with capability >= sm90") + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") @parameterized.product( batch_size=(1,), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e0945b0265fb..cbbe8da54972 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -57,6 +57,10 @@ def capture_stdout(self): # We need to cudaDeviceSynchronize to make sure printfs are flushed. mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() + def skip_unless_sm90a(self): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + class PallasCallTest(PallasTest): @@ -731,6 +735,7 @@ def kernel(x_ref, o_ref): @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): + self.skip_unless_sm90a() # TensorCores can only fuse transposes of 16-bit values, and RHS # is expected to be column major by default. rhs_transpose = jnp.dtype(dtype).itemsize != 2 @@ -781,6 +786,7 @@ def scope(acc_ref): ) def test_wgmma_registers(self): + self.skip_unless_sm90a() def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -804,6 +810,7 @@ def scope(acc_ref): np.testing.assert_allclose(res, a @ b, rtol=1e-3) def test_wgmma_registers_init(self): + self.skip_unless_sm90a() def kernel(a_ref, b_ref, i_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref[...], b_ref) @@ -828,6 +835,7 @@ def scope(acc_ref): np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) def test_wgmma_sliced_ref(self): + self.skip_unless_sm90a() def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) @@ -863,6 +871,7 @@ def scope(acc_ref): np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) def test_wgmma_sliced_acc(self): + self.skip_unless_sm90a() swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref): @@ -918,6 +927,7 @@ def kernel(a_ref, b_ref): np.testing.assert_array_equal(b, np.ones_like(a)) def test_realistic_matmul(self): + self.skip_unless_sm90a() dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize From 39e0f486a252b0bf58b7d6588dbdac7f2344000d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Nov 2024 17:40:54 +0000 Subject: [PATCH 277/698] Bump actions/cache from 4.1.1 to 4.1.2 Bumps [actions/cache](https://github.com/actions/cache) from 4.1.1 to 4.1.2. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v4.1.1...6849a6489940f00c2f30c0fb92c6274307ccb58a) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index db1477ac38b1..2b555d492644 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -35,7 +35,7 @@ jobs: with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1 + - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} From 034467deda89ace253419a67a6334162e5666a19 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 11 Nov 2024 09:46:01 -0800 Subject: [PATCH 278/698] [pallas:triton] Simplify reshape lowering rule. PiperOrigin-RevId: 695378496 --- jax/_src/pallas/triton/lowering.py | 45 ++++-------------------------- 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 19328b44800b..4b8bc062427e 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1608,17 +1608,6 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions): return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None) -def _reshape(x: ir.Value, shape: Sequence[int]) -> ir.Value: - if not shape: - raise ValueError("cannot reshape to an empty shape") - ty = ir.RankedTensorType(x.type) - return tt_dialect.reshape( - ir.RankedTensorType.get(shape, ty.element_type, ty.encoding), - x, - allow_reorder=False, - ) - - @register_lowering(lax.reshape_p) def _reshape_lowering_rule( ctx: LoweringRuleContext, a, *, new_sizes, dimensions @@ -1633,34 +1622,12 @@ def _reshape_lowering_rule( assert all(dim_size == 1 for dim_size in out_aval.shape) return _splat(a, out_aval.shape) - # TODO(slebedev): Check that the following comment still applies. - # Expand-dims or reduce-sum to handle singleton dims as `tl.reshape` is not - # currently implemented. - dst_shape = [*out_aval.shape] - i = 0 - while ( - ir.RankedTensorType.isinstance(a.type) - and (a_shape := ir.RankedTensorType(a.type).shape) != dst_shape - ): - dim_size = a_shape[i] if i < len(a_shape) else None - dst_dim_size = dst_shape[i] if i < len(dst_shape) else None - if dim_size == dst_dim_size: - i += 1 - elif dst_dim_size == 1: - a = _expand_dims(a, axis=i) - i += 1 - elif dim_size == 1: - in_shape = a_shape - out_shape = tuple(d for di, d in enumerate(a_shape) if di != i) - reduce_ctx = ctx.replace( - avals_in=[ctx.avals_in[0].update(shape=in_shape)], - avals_out=[ctx.avals_in[0].update(shape=out_shape)], - ) - a = _reduce_lowering(jnp.add, reduce_ctx, a, axes=(i,)) - else: # We expect this to fail. - return _reshape(a, dst_shape) - - return a + ty = ir.RankedTensorType(a.type) + return tt_dialect.reshape( + ir.RankedTensorType.get([*out_aval.shape], ty.element_type, ty.encoding), + a, + allow_reorder=False, + ) def _compute_pointers_from_indices( From 0995bc231c51e2ee66995be8ee2b31adf9236509 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 11 Nov 2024 11:10:06 -0800 Subject: [PATCH 279/698] [pallas:triton] Simplify lowering code. `BlockInfo` is now always present for memory refs. PiperOrigin-RevId: 695414469 --- jax/_src/pallas/triton/lowering.py | 62 ++++++++++++------------------ 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 4b8bc062427e..f5fb44de9a9d 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -87,7 +87,7 @@ class ModuleContext: class BlockInfo: full_shape_dtype: jax.ShapeDtypeStruct start_indices: Sequence[Any] - block_shape: tuple[int, ...] # TODO(necula): can this contain "mapped"? + block_shape: tuple[int | pallas_core.Mapped, ...] @dataclasses.dataclass @@ -95,7 +95,7 @@ class LoweringRuleContext: context: ModuleContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] - block_infos: Sequence[BlockInfo | None] # TODO(necula): can this be None? + block_infos: Sequence[BlockInfo | None] replace = dataclasses.replace @@ -362,14 +362,15 @@ def read_env(atom: jax_core.Atom): def read_block_info_env(atom: jax_core.Atom): if isinstance(atom, jax_core.Literal): return None - return block_info_env.get(atom, None) + return block_info_env.get(atom) def write_env(var: jax_core.Var, val): env[var] = val if block_infos is not None: for invar, block_info in zip(jaxpr.invars, block_infos): - block_info_env[invar] = block_info + if block_info is not None: + block_info_env[invar] = block_info map(write_env, jaxpr.invars, args) @@ -393,7 +394,7 @@ def write_env(var: jax_core.Var, val): raise # We only add the extra info to the innermost exception. except Exception as e: if not pallas_call._verbose_errors_enabled(): - raise + raise inval_types = map(lambda t: getattr(t, "type", None), invals) raise LoweringError( f"Exception while lowering eqn:\n {eqn}\nWith context:\n " @@ -474,14 +475,14 @@ def _atomic_lowering_rule( args_tree, atomic_type: primitives.AtomicOpType, ): + block_info, *_ = ctx.block_infos + assert block_info is not None ptr, indexers, val, mask = args_tree.unflatten(args_flat) *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) != 1: raise NotImplementedError("Only single indexer is supported.") idx = indexers[0] - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, idx) val = _ensure_ir_value(val, value_aval) if mask is not None: mask = _ensure_ir_value(mask, mask_aval) @@ -1631,21 +1632,10 @@ def _reshape_lowering_rule( def _compute_pointers_from_indices( - root_ptr: ir.Value, - block_info: BlockInfo | None, - nd_indexer: NDIndexer, - array_shape_dtype: Any, + root_ptr: ir.Value, block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: - if block_info is None: # TODO(necula): is this branch dead? - full_shape = array_shape_dtype.shape - num_mapped_dims = 0 - block_shape = array_shape_dtype.shape - else: - full_shape = block_info.full_shape_dtype.shape - num_mapped_dims = sum( - b is pallas_core.mapped for b in block_info.block_shape - ) - block_shape = block_info.block_shape + full_shape = block_info.full_shape_dtype.shape + num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape) strides = pallas_utils.strides_from_shape(full_shape) indexer_shape = nd_indexer.get_indexer_shape() int_indexer_shape = nd_indexer.int_indexer_shape @@ -1653,14 +1643,10 @@ def _compute_pointers_from_indices( indices = nd_indexer.indices other_shape = indexer_shape[len(int_indexer_shape) :] other_shape_idx = 0 - if block_info is None: - start_index_offsets = [None] * len(indices) - else: - start_index_offsets = block_info.start_indices assert len(indices) + num_mapped_dims == len(full_shape) - assert len(start_index_offsets) == len(full_shape) + assert len(block_info.start_indices) == len(full_shape) - array_dtype = jnp.dtype(array_shape_dtype.dtype) + array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype) full_size = math.prod(full_shape) * array_dtype.itemsize # Use 64-bit indexing when offset might be >= 2**32 bytes. offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) @@ -1671,7 +1657,7 @@ def _compute_pointers_from_indices( indexer_iter = iter(indices) for dim_stride, dim_block_size, start_offset in zip( - strides, block_shape, start_index_offsets + strides, block_info.block_shape, block_info.start_indices ): if dim_block_size is pallas_core.mapped: index = _ir_constant(0, offset_eltype) @@ -1831,6 +1817,8 @@ def _masked_load_lowering_rule( cache_modifier, is_volatile, ): + block_info, *_ = ctx.block_infos + assert block_info is not None ptr, indexers, mask, other = args_tree.unflatten(args_flat) *_, mask_aval, other_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) > 1: @@ -1839,9 +1827,7 @@ def _masked_load_lowering_rule( if not tt_dialect.PointerType.isinstance(ptr.type): assert len(ctx.avals_in) == 1 return ptr - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, idx) if mask is not None: mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape()) if other is not None: @@ -1931,14 +1917,14 @@ def _store( def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, eviction_policy ): + block_info, *_ = ctx.block_infos + assert block_info is not None ptr, indexers, value, mask = args_tree.unflatten(args_flat) *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) if len(indexers) > 1: raise NotImplementedError("No support for multiple indexers yet.") idx = indexers[0] - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, idx) other = None if value is not None: value = _ensure_ir_value(value, value_aval) @@ -1954,6 +1940,8 @@ def _masked_swap_lowering_rule( @register_lowering(sp.addupdate_p) def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): + block_info, *_ = ctx.block_infos + assert block_info is not None indexers = tree_util.tree_unflatten(tree, idx) if not tt_dialect.PointerType.isinstance(ptr.type): assert len(indexers) == 0 @@ -1961,9 +1949,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): if len(indexers) > 1: raise NotImplementedError("No support for multiple indexers yet.") indexer = indexers[0] - ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], indexer, ctx.avals_in[0] - ) + ptr = _compute_pointers_from_indices(ptr, block_info, indexer) op = tt_dialect.RMWOp.FADD if isinstance(_element_type(value.type), ir.IntegerType): op = tt_dialect.RMWOp.ADD From 0e611e5cac4d6e92e49bdf851e3d1b9724380cc9 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 11 Nov 2024 11:13:16 -0800 Subject: [PATCH 280/698] [Pallas] Add a cost estimator for Pallas/JAX functions. Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: https://github.com/jax-ml/jax/issues/24539. This cost estimator uses only abstract evaluation which should work for all input sizes. PiperOrigin-RevId: 695415760 --- jax/_src/pallas/BUILD | 1 + jax/_src/pallas/cost_estimate.py | 215 ++++++++++++++++++++++ tests/pallas/BUILD | 14 ++ tests/pallas/pallas_cost_estimate_test.py | 95 ++++++++++ 4 files changed, 325 insertions(+) create mode 100644 jax/_src/pallas/cost_estimate.py create mode 100644 tests/pallas/pallas_cost_estimate_test.py diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 4ff7062ac1e8..e1bedaf93377 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -30,6 +30,7 @@ py_library( srcs = [ "__init__.py", "core.py", + "cost_estimate.py", "pallas_call.py", "primitives.py", "utils.py", diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py new file mode 100644 index 000000000000..1bcf704b3579 --- /dev/null +++ b/jax/_src/pallas/cost_estimate.py @@ -0,0 +1,215 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper tool for automatic cost estimation.""" +import dataclasses +import math +from typing import Any, Sequence + +from jax._src import core as jax_core +from jax._src.pallas import core as pallas_core +from jax._src import linear_util as lu +from jax._src.interpreters import partial_eval as pe +from jax._src.util import safe_map +from jax._src.util import safe_zip +from jax._src.lax import lax + +map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin +zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin + +_cost_rules = {} + +@dataclasses.dataclass(frozen=True) +class CostEstimate: + flops: int + transcendentals: int + bytes_accessed: int + + def __add__(self, other: 'CostEstimate') -> 'CostEstimate': + return CostEstimate( + flops=self.flops + other.flops, + transcendentals=self.transcendentals + other.transcendentals, + bytes_accessed=self.bytes_accessed + other.bytes_accessed, + ) + +def register_cost_rule(primitive: jax_core.Primitive, rule): + _cost_rules[primitive] = rule + +@dataclasses.dataclass(frozen=True) +class Context: + avals_in: Sequence[Any] + avals_out: Sequence[Any] + +def cost_estimate_jaxpr( + jaxpr: jax_core.ClosedJaxpr, +) -> pallas_core.CostEstimate: + """Returns the cost estimate for the given Jaxpr.""" + jaxpr, _ = jaxpr.jaxpr, jaxpr.consts + total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0) + + for eqn in jaxpr.eqns: + _, bind_params = eqn.primitive.get_bind_params(eqn.params) + rule = _cost_rules.get(eqn.primitive, None) + if rule is not None: + context = Context(avals_in=[v.aval for v in eqn.invars], + avals_out=[v.aval for v in eqn.outvars]) + op_cost = rule(context, **bind_params) + total_cost = total_cost + op_cost + return pallas_core.CostEstimate( + flops=total_cost.flops, + transcendentals=total_cost.transcendentals, + bytes_accessed=total_cost.bytes_accessed, + ) + +def cost_estimate(fun, *args) -> pallas_core.CostEstimate: + """Computes a cost estimate for the given function. + + Args: + fun: The function to compute the cost estimate for. + *args: The arguments to the function. Can be jax.ShapeDtypeStruct or + jax.Array. + + Returns: + A pallas_core.CostEstimate object containing the cost estimate. + """ + wrapped_fun = lu.wrap_init(lambda *args, **kwargs: (fun(*args, **kwargs),)) + avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in args] + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) + estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) + input_bytes = sum(math.prod(a.shape) * a.dtype.itemsize for a in args) + output_bytes = sum( + math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars) + return pallas_core.CostEstimate( + flops=estimate.flops, + transcendentals=estimate.transcendentals, + bytes_accessed=estimate.bytes_accessed + input_bytes + output_bytes, + ) + +def binary_cost_rule(ctx: Context, **_) -> CostEstimate: + aval_out, = ctx.avals_out + out_flops = math.prod(aval_out.shape) + return CostEstimate( + flops=out_flops, + transcendentals=0, + bytes_accessed=0, + ) +BINARY_OPS = [ + lax.add_p, + lax.mul_p, + lax.sub_p, + lax.div_p, + lax.min_p, + lax.max_p, + lax.or_p, + lax.and_p, + lax.xor_p, +] +for op in BINARY_OPS: + register_cost_rule(op, binary_cost_rule) + + +def unary_cost_rule(transcendental: bool): + def cost_rule(ctx: Context, **_) -> CostEstimate: + x_aval, = ctx.avals_in + new_flops = 0 + new_transcendentals = 0 + if transcendental: + new_transcendentals += math.prod(x_aval.shape) + else: + new_flops += math.prod(x_aval.shape) + return CostEstimate( + flops=new_flops, + transcendentals=new_transcendentals, + bytes_accessed=0, + ) + return cost_rule + +UN_OPS = [ + lax.neg_p, + lax.floor_p, + lax.ceil_p, + lax.round_p, + lax.not_p, +] +for op in UN_OPS: + register_cost_rule(op, unary_cost_rule(transcendental=False)) + +TRANSCENDENTAL_OPS = [ + lax.cos_p, + lax.sin_p, + lax.tan_p, + lax.sinh_p, + lax.cosh_p, + lax.tanh_p, + lax.acos_p, + lax.asin_p, + lax.atan_p, + lax.exp_p, + lax.log_p, + lax.logistic_p, + lax.sqrt_p, +] +for op in TRANSCENDENTAL_OPS: + register_cost_rule(op, unary_cost_rule(transcendental=True)) + +def _integer_pow_cost_rule(ctx: Context, *, y: int) -> CostEstimate: + x_aval, = ctx.avals_in + num_elements = math.prod(x_aval.shape) + if y == 0 or y == 1: + # No flops, the result is 0 or a copy of the input. + cost_per_element = 0 + else: + # We assume integer pow is implemented using repeated squaring. + # The cost is log(y) squarings, plus one multiply per non-zero bit. + highest_bit = math.floor(math.log(y, 2)) + cost_per_element = highest_bit + y.bit_count() + return CostEstimate( + flops=num_elements * cost_per_element, + transcendentals=0, + bytes_accessed=0, + ) +register_cost_rule(lax.integer_pow_p, _integer_pow_cost_rule) + +def dot_general_cost_rule(ctx: Context, + dimension_numbers: lax.DotDimensionNumbers, + **_) -> CostEstimate: + x_aval, y_aval = ctx.avals_in + x_shape, y_shape = x_aval.shape, y_aval.shape + (lhs_contracting_dims, rhs_contracting_dims), ( + lhs_batch_dims, rhs_batch_dims) = dimension_numbers + assert len(lhs_contracting_dims) == len(rhs_contracting_dims) + assert len(lhs_batch_dims) == len(rhs_batch_dims) + flops = 1 + # Flops along a contracting dim is 2*dim (addition and multiplication) + for i in range(len(lhs_contracting_dims)): + lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i] + assert x_shape[lhs_dim] == y_shape[rhs_dim] + flops *= 2 * x_shape[lhs_dim] + # Now we handle all other dimensions. + for i, lhs_dim in enumerate(x_shape): + if i in lhs_contracting_dims: + continue + flops *= lhs_dim + for i, rhs_dim in enumerate(y_shape): + if i in rhs_contracting_dims: + continue + # Don't double-count batch dims (we already counted for LHS) + if i in rhs_batch_dims: + continue + flops *= rhs_dim + return CostEstimate( + flops=flops, + transcendentals=0, + bytes_accessed=0, + ) +register_cost_rule(lax.dot_general_p, dot_general_cost_rule) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index f95ea53b4929..92cab875df7d 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -56,6 +56,20 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_multiplatform_test( + name = "pallas_cost_estimate_test", + srcs = [ + "pallas_cost_estimate_test.py", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", + "//jax:pallas_gpu_ops", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_multiplatform_test( name = "pallas_jumble_test", srcs = [ diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py new file mode 100644 index 000000000000..74dd150fbc10 --- /dev/null +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -0,0 +1,95 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import lax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import cost_estimate + + +config.parse_flags_with_absl() + + +class PallasCostEstimateTest(jtu.JaxTestCase): + + def test_exp_add(self): + def exp_add(x, y): + return jnp.exp(x + y) + cost = cost_estimate.cost_estimate(exp_add, + jnp.ones(10, dtype=jnp.float32), + jnp.ones(10, dtype=jnp.float32)) + self.assertEqual(cost.flops, 10) + self.assertEqual(cost.transcendentals, 10) + self.assertEqual(cost.bytes_accessed, 4 * 30) + + def test_very_large_matmul(self): + def matmul(a, b): + return a @ b + m, k, n = 400_000, 800_000, 900_000 + cost = cost_estimate.cost_estimate( + matmul, + jax.ShapeDtypeStruct((m, k), jnp.bfloat16), + jax.ShapeDtypeStruct((k, n), jnp.bfloat16)) + self.assertEqual(cost.flops, 2*m*k*n) + self.assertEqual(cost.transcendentals, 0) + self.assertEqual(cost.bytes_accessed, 2*(m*k + n*k + m*n)) + + def test_batched_matmul(self): + def matmul(a, b): + return jnp.matmul(a, b) + b, m, k, n = 7, 37, 91, 23 + cost = cost_estimate.cost_estimate( + matmul, + jax.ShapeDtypeStruct((b, m, k), jnp.float32), + jax.ShapeDtypeStruct((b, k, n), jnp.float32)) + self.assertEqual(cost.flops, 2*b*m*k*n) + self.assertEqual(cost.transcendentals, 0) + self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n)) + + def test_attention(self): + qk_dim = 16 + v_dim = 4 + kv_len = 128 + q_len = 64 + def attention(q, k, v): + return jax.nn.softmax(q @ k.T, axis=-1) @ v + cost = cost_estimate.cost_estimate( + attention, + jnp.zeros((q_len, qk_dim), dtype=jnp.float32), + jnp.zeros((kv_len, qk_dim), dtype=jnp.float32), + jnp.zeros((kv_len, v_dim), dtype=jnp.float32)) + qk_cost = 2 * q_len * kv_len * qk_dim + v_cost = 2 * q_len * kv_len * v_dim + softmax_flops = kv_len * q_len + self.assertEqual(cost.flops, qk_cost + v_cost + 2 * softmax_flops + q_len) + self.assertEqual(cost.transcendentals, softmax_flops) + input_bytes = q_len * qk_dim + kv_len * qk_dim + kv_len * v_dim + output_bytes = q_len * v_dim + self.assertEqual(cost.bytes_accessed, 4 * (input_bytes + output_bytes)) + + @parameterized.parameters( + (1, 0), (7, 5), (8, 4), (9, 5) + ) + def test_integer_pow(self, power, expected_flops_per_element): + cost = cost_estimate.cost_estimate(lambda x: lax.integer_pow(x, power), + jnp.ones(10, dtype=jnp.float32)) + self.assertEqual(cost.flops, 10 * expected_flops_per_element) + self.assertEqual(cost.transcendentals, 0) + self.assertEqual(cost.bytes_accessed, 80) + +if __name__ == "__main__": + absltest.main() From 6892e628fba9b7f87069db229d95533f88e905ad Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 11 Nov 2024 13:51:36 -0800 Subject: [PATCH 281/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e93a258e4494231626c7d3b6a6447e746ea72f9c. PiperOrigin-RevId: 695470898 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ace06be47494..f74c74077198 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "cebb50cc77915aaf16e76b10d78e4c3deb955939" -XLA_SHA256 = "e4be11c05a6b59e8a090e6205d34f138f889826b74113633822fd11b65258668" +XLA_COMMIT = "e93a258e4494231626c7d3b6a6447e746ea72f9c" +XLA_SHA256 = "99f3a6b06230becf013f00009afeee4c89f52818e7a4a1ea4851157dc853830e" def repo(): tf_http_archive( From afa518aa0ef296d8b32c71bf2b2022514f520e70 Mon Sep 17 00:00:00 2001 From: Stella-S-Yan Date: Thu, 7 Nov 2024 00:24:32 +0000 Subject: [PATCH 282/698] Allow setting default_device with platform names. --- jax/_src/config.py | 11 ++++++----- jax/_src/interpreters/pxla.py | 7 ++++++- tests/api_test.py | 11 ++++++----- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index f3edde69981f..30a9ba0be9c9 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1561,7 +1561,9 @@ def _update_default_device_thread_local(val): def _validate_default_device(val): - if val is not None and not isinstance(val, xla_client.Device): + if (val is not None and + not isinstance(val, xla_client.Device) and + val not in ['cpu', 'gpu', 'tpu']): # TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when # all JAX backends use a single C++ device interface. if 'Device' in str(type(val)): @@ -1569,12 +1571,11 @@ def _validate_default_device(val): 'Allowing non-`xla_client.Device` default device: %s, type: %s', repr(val), type(val)) return - raise ValueError('jax.default_device must be passed a Device object (e.g. ' - f"`jax.devices('cpu')[0]`), got: {val!r}") + raise ValueError('jax.default_device must be passed either a Device object (e.g. ' + f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'" + f", got: {val!r}") -# TODO(skye): default_device only accepts devices for now. Make it work with -# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). default_device = string_or_object_state( name='jax_default_device', default=None, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c83d3e3a4804..2ee2ec75e1a0 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1710,7 +1710,10 @@ class DeviceAssignmentMismatchError(Exception): def _get_default_device() -> xc.Device: - return config.default_device.value or xb.local_devices()[0] + if isinstance(config.default_device.value, str): + return xb.get_backend(config.default_device.value).local_devices()[0] + else: + return config.default_device.value or xb.local_devices()[0] def _get_and_check_device_assignment( @@ -1742,6 +1745,7 @@ def _get_and_check_device_assignment( raise DeviceAssignmentMismatchError([ DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None), DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + if first_sharding_info is None and devices: final_device_assignment = devices elif first_sharding_info is None: @@ -2190,6 +2194,7 @@ def lower_sharding_computation( assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) + devices_from_context = (None if context_mesh is None or context_mesh.empty else context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr diff --git a/tests/api_test.py b/tests/api_test.py index 8ab5d90f6e07..e61236a2d254 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -287,13 +287,14 @@ def test_jit_default_device(self, module): self.assertEqual(f(sticky).devices(), system_default_devices) self.assertEqual(f(1).devices(), system_default_devices) - # TODO(skye): make this work! def test_jit_default_platform(self): - with self.assertRaisesWithLiteralMatch( - ValueError, "jax.default_device must be passed a Device object " - "(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"): with jax.default_device("cpu"): - jax.jit(lambda x: x + 1)(1) + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, "cpu") + + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, jax.default_backend()) + def test_complex_support(self): self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j) From 478ea0dcd66182e0b672694a8f88af529fcf6efa Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 7 Nov 2024 19:46:55 -0500 Subject: [PATCH 283/698] Allow 64-bit output types from ffi_call regardless of enable_x64 flag. --- jax/_src/extend/ffi.py | 9 ++++++--- tests/extend_test.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 60db341254c6..5207e6289e26 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -27,7 +27,7 @@ from jax._src import dispatch from jax._src import effects from jax._src import util -from jax._src.callback import _check_shape_dtype, callback_batching_rule +from jax._src.callback import callback_batching_rule from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -209,11 +209,14 @@ def _lowering( def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]: avals: list[core.AbstractValue] = [] - for result in results: + for idx, result in enumerate(results): if isinstance(result, core.AbstractToken): avals.append(result) else: - _check_shape_dtype(result) + if not hasattr(result, "shape") or not hasattr(result, "dtype"): + raise ValueError( + "All elements of result_shape_dtypes must have 'shape' and 'dtype' " + f"attributes. Got {result} at position {idx}.") avals.append(core.ShapedArray(result.shape, result.dtype)) return tuple(avals) diff --git a/tests/extend_test.py b/tests/extend_test.py index 84a907c7331d..b4af8bc23e16 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -27,6 +27,7 @@ from jax._src import abstract_arrays from jax._src import api +from jax._src import config from jax._src import core from jax._src import linear_util from jax._src import prng @@ -326,6 +327,21 @@ def fun(x): "The use of ffi_call attributes requires"): jax.jit(fun).lower(jnp.ones(5)).as_text() + def testAllow64(self): + if config.enable_x64.value: + self.skipTest("Requires enable_x64=False") + def fun(): + return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))() + self.assertIn("tensor", jax.jit(fun).lower().as_text()) + + def testInvalidResultType(self): + with self.assertRaisesRegex( + ValueError, "All elements of result_shape_dtypes.*position 0"): + jex.ffi.ffi_call("test", None)() + with self.assertRaisesRegex( + ValueError, "All elements of result_shape_dtypes.*position 1"): + jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))() + def ffi_call_geqrf(x, **kwargs): if jtu.test_device_matches(["cpu"]): From 3f98c57f7b3fe35b1f7ae8f03e18a7108744461d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 11 Nov 2024 15:32:43 -0800 Subject: [PATCH 284/698] jax.scipy.linalg.toeplitz: support implicit batching --- CHANGELOG.md | 3 +++ jax/_src/scipy/linalg.py | 42 ++++++++++++++++++++++++---------------- tests/linalg_test.py | 34 +++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ab334c15904..d2a45c377889 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `platforms` instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. + * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional + inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` + on the function inputs. * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d014e5ceb24e..1c5eba988e6a 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: - r"""Construct a Toeplitz matrix + r"""Construct a Toeplitz matrix. JAX implementation of :func:`scipy.linalg.toeplitz`. @@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: Notice this implies that :math:`r_0` is ignored. Args: - c: array specifying the first column. Will be flattened - if not 1-dimensional. - r: (optional) array specifying the first row. If not specified, defaults - to ``conj(c)``. Will be flattened if not 1-dimensional. + c: array of shape ``(..., N)`` specifying the first column. + r: (optional) array of shape ``(..., M)`` specifying the first row. Leading + dimensions must be broadcast-compatible with those of ``c``. If not specified, + ``r`` defaults to ``conj(c)``. Returns: - toeplitz matrix of shape ``(c.size, r.size)``. + A Toeplitz matrix of shape ``(... N, M)``. Examples: Specifying ``c`` only: @@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) >>> print("M is Hermitian:", jnp.all(M == M.conj().T)) M is Hermitian: True + + For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices: + + >>> c = jnp.array([[1, 2, 3], [4, 5, 6]]) + >>> jax.scipy.linalg.toeplitz(c) + Array([[[1, 2, 3], + [2, 1, 2], + [3, 2, 1]], + + [[4, 5, 6], + [5, 4, 5], + [6, 5, 4]]], dtype=int32) """ if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) else: check_arraylike("toeplitz", c, r) + return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r))) - c_arr = jnp.asarray(c).flatten() - r_arr = jnp.asarray(r).flatten() - - ncols, = c_arr.shape - nrows, = r_arr.shape - +@partial(jnp.vectorize, signature="(m),(n)->(m,n)") +def _toeplitz(c: Array, r: Array) -> Array: + ncols, = c.shape + nrows, = r.shape if ncols == 0 or nrows == 0: - return jnp.empty((ncols, nrows), - dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype)) - + return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype)) nelems = ncols + nrows - 1 - elems = jnp.concatenate((c_arr[::-1], r_arr[1:])) + elems = jnp.concatenate((c[::-1], r[1:])) patches = lax.conv_general_dilated_patches( elems.reshape((1, nelems, 1)), (nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'), precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) - @partial(jit, static_argnames=("n",)) def hilbert(n: int) -> Array: r"""Create a Hilbert matrix of order n. diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5ace4b5ecf18..d3fe8f476722 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -53,6 +53,22 @@ def _is_required_cuda_version_satisfied(cuda_version): else: return int(version.split()[-1]) >= cuda_version + +def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: + """scipy.linalg.toeplitz with v1.17+ batching semantics.""" + if scipy_version >= (1, 17, 0): + return scipy.linalg.toeplitz(c, r) + elif r is None: + c = np.atleast_1d(c) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c) + else: + c = np.atleast_1d(c) + r = np.atleast_1d(r) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r) + + class NumpyLinalgTest(jtu.JaxTestCase): @jtu.sample_product( @@ -1990,11 +2006,11 @@ def testSqrtmEdgeCase(self, diag, expected, dtype): self.assertAllClose(root, expected, check_dtypes=False) @jtu.sample_product( - cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)], + cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)], cdtype=float_types + complex_types, - rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)], + rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)], rdtype=float_types + complex_types + int_types) - def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): + def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype): if ((rdtype in [np.float64, np.complex128] or cdtype in [np.float64, np.complex128]) and not config.enable_x64.value): @@ -2007,10 +2023,11 @@ def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)] - with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) - self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) + with jax.numpy_rank_promotion("allow"): + with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): + self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz), + jsp.linalg.toeplitz, args_maker) + self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) @jtu.sample_product( shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)], @@ -2028,8 +2045,7 @@ def testToeplitzSymmetricConstruction(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) + self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker) self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) def testToeplitzConstructionWithKnownCases(self): From 38d062dbee639c602b6d552bc068161680db0a10 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 11 Nov 2024 16:13:44 -0800 Subject: [PATCH 285/698] [Mosaic TPU] Support dynamic DMA and ref slice on the 2nd minor when memref is untiled * Generalize any untiled memref to have tiling (packing, 128) * Support dynamic index on 2nd minor. * Support dynamic shape on 2nd minor. PiperOrigin-RevId: 695516124 --- .../tpu/transforms/apply_vector_layout.cc | 4 +- .../tpu/transforms/infer_memref_layout.cc | 70 +++++++++++++++++-- .../tpu/transforms/infer_vector_layout.cc | 6 +- .../transforms/memory_space_specialization.cc | 4 ++ jaxlib/mosaic/dialect/tpu/util.cc | 67 +++++++++++++++--- jaxlib/mosaic/dialect/tpu/util.h | 5 +- tests/pallas/tpu_pallas_test.py | 34 +++++++++ 7 files changed, 165 insertions(+), 25 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 80d0e69e128c..8cb01ee67ad4 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2996,7 +2996,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, // TODO(b/295393167): need to support strided load for bitwidth < 32. } else if (layout_out.bitwidth() == 32 && canReinterpretToUntiledMemref( - memref_ty, ctx.target_shape, + load_op.getBase(), ctx.target_shape, /*allow_minormost_padding=*/true)) { // In this case, if the memref can be reinterpreted to untiled, it is // valid to use any tiling for output. But using native tiling can save us @@ -4204,7 +4204,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, // We accept padding in the minormost dim, because // apply_vector_layout will properly mask stores。 canReinterpretToUntiledMemref( - memref_ty, ctx.target_shape, + store_op.getBase(), ctx.target_shape, /*allow_minormost_padding=*/true)) { // In this case, if the memref can be reinterpreted to untiled, it is // valid to use any tiling for to_store. But using native tiling can save diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 541393fc2758..046b642f98a3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -87,6 +87,16 @@ FailureOr inferLayout(MemRefType memref_ty, int64_t leading_tile_rows = 0) { if (auto tiled_layout_attr = dyn_cast(memref_ty.getLayout())) { + if (leading_tile_rows > 0 && !tiled_layout_attr.getTiles().empty() && + tiled_layout_attr.getTiles().front().dimensions().size() == 2 && + tiled_layout_attr.getTiles().front().dimensions()[0] != + leading_tile_rows) { + return emitError(UnknownLoc::get(memref_ty.getContext()), + "Trying to infer memref layout with sublane tiling ") + << leading_tile_rows + << ", but the memref already has sublane tiling " + << tiled_layout_attr.getTiles().front().dimensions()[0]; + } return tiled_layout_attr; } if (auto affine_map_attr = dyn_cast(memref_ty.getLayout())) { @@ -226,13 +236,25 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, if (auto alloca_op = dyn_cast(op)) { TypedValue arg = alloca_op.getResult(); const MemRefType memref_ty = alloca_op.getResult().getType(); - FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, - target_shape, tpu_tiling_flags)); + // If the memref can be reinterpreted to untiled, force to use tiling + // {1, target.lane_count} for 32 bit. + int64_t leading_tile_rows = 0; + // TODO(b/375038685): generalize untiled memref with packed type which + // needs to update load/store rules. + if (memref_ty.getElementTypeBitWidth() == 32 && memref_ty.getRank() > 1 && + *(memref_ty.getShape().end() - 1) <= target_shape[1]) { + leading_tile_rows = 1; + } + FAILUREOR_ASSIGN_OR_RETURN( + const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, target_shape, + tpu_tiling_flags, leading_tile_rows)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); builder.setInsertionPointAfter(alloca_op); + // TODO(b/376130272): add a canonicalizer for EraseLayoutOp so that if we + // have erase(erase(x)) then we rewrite it to erase(x). auto erase_op = builder.create( arg.getLoc(), MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), @@ -296,22 +318,56 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, } FAILUREOR_ASSIGN_OR_RETURN( - const MemRefType new_memref_ty, + MemRefType new_memref_ty, inferMemref(memref_ty, hardware_generation, target_shape, tpu_tiling_flags, leading_tile_rows)); arg.setType(new_memref_ty); new_arg_types.push_back(arg.getType()); if (memref_ty != new_memref_ty) { + Value val = arg; + Operation * arg_use_op = nullptr; + // If the arg memref can be reinterpreted to untiled, we can insert + // ReinterpretCastOp to use tiling {packing, target.lane_count} before + // EraseLayoutOp for only the arg memrefs and expect the rest memref + // layout inference is based on the casted layout automatically. This + // would help lift many restrictions in alignment check when consuming + // this memref. + if (canReinterpretToUntiledMemref(cast>(val), + target_shape, + /*allow_minormost_padding=*/true) && + // TODO(b/375038685): generalize untiled memref with packed type which + // needs to update load/store rules. + new_memref_ty.getElementTypeBitWidth() == 32) { + auto tiled_layout = + cast(new_memref_ty.getLayout()); + SmallVector tiles(tiled_layout.getTiles()); + SmallVector new_tile_strides(tiled_layout.getTileStrides()); + for (int i = 0; i < new_tile_strides.size() - 2; ++i) { + new_tile_strides[i] *= tiles[0].dimension(0); + } + tiles[0] = ::xla::Tile({1, target_shape[1]}); + new_memref_ty = MemRefType::get( + new_memref_ty.getShape(), new_memref_ty.getElementType(), + TiledLayoutAttr::get(new_memref_ty.getContext(), tiles, + new_tile_strides), + new_memref_ty.getMemorySpace()); + arg_use_op = builder.create(val.getLoc(), + new_memref_ty, val); + val = arg_use_op->getResult(0); + } // Some standard MLIR ops have static checks that seems unreasonable, // and we know they hold in the way they are used in Mosaic. Still, // verification with layouts likes to fail, because it can't statically // prove the properties. auto erase_op = builder.create( - arg.getLoc(), + val.getLoc(), MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(), /*layout=*/nullptr, new_memref_ty.getMemorySpace()), - arg); - arg.replaceAllUsesExcept(erase_op.getResult(), erase_op); + val); + if (!arg_use_op) { + arg_use_op = erase_op; + } + arg.replaceAllUsesExcept(erase_op.getResult(), arg_use_op); } } f.setFunctionType( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index a079815fa165..bf668b8ecb52 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1283,7 +1283,8 @@ class VectorLayoutInferer { layout_tiling, ImplicitDim::kNone)); } else if (bitwidth == 32 && canReinterpretToUntiledMemref( - src_ty, target_shape_, /*allow_minormost_padding=*/true) && + op.getBase(), target_shape_, + /*allow_minormost_padding=*/true) && *(src_ty.getShape().end() - 2) > 1) { // Since it is untiled, we can load from any arbitrary address which // means we can always set the sublane offset to 0. @@ -1620,7 +1621,8 @@ class VectorLayoutInferer { // We accept padding in the minormost dim, because // apply_vector_layout will properly mask stores. canReinterpretToUntiledMemref( - ref_ty, target_shape_, /*allow_minormost_padding=*/true)) { + op.getBase(), target_shape_, + /*allow_minormost_padding=*/true)) { // Since it is untiled, we can store to any arbitrary address which // means the sublane offset can be any value and we can fold it to // 2nd minor index. diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index 569038500067..0fd88ac1f294 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -70,6 +70,10 @@ LogicalResult specializeMemorySpace(TypedValue value, to_update.pop_back(); // Here we only have to handle the operations allowed on refs with // unspecified memory space. + if (auto op = dyn_cast(some_op)) { + updateResultFrom(op, op.getInput().getType()); + continue; + } if (auto op = dyn_cast(some_op)) { updateResultFrom(op, op.getMemRef().getType()); continue; diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index b74a43dce32f..d434837efaf5 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -23,6 +23,8 @@ limitations under the License. #include "llvm/Support/MathExtras.h" #include "absl/types/span.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" +#include "mlir/include/mlir/IR/Value.h" +#include "mlir/include/mlir/IR/ValueRange.h" #include "mlir/include/mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -69,31 +71,74 @@ std::optional> isTransposedMatmul( return std::pair{lhs_transposed, rhs_transposed}; } -bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty, +bool canReinterpretToUntiledMemref(TypedValue tiled_memref, const std::array& target_shape, bool allow_minormost_padding) { + MemRefType tiled_memref_ty = tiled_memref.getType(); auto tiled_layout = dyn_cast(tiled_memref_ty.getLayout()); + ValueRange dynamic_sizes = {}; + if (!tiled_layout) { + if (auto erase_op = tiled_memref.getDefiningOp()) { + tiled_memref = erase_op.getOperand(); + tiled_memref_ty = tiled_memref.getType(); + tiled_layout = + dyn_cast(tiled_memref_ty.getLayout()); + // TODO(b/375641258): Currently we rely on the pattern `slice -> + // (squeeze)* -> eraseLayout` to get the dynamic sizes, but other patterns + // may not work here: eg., slice -> eraseLayout -> reshape -> + // eraseLayout`. We should fix this! For now, if we can not get the + // expected dynamic sizes, we consider the memref cannot be reinterpreted + // to untiled. + Value ref = tiled_memref; + while (auto squeeze_op = ref.getDefiningOp()) { + ref = squeeze_op.getInput(); + } + if (auto slice_op = ref.getDefiningOp()) { + dynamic_sizes = slice_op.getDynamicSizes(); + } + } + } if (!tiled_layout) { // We expect the tiled memref to have a tiled layout. return false; } + if (tiled_memref_ty.getNumDynamicDims() != dynamic_sizes.size()) { + return false; + } if (tiled_layout.getTiles().empty() || tiled_layout.getTiles().front().dimensions().size() != 2 || tiled_memref_ty.getRank() < 2) { - // TODO(jevinjiang): Currently we only support >= 2D memref, we might + // TODO(b/375642202): Currently we only support >= 2D memref, we might // need to handle 1D memref if we find a use case. return false; } - if (!allow_minormost_padding && - *(tiled_memref_ty.getShape().end() - 1) != target_shape[1]) { - return false; - } + auto rank = tiled_memref_ty.getRank(); auto packing = 32 / tiled_memref_ty.getElementTypeBitWidth(); - return (*(tiled_memref_ty.getShape().end() - 1) <= target_shape[1] && - *(tiled_memref_ty.getShape().end() - 2) % packing == 0 && - *(tiled_layout.getTileStrides().end() - 1) == 1 && - *(tiled_layout.getTileStrides().end() - 2) == 1); + if (tiled_memref_ty.isDynamicDim(rank - 1)) { + // TODO(jevinjiang): we can still allow the minormost padding if we know the + // max bound of the dynamic size is not larger than the target_shape[1]. + if (!isGuaranteedDivisible(dynamic_sizes.back(), target_shape[1])) { + return false; + } + dynamic_sizes = dynamic_sizes.drop_back(); + } else { + if (!allow_minormost_padding && + tiled_memref_ty.getShape()[rank - 1] != target_shape[1]) { + return false; + } + } + if (tiled_memref_ty.isDynamicDim(rank - 2)) { + if (!isGuaranteedDivisible(dynamic_sizes.back(), packing)) { + return false; + } + } else { + if (tiled_memref_ty.getShape()[rank - 2] % packing != 0) { + return false; + } + } + // Check if the minormost dim has a single tile. + return *(tiled_layout.getTileStrides().end() - 1) == 1 && + *(tiled_layout.getTileStrides().end() - 2) == 1; } - } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 5b068fedd3fd..7c602e9a0bc9 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -2,7 +2,6 @@ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ #include -#include #include #include #include @@ -17,7 +16,7 @@ #include "mlir/Support/LogicalResult.h" #include "absl/types/span.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "tsl/platform/statusor.h" +#include "mlir/include/mlir/IR/Value.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? @@ -112,7 +111,7 @@ std::optional> isTransposedMatmul( // considered as an untiled memref, except for potential padding in the // minormost dimension up to target_shape[1] (if allow_minormost_padding is // true). -bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty, +bool canReinterpretToUntiledMemref(TypedValue tiled_memref, const std::array &target_shape, bool allow_minormost_padding = false); diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 544ed1ac3ecc..347a06c50323 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1472,6 +1472,40 @@ def kernel(index, x, y, sem): np.testing.assert_array_equal(y, i) del y + def test_dynamic_dma_on_2nd_minor(self): + def kernel(array, data, index, size, _, sem): + pltpu.async_copy( + data.at[pl.ds(0, size[0])], array.at[pl.ds(index[0], size[0])], sem + ).wait() + + def run(array, data, index, size): + return pl.pallas_call( + kernel, + out_shape=array, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + scratch_shapes=[ + pltpu.SemaphoreType.DMA, + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(array, data, index, size) + + array = jnp.zeros((1024, 128), jnp.int32) + data = jnp.ones((8, 128), jnp.int32) + index = jnp.array([3], jnp.int32) + size = jnp.array([5], jnp.int32) + + expected = array.at[index[0] : index[0] + size[0]].set( + data[index[0] : index[0] + size[0]] + ) + result = run(array, data, index, size) + np.testing.assert_array_equal(result, expected) + class PallasCallDMAInterpretTest(PallasCallDMATest): INTERPRET = True From 54e72d505413aa46e73157e1d14994c4917b46b9 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 11 Nov 2024 22:45:32 -0800 Subject: [PATCH 286/698] Add wraparound for 2x2x2 v5p PiperOrigin-RevId: 695603337 --- jax/_src/mesh_utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index 996a6811a20d..16e34e1afaef 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -33,6 +33,7 @@ _TPU_V4 = 'TPU v4' _TPU_V5_LITE = "TPU v5 lite" _TPU_V5E = "TPU v5e" +_TPU_V5P = "TPU v5p" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. @@ -70,6 +71,7 @@ _TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) _V5E_TRAY_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 4) _V5E_TRAY_IOTA_ORDER = (0, 4, 2, 6, 1, 5, 3, 7) +_V5P_2x2x2_ORDER = (0, 1, 3, 2, 6, 7, 5, 4) def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], @@ -148,6 +150,35 @@ def _v5e_create_device_mesh( return None +def _v5p_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates device assignment for selected topologies. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) + bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 + # Our ring re-ordering makes sense only if the passed-in devices are + # sequential, which may not always be the case. reversed() changes z-minor to + # x-minor. + sequential_devices = sorted( + devices, + key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) + + if bound_x == bound_y == 2 and bound_z == 2: + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_V5P_2x2x2_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + return None + # Registers functions to create device mesh for specific device kinds. Takes # precedence over the more general logic in create_device_mesh(). Handler may # return None; in that case, it will fall back to using the default logic. @@ -158,6 +189,7 @@ def _v5e_create_device_mesh( _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, _TPU_V5_LITE: _v5e_create_device_mesh, + _TPU_V5P: _v5p_create_device_mesh, } From 31e42d8e9154d52b8807f527eaa814f03fa6cd4b Mon Sep 17 00:00:00 2001 From: Martin Huschenbett Date: Tue, 12 Nov 2024 11:44:31 +0100 Subject: [PATCH 287/698] Make sure compilation_cache.is_cache_used always returns a bool In some cases, `compilation_cache.is_cache_used` can reach the end of the function body without returning anything. This amounts to an implicit `return None`, which is not in line with the functions return type of `bool`. We fix this by adding a final `return False` to the function. --- jax/_src/compilation_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index c75d1783f356..89dd97175f00 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -84,6 +84,8 @@ def is_cache_used(backend: xla_client.Client) -> bool: _cache_used = True return _cache_used + return False + def get_file_cache(path: str) -> tuple[CacheInterface, str] | None: """Returns the file cache and the path to the cache.""" From c92507772caac5581801120a1ab6d55baf8b8dbb Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 10 Nov 2024 10:31:52 +0200 Subject: [PATCH 288/698] Cleanup more remnants of the jax.experimental.host_callback Removes the outfeed rewriter mechanism and helper functions `jaxpr_uses_outfeed`, which were used only by `jax.experimental.host_callback`. --- jax/_src/core.py | 27 --------------------------- jax/_src/dispatch.py | 13 +------------ jax/_src/interpreters/partial_eval.py | 2 +- jax/_src/interpreters/pxla.py | 2 -- jax/core.py | 3 --- 5 files changed, 2 insertions(+), 45 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index e3aa40f75bb6..beb755348a0f 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2111,33 +2111,6 @@ def get_bind_params(self, params): closed_call_p.def_effectful_abstract_eval( lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects)) - -outfeed_primitives: set[Primitive] = set() -def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool: - """Finds if there are outfeed primitives anywhere inside a Jaxpr.""" - return any(primitive_uses_outfeed(eqn.primitive, eqn.params) - for eqn in jaxpr.eqns) - -def _param_uses_outfeed(param): - if type(param) is Jaxpr: - if jaxpr_uses_outfeed(param): - return True - elif type(param) is ClosedJaxpr: - if jaxpr_uses_outfeed(param.jaxpr): - return True - return False - -def primitive_uses_outfeed(prim: Primitive, params: dict) -> bool: - if prim in outfeed_primitives: - return True - for param in params.values(): - if isinstance(param, tuple): - if any(unsafe_map(_param_uses_outfeed, param)): - return True - elif _param_uses_outfeed(param): - return True - return False - # ------------------- Map ------------------- class MapPrimitive(Primitive): diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 8d53742bc7cf..081abf394f98 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -16,7 +16,7 @@ from __future__ import annotations import atexit -from collections.abc import Callable, Sequence +from collections.abc import Sequence import contextlib import dataclasses import enum @@ -278,17 +278,6 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool: return False -# We can optionally set a Jaxpr rewriter that can be applied just before -# compilation. This mechanism is used for compiling id_tap, we can -# remove it once we bring the id_tap implementation into the core. -outfeed_rewriter: Callable[[core.Jaxpr], core.Jaxpr] | None = None -def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: - if outfeed_rewriter is not None: - return outfeed_rewriter(jaxpr) - else: - return jaxpr - - def check_arg(arg: Any): if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)): raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid " diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index c09a8c711984..ad97ef325f64 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1397,7 +1397,7 @@ def write(x: Atom, b: bool) -> None: def has_effects(eqn: JaxprEqn) -> bool: effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} - return bool(effs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params) + return bool(effs) new_eqns = [] map(write, jaxpr.outvars, used_outputs) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c83d3e3a4804..9a17194d46c9 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -723,7 +723,6 @@ def stage_parallel_callable( jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic( fun, sharded_avals, pe.debug_info_final(fun, "pmap")) jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info) - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) assert len(out_sharded_avals) == len(pci.out_axes), ( len(out_sharded_avals), len(pci.out_axes)) @@ -1783,7 +1782,6 @@ def _dce_jaxpr(closed_jaxpr, api_name, fun_name, donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) del kept_const_idx - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) return closed_jaxpr, donated_invars, kept_var_idx, name_stack diff --git a/jax/core.py b/jax/core.py index fb08763fd3a1..8d7c546f0754 100644 --- a/jax/core.py +++ b/jax/core.py @@ -87,7 +87,6 @@ is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, jaxpr_as_fun as jaxpr_as_fun, - jaxpr_uses_outfeed as jaxpr_uses_outfeed, jaxprs_in_params as jaxprs_in_params, join_effects as join_effects, lattice_join as lattice_join, @@ -100,9 +99,7 @@ new_jaxpr_eqn as new_jaxpr_eqn, no_axis_name as no_axis_name, no_effects as no_effects, - outfeed_primitives as outfeed_primitives, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, - primitive_uses_outfeed as primitive_uses_outfeed, pytype_aval_mappings as pytype_aval_mappings, raise_to_shaped as raise_to_shaped, raise_to_shaped_mappings as raise_to_shaped_mappings, From fb68c97a0d37de153b3b31aa9e3f520da8429cfa Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 10 Nov 2024 09:19:35 +0200 Subject: [PATCH 289/698] [shape_poly] Fix the handling of jvp(lax.sort) Previously, `jvp(lax.sort)` used a shape-dependent dtype, for the types of indices (either `int32` or `int64`, depending on the size of the dimension). For shape polymorphism, input shapes can affect other intermediate shapes, but not `dtype`s. In this case it is easy to just use `int46` independent of the actual shape. --- jax/_src/lax/lax.py | 12 ++++++------ tests/shape_poly_test.py | 8 ++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8b6a517a54b3..6e1e3ea14fb1 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5317,14 +5317,14 @@ def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys): shape = primals[0].shape iotas = [] for dim, size in enumerate(shape): - dtype = np.int32 if size < np.iinfo(np.int32).max else np.int64 - iotas.append(broadcasted_iota(dtype, shape, dim)) - primals = sort_p.bind(*(primals + (iotas[dimension],)), dimension=dimension, - is_stable=is_stable, num_keys=num_keys) - idx = tuple(primals[-1] if i == dimension else iotas[i] + iotas.append(broadcasted_iota(np.int64, shape, dim)) + sorted_primals_and_idx = sort_p.bind( + *primals, iotas[dimension], dimension=dimension, + is_stable=is_stable, num_keys=num_keys) + idx = tuple(sorted_primals_and_idx[-1] if i == dimension else iotas[i] for i in range(len(shape))) tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents) - return tuple(primals[:-1]), tangents_out + return tuple(sorted_primals_and_idx[:-1]), tangents_out def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys): prototype_arg, new_bdim = next( diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index ead77e2b5053..eda4c4309960 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -3302,6 +3302,14 @@ def test_vmap_error(self): lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1 + x.shape[0] // 4, axis=0), arg_descriptors=[RandArg((13, 4), _f32)], polymorphic_shapes=["b, ..."]), + PolyHarness("sort", "", + lambda a: lax.sort(a), + arg_descriptors=[RandArg((16,), _f32)], + polymorphic_shapes=["b"]), + PolyHarness("jvp_sort", "", + lambda a: jax.jvp(lax.sort, (a,), (a,)), + arg_descriptors=[RandArg((16,), _f32)], + polymorphic_shapes=["b"]), PolyHarness("jnp_split", "idx_tuple_ct", # The indices are a tuple with constants lambda a: jnp.split(a, (2,)), From cb82609ae599261f30260cea4b3da138cde05e8c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 12 Nov 2024 04:24:36 -0800 Subject: [PATCH 290/698] [pallas:triton] Fix reshape lowering with scalar output shape. PiperOrigin-RevId: 695678909 --- jax/_src/pallas/triton/lowering.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f5fb44de9a9d..1a0400ebf0db 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1624,6 +1624,11 @@ def _reshape_lowering_rule( return _splat(a, out_aval.shape) ty = ir.RankedTensorType(a.type) + + # Triton Reshape doesn't support scalar result types (only 0d tensors). + if not out_aval.shape: + return _reduce_lowering(jnp.add, ctx, a, axes=tuple(range(ty.rank))) + return tt_dialect.reshape( ir.RankedTensorType.get([*out_aval.shape], ty.element_type, ty.encoding), a, From 15f30a9e9c712ed414fe69a38d6f0ee3a804f41b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 12 Nov 2024 04:38:24 -0800 Subject: [PATCH 291/698] [pallas:mosaic_gpu] `emit_pipeline` now maintains the grid indices Previously, it was recomputing them at every loop iteration. PiperOrigin-RevId: 695682116 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 34 +++++++++++-------- .../mosaic/gpu/fragmented_array.py | 6 +++- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 9a17646f0758..21267b50a007 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -30,6 +30,7 @@ from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives from jax.experimental import pallas as pl +import jax.numpy as jnp map = util.safe_map @@ -72,15 +73,16 @@ def copy_out(self, slot, grid_indices): ) -def make_grid_indices( - step: jax.typing.ArrayLike, grid: Sequence[int] +def _inc_grid_by_1( + indices: tuple[jax.Array, ...], grid: Sequence[int] ) -> tuple[jax.Array, ...]: - # TODO(slebedev): Maintain the grid index through the fori_loop instead. - indices = [] - for size in reversed(grid): - indices.append(lax.rem(step, size)) - step = lax.div(step, size) - return tuple(reversed(indices)) + next_indices = [] + carry: bool | jax.Array = True + for idx, size in reversed(list(zip(indices, grid))): + next_idx = lax.select(carry, idx + 1, idx) + carry = next_idx == size + next_indices.append(lax.select(carry, 0, next_idx).astype(idx.dtype)) + return tuple(reversed(next_indices)) def emit_pipeline( @@ -143,15 +145,15 @@ def scoped_pipeline( ): map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) - def loop_body(step, _): + def loop_body(step, carry): slot = step % max_concurrent_steps + indices, fetch_indices = carry # Wait for the current GMEM->SMEM copy to complete. gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) - indices = make_grid_indices(step, grid) with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): body( *(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs)) @@ -166,17 +168,19 @@ def loop_body(step, _): jax.lax.cond( fetch_step < num_steps, lambda: map( - lambda bref: bref.copy_in( - fetch_slot, make_grid_indices(fetch_step, grid), barrier_ref - ), + lambda bref: bref.copy_in(fetch_slot, fetch_indices, barrier_ref), in_brefs, ), lambda: [None] * len(in_brefs), ) - return () + return _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid) - lax.fori_loop(0, num_steps, loop_body, ()) + indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid) + fetch_indices = indices + for _ in range(max_concurrent_steps): + fetch_indices = _inc_grid_by_1(fetch_indices, grid) + lax.fori_loop(0, num_steps, loop_body, (indices, fetch_indices)) # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 0c5dd0ef793e..fd989d052917 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1184,7 +1184,11 @@ def select(self, on_true, on_false): or ir.IntegerType(self.mlir_dtype).width != 1 ): raise NotImplementedError - return self._pointwise(arith.select, on_true, on_false) + # We change the receiver here, because the return type is defined by + # `on_true` and `on_false` and not the predicate `self`. + return on_true._pointwise( + lambda t, p, f: arith.select(p, t, f), self, on_false, + ) def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): """Call a function for each value and index.""" From 2582a337a6fcf4d0a6357703cbcc308936a0dedf Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Tue, 12 Nov 2024 04:41:36 -0800 Subject: [PATCH 292/698] Explicitly raise an error if more than 65535 channels are created `xla::HostCallbackArgInfo` uses `uint16_t` for channel ids, so we should warn users explicitly when the channel ids exceed the UINT16_MAX instead of silently wrapping around. PiperOrigin-RevId: 695682871 --- jax/_src/interpreters/mlir.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2c0e26019e4d..bef465c6aa75 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -770,7 +770,13 @@ def backend(self) -> xb.XlaBackend: return self.backend_or_name def new_channel(self) -> int: - return next(self.channel_iterator) + channel = next(self.channel_iterator) + # `xla::HostCallback` requires a 16-bit channel ID. + if channel >= (1 << 16): + raise RuntimeError( + "Host callback lowering created too many channels. PjRt does not" + " support more than 65535 channels") + return channel # Adds an IFRT host callback object to the context. A reference to these # callbacks will be provided to IFRT during compilation so it can do things From 9bb6366741c86c2b8444aca2bafcdb1364e8d12d Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 12 Nov 2024 05:29:40 -0800 Subject: [PATCH 293/698] Allow more output storage types for some dot algorithms. As reported in https://github.com/jax-ml/jax/issues/24794, there were some dot products that were resulting in an unnecessary conversion. This change makes the output storage type selection more flexible. Fixes https://github.com/jax-ml/jax/issues/24794 PiperOrigin-RevId: 695694179 --- jax/_src/lax/lax.py | 69 ++++++++++++++++++++++++++++++++++----------- tests/lax_test.py | 13 +++++++++ 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6e1e3ea14fb1..9781f67152c8 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -906,7 +906,7 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: return self.lhs_precision_type @property - def accumulation_type(self) -> DTypeLike | None: + def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: match self: case ( DotAlgorithmPreset.DEFAULT | @@ -914,8 +914,17 @@ def accumulation_type(self) -> DTypeLike | None: DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): return None + case ( + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | + DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + ): + return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn, + dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, + dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz) case DotAlgorithmPreset.F16_F16_F16: return np.float16 + case DotAlgorithmPreset.F16_F16_F32: + return (np.float32, np.float16) case DotAlgorithmPreset.BF16_BF16_BF16: return dtypes.bfloat16 case DotAlgorithmPreset.F64_F64_F64: @@ -3619,6 +3628,37 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype) +def get_algorithm_compute_types( + algorithm: DotAlgorithm | DotAlgorithmPreset, + lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike, + out_dtype: DTypeLike | None = None, +) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]: + def maybe_convert_dtype(input_dtype, target_dtype): + if target_dtype is None: + return input_dtype + if not isinstance(target_dtype, tuple): + target_dtype = (target_dtype,) + if any(input_dtype == d for d in target_dtype): + return input_dtype + return target_dtype[0] + if algorithm == DotAlgorithmPreset.BF16_BF16_F32: + lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type) + rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type) + if lhs_dtype == dtypes.bfloat16: + out_dtype = maybe_convert_dtype(out_dtype, + (np.float32, dtypes.bfloat16)) + else: + out_dtype = maybe_convert_dtype(out_dtype, np.float32) + return lhs_dtype, rhs_dtype, out_dtype + else: + return ( + maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type), + maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type), + maybe_convert_dtype(out_dtype, algorithm.accumulation_type), + ) + + def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, out_type, platform: str = "default"): @@ -3656,20 +3696,17 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): # If an explicit algorithm was specified, we always cast the input types to # the correct types. def maybe_convert_dtype(operand, operand_aval, target_dtype): - if target_dtype is None: - return operand, operand_aval.dtype - if not isinstance(target_dtype, tuple): - target_dtype = (target_dtype,) - if any(operand_aval.dtype == d for d in target_dtype): - return operand, operand_aval.dtype - aval = core.ShapedArray(operand_aval.shape, target_dtype[0]) - return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0] - - lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type) - rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type) - accumulation_type = precision.accumulation_type - if accumulation_type is not None: - accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type) + if target_dtype is None or operand_aval.dtype == target_dtype: + return operand + aval = core.ShapedArray(operand_aval.shape, target_dtype) + return mlir.convert_hlo(ctx, operand, operand_aval, aval) + + lhs_dtype, rhs_dtype, accumulation_dtype = get_algorithm_compute_types( + precision, lhs_dtype, rhs_dtype, aval_out.dtype) + lhs = maybe_convert_dtype(lhs, lhs_aval, lhs_dtype) + rhs = maybe_convert_dtype(rhs, rhs_aval, rhs_dtype) + if accumulation_dtype is not None: + accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_dtype) if precision != DotAlgorithmPreset.DEFAULT: algorithm_kwarg = { @@ -3690,7 +3727,6 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) - lhs_dtype = rhs_dtype = aval_out.dtype else: # cpu and gpu # Do not convert mixed fp8 types to output type. if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype): @@ -3698,7 +3734,6 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) - lhs_dtype = rhs_dtype = aval_out.dtype result = hlo.dot_general( mlir.aval_to_ir_type(accumulation_aval), diff --git a/tests/lax_test.py b/tests/lax_test.py index ab1557450864..17132996c429 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1146,6 +1146,19 @@ def fun(lhs, rhs): lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) self.assertEqual(fun(lhs, rhs).dtype, np.float16) + def testDotAlgorithmAllowedOutputStorage(self): + # see https://github.com/jax-ml/jax/issues/24794 + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only supported on GPU.") + def fun(lhs, rhs): + return lax.dot(lhs, rhs, precision="F16_F16_F32", + preferred_element_type=np.float16) + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16) + self.assertNotIn("convert", jax.jit(fun).lower(lhs, rhs).as_text()) + def testDotAlgorithmConfig(self): lhs_shape = (3, 4) rhs_shape = (4, 3) From 21e98b5ce43dd1fda5d10fb7441bfdd011811095 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 12 Nov 2024 05:32:05 -0800 Subject: [PATCH 294/698] Fix overflow error in GPU batched linear algebra kernels. As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS. PiperOrigin-RevId: 695694648 --- jaxlib/gpu/make_batch_pointers.cu.cc | 11 +++++++---- jaxlib/gpu/make_batch_pointers.h | 5 ++++- tests/linalg_test.py | 8 ++++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index b10655645924..3a24e355ead0 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include "jaxlib/gpu/vendor.h" @@ -24,8 +25,9 @@ namespace JAX_GPU_NAMESPACE { namespace { __global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, - int batch, int batch_elem_size) { - for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; + int64_t batch, + int64_t batch_elem_size) { + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; idx += blockDim.x * gridDim.x) { buffer_out[idx] = buffer_in + idx * batch_elem_size; } @@ -33,8 +35,9 @@ __global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, } // namespace void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, - void* buffer_out, int batch, int batch_elem_size) { - const int block_dim = 128; + void* buffer_out, int64_t batch, + int64_t batch_elem_size) { + const std::size_t block_dim = 128; const std::size_t grid_dim = std::min(1024, (batch + block_dim - 1) / block_dim); MakeBatchPointersAsyncKernel<<>>( diff --git a/jaxlib/gpu/make_batch_pointers.h b/jaxlib/gpu/make_batch_pointers.h index f2fd064961e8..f43ac25c7e50 100644 --- a/jaxlib/gpu/make_batch_pointers.h +++ b/jaxlib/gpu/make_batch_pointers.h @@ -16,13 +16,16 @@ limitations under the License. #ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ #define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ +#include + #include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, - void* buffer_out, int batch, int batch_elem_size); + void* buffer_out, int64_t batch, + int64_t batch_elem_size); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5ace4b5ecf18..5538ce38baa3 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1450,6 +1450,14 @@ def testLuBatching(self, shape, dtype): self.assertAllClose(ls, actual_ls, rtol=5e-6) self.assertAllClose(us, actual_us) + @jtu.skip_on_devices("cpu", "tpu") + @jtu.skip_on_flag("jax_skip_slow_tests", True) + def testBatchedLuOverflow(self): + # see https://github.com/jax-ml/jax/issues/24843 + x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32) + lu, _, _ = lax.linalg.lu(x) + self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9)) + @jtu.skip_on_devices("cpu", "tpu") @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") From a99ccd9341528cff3d55fe7458bb8bf774497f45 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 12 Nov 2024 07:02:16 -0800 Subject: [PATCH 295/698] Remove GPU test with unreasonably large memory footprint. PiperOrigin-RevId: 695717589 --- tests/linalg_test.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5538ce38baa3..5ace4b5ecf18 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1450,14 +1450,6 @@ def testLuBatching(self, shape, dtype): self.assertAllClose(ls, actual_ls, rtol=5e-6) self.assertAllClose(us, actual_us) - @jtu.skip_on_devices("cpu", "tpu") - @jtu.skip_on_flag("jax_skip_slow_tests", True) - def testBatchedLuOverflow(self): - # see https://github.com/jax-ml/jax/issues/24843 - x = self.rng().standard_normal((1500000, 20, 20)).astype(np.float32) - lu, _, _ = lax.linalg.lu(x) - self.assertTrue(jnp.all(lu.std(axis=[1, 2]) > 0.9)) - @jtu.skip_on_devices("cpu", "tpu") @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") From 0720942b18784ccee4ba6e1899eba54c20a1f717 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 5 Nov 2024 15:36:14 -0800 Subject: [PATCH 296/698] Fix debug_nans false positive in jnp.quantile --- jax/_src/numpy/reductions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index fa8d73361e2b..be1e55675079 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -2360,7 +2360,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) + with jax.debug_nans(False): + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) From f1caa0ed69f0aecc80661a643e9bd8bd6a2abe9e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Nov 2024 11:08:34 -0800 Subject: [PATCH 297/698] Remove some obsolete deprecation registrations PiperOrigin-RevId: 693793727 --- jax/_src/deprecations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 962244a321a9..c7a956068981 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -125,8 +125,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register("jax-numpy-astype-complex-to-real") register("jax-numpy-array-none") -register('jax-scipy-beta-args') -register('tracer-hash') register('jax-numpy-reshape-newshape') register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') From 95f7b247db96871b88c7a0d24275bae64b8250ce Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 6 Nov 2024 12:18:34 -0800 Subject: [PATCH 298/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0f6331b1881ae34c8b1cd59580900d556bc8305c. PiperOrigin-RevId: 693819727 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 3dc24da2559b..9190c136f6e8 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "5a9f79f295ba8d16afce24ea8724da525b8eb87d" -XLA_SHA256 = "83e516dd8f7c61541aa9e2cba7fe480166ea23f28a41fed445fef4c5b6d45519" +XLA_COMMIT = "0f6331b1881ae34c8b1cd59580900d556bc8305c" +XLA_SHA256 = "1e4e4317750b2bb2845c6138aaa96b0d94249484d23e9c799d2dd6ecd4b8dd3c" def repo(): tf_http_archive( From 8463eb08d886058d25fd2bd9abf8573b2121dbab Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 24 Oct 2024 17:57:06 -0700 Subject: [PATCH 299/698] Adding start index and kv_seq_len to decode kernel --- .../pallas/ops/gpu/decode_attention.py | 347 +++++++++++------- tests/pallas/gpu_attention_test.py | 31 +- 2 files changed, 243 insertions(+), 135 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index a7e1b33e1f35..d09f1fbac113 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -14,6 +14,7 @@ """Module containing decode attention.""" from __future__ import annotations +import math import functools from typing import Any @@ -24,82 +25,115 @@ from jax.experimental.pallas import triton as plgpu import jax.numpy as jnp - def attn_forward_kernel( - q_ref, # [num_heads, head_dim] - k_ref, # [k_seq_len, head_dim] - v_ref, # [k_seq_len, head_dim] - o_ref: Any, # [num_heads, head_dim] + # inputs + q_ref, # [num_heads, head_dim] + k_ref, # [k_seq_len, head_dim] + v_ref, # [k_seq_len, head_dim] + start_idx_ref, # [] (i.e., scalar) + kv_seq_len_ref, # [] (i.e., scalar) + # outputs + o_ref: Any, # [num_heads, head_dim] *residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,] sm_scale: float, block_k: int, + block_h: int, + num_heads: int, ): - block_h, head_dim = q_ref.shape - k_seq_len, _ = k_ref.shape - start_q = pl.program_id(0) + _, head_dim = q_ref.shape + split_k_seq_len, _ = k_ref.shape + prog_i, prog_j = pl.program_id(0), pl.program_id(1) + q_slice = pl.ds(0, block_h) + q_mask = (jnp.arange(block_h) < num_heads - block_h * prog_i)[:, None] + + def _compute(start_idx, kv_seq_len, o, m_i, l_i): + # Load q: it will stay in L1 throughout. Indices form a matrix because we + # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. + # q tile has shape [block_h, head_dim]. + q = pl.load(q_ref, (q_slice, pl.ds(None)), mask=q_mask) + + def _dot(a, b): + # if a.shape[0] == 1: + # # Use matrix vector product + # return (a.T * b).sum(axis=0, keepdims=True) + return pl.dot(a, b) + + mask_indices = jnp.arange(block_k) + + # Loop over blocks of kv to process entire kv seq_len. + # Grid loops over q blocks over num_heads. + def body(start_k, carry): + o_prev, m_prev, l_prev = carry + curr_k_slice = pl.ds(start_k * block_k, block_k) + + k = pl.load(k_ref, (curr_k_slice, slice(None))) + qk = _dot(q, k.T) # [block_h, block_k] + if sm_scale != 1.0: + qk *= sm_scale # [block_h, block_k] + + # apply mask if start or sequence length is specified + if start_idx_ref is not None or kv_seq_len_ref is not None: + indices = (prog_j * split_k_seq_len + start_k * block_k + mask_indices) + mask = ((indices >= start_idx) & (indices < kv_seq_len))[None, :] + qk += (~mask) * (0.7 * jnp.finfo(qk.dtype).min) + + m_curr = qk.max(axis=-1) + m_next = jnp.maximum(m_prev, m_curr) + correction = jnp.exp(m_prev - m_next) + l_prev_corr = correction * l_prev + s_curr = jnp.exp( + qk - m_next[:, None] + ) # Use m_next instead of m_curr to avoid a correction on l_curr + l_curr = s_curr.sum(axis=-1) + l_next = l_prev_corr + l_curr + v = pl.load(v_ref, (curr_k_slice, slice(None))) + o_curr = _dot(s_curr.astype(v.dtype), v) + + # flash2 unscaled_o + o_next = correction[:, None] * o_prev + o_curr + return o_next, m_next, l_next + + max_it = jnp.minimum(pl.cdiv((kv_seq_len - prog_j * split_k_seq_len), + block_k), split_k_seq_len // block_k) + (o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i)) + return o, m_i, l_i # o is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop. - m_i = jnp.zeros(block_h, dtype=jnp.float32) - float("inf") + m_i = jnp.zeros(block_h, dtype=jnp.float32) + jnp.finfo(jnp.float32).min l_i = jnp.zeros(block_h, dtype=jnp.float32) o = jnp.zeros((block_h, head_dim), dtype=jnp.float32) - # Load q: it will stay in L1 throughout. Indices form a matrix because we - # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. - # q tile has shape [block_h, head_dim]. - curr_q_slice = pl.dslice(start_q * block_h, block_h) - q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) - - def _dot(a, b): - # if a.shape[0] == 1: - # # Use matrix vector product - # return (a.T * b).sum(axis=0, keepdims=True) - return pl.dot(a, b) - - # Loop over blocks of kv to process entire kv seq_len. - # Grid loops over q blocks over num_heads. - def body(start_k, carry): - o_prev, m_prev, l_prev = carry - curr_k_slice = pl.dslice(start_k * block_k, block_k) - - k = pl.load(k_ref, (curr_k_slice, slice(None))) - qk = _dot(q, k.T) # [block_h, block_k] - if sm_scale != 1.0: - qk *= sm_scale # [block_h, block_k] - - m_curr = qk.max(axis=-1) - m_next = jnp.maximum(m_prev, m_curr) - correction = jnp.exp(m_prev - m_next) - l_prev_corr = correction * l_prev - s_curr = jnp.exp( - qk - m_next[:, None] - ) # Use m_next instead of m_curr to avoid a correction on l_curr - l_curr = s_curr.sum(axis=-1) - l_next = l_prev_corr + l_curr - v = pl.load(v_ref, (curr_k_slice, slice(None))) - o_curr = _dot(s_curr.astype(v.dtype), v) - - # flash2 unscaled_o - o_next = correction[:, None] * o_prev + o_curr - return o_next, m_next, l_next - - upper_bound = pl.cdiv(k_seq_len, block_k) - # o is left unscaled; it will be scaled in the final reduction step - o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) + start_idx = split_k_seq_len * prog_j + if start_idx_ref is not None: + start_idx = jnp.maximum(start_idx, pl.load(start_idx_ref, ())) + kv_seq_len = (prog_j + 1) * split_k_seq_len # lower bound on actual k_seq_len + if kv_seq_len_ref is not None: + kv_seq_len = jnp.minimum(kv_seq_len, pl.load(kv_seq_len_ref, ())) + if start_idx_ref is None and kv_seq_len is None: + o, m_i, l_i = _compute(start_idx, kv_seq_len, o, m_i, l_i) + else: + o, m_i, l_i = jax.lax.cond( + start_idx >= kv_seq_len, lambda: (o, m_i, l_i), + lambda: _compute(start_idx, kv_seq_len, o, m_i, l_i)) + + # Write output to dram. if residual_refs: l_ref, m_ref = residual_refs - pl.store(l_ref, (curr_q_slice,), l_i) - pl.store(m_ref, (curr_q_slice,), m_i) - # Write output to dram. + vec_q_mask = q_mask.reshape(-1) if q_mask is not None else None + pl.store(l_ref, q_slice, l_i, mask=vec_q_mask) + pl.store(m_ref, q_slice, m_i, mask=vec_q_mask) o = o.astype(o_ref.dtype) - pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) + pl.store(o_ref, (q_slice, pl.ds(None)), o, mask=q_mask) -def attn_unbatched( - q, # [num_heads, head_dim] - k, # [k_seq_len, head_dim] - v, # [k_seq_len, head_dim] +def decode_attn_unbatched( + q, # [num_heads, head_dim] + k, # [k_seq_len, head_dim] + v, # [k_seq_len, head_dim] + start_idx, # [] + kv_seq_len, # [] sm_scale: float, block_h: int, block_k: int, @@ -113,12 +147,6 @@ def attn_unbatched( num_heads, head_dim = q.shape k_seq_len, _ = k.shape # Pad num query heads to 16 if needed, and slice output at the end. - original_num_heads = None - if num_heads < 16: - q = jnp.pad(q, ((0, 16 - num_heads), (0, 0))) - original_num_heads = num_heads - num_heads = q.shape[0] - block_h = min(block_h, num_heads) head_splits = pl.cdiv(num_heads, block_h) grid_ = grid if grid_ is None: @@ -127,11 +155,16 @@ def attn_unbatched( assert ( k_seq_len % k_splits == 0 ), f"{k_seq_len=} must be divisible by {k_splits=}" + assert k_seq_len // k_splits >= 16, ( + f"{k_seq_len=} divided by {k_splits=} must be >= 16.") + assert block_k >= 16, "block_k must be >= 16" k = k.reshape(k_splits, k_seq_len // k_splits, head_dim) v = v.reshape(k_splits, k_seq_len // k_splits, head_dim) - k_seq_len = k_seq_len // k_splits - assert min(num_heads, head_dim, k_seq_len) >= 16, "Minimum pl.dot size is 16" - block_k = min(block_k, k_seq_len) + split_k_seq_len = k_seq_len // k_splits + block_k = min(block_k, split_k_seq_len) + assert split_k_seq_len % block_k == 0, ( + f"Sequence length ({k_seq_len=}) split by {k_splits=} must by divisible by" + f" {block_k=}") num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 @@ -139,47 +172,49 @@ def attn_unbatched( attn_forward_kernel, sm_scale=sm_scale, block_k=block_k, + block_h=block_h, + num_heads=num_heads, ) o, l, m = pl.pallas_call( - kernel, - grid=grid_, - in_specs=[ - pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)), - pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), - pl.BlockSpec((None, k_seq_len, head_dim), lambda i, j: (j, 0, 0)), - ], - out_specs=[ - pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o - pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l - pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m - ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps_, num_stages=num_stages - ), - out_shape=[ - jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o - jax.ShapeDtypeStruct( - shape=(k_splits, num_heads), dtype=jnp.float32 - ), # l - jax.ShapeDtypeStruct( - shape=(k_splits, num_heads), dtype=jnp.float32 - ), # m - ], - debug=debug, - interpret=interpret, - name="mha_forward", - )(q, k, v) + kernel, + grid=grid_, + in_specs=[ + pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)), + pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)), + pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)), + ] + + [None if start_idx is None else pl.BlockSpec((), lambda i, j: ())] + + [None if kv_seq_len is None else pl.BlockSpec((), lambda i, j: ())], + out_specs=[ + pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l + pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m + ], + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages + ), + out_shape=[ + jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # l + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # m + ], + debug=debug, + interpret=interpret, + name="mha_forward", + )(q, k, v, start_idx, kv_seq_len) # final round of flash m_next = m.max(axis=0) correction = jnp.exp(m - m_next[None]) - o = o * correction[:, :, None] + o = o * correction[:, :, None].astype(o.dtype) l_next = (l * correction).sum(axis=0) - o = o.sum(axis=0) / l_next[:, None] - - if original_num_heads is not None: - o = o[:original_num_heads, :] + eps = jnp.finfo(l_next.dtype).eps + o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps) return o @@ -198,10 +233,12 @@ def attn_unbatched( ], ) def mqa( - q, # [batch_size, num_heads, head_dim] - k, # [batch_size, k_seq_len, head_dim] - v, # [batch_size, k_seq_len, head_dim] - sm_scale: float = 1.0, + q, # [batch_size, num_heads, head_dim] + k, # [batch_size, k_seq_len, head_dim] + v, # [batch_size, k_seq_len, head_dim] + start_idx=None, # [batch_size] + kv_seq_len=None, # [batch_size] + sm_scale: float | None = None, block_h: int = 16, block_k: int = 256, k_splits: int = 16, @@ -211,8 +248,14 @@ def mqa( interpret: bool = False, debug: bool = False, ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) + bs = q.shape[0] + if start_idx is not None: + start_idx = jnp.broadcast_to(start_idx, (bs,)) + if kv_seq_len is not None: + kv_seq_len = jnp.broadcast_to(kv_seq_len, (bs,)) inner = functools.partial( - attn_unbatched, + decode_attn_unbatched, sm_scale=sm_scale, block_h=block_h, block_k=block_k, @@ -223,7 +266,7 @@ def mqa( interpret=interpret, debug=debug, ) - return jax.vmap(inner)(q, k, v) + return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len) @functools.partial( @@ -241,12 +284,14 @@ def mqa( ], ) def gqa( - q, # [batch_size, num_q_heads, head_dim] - k, # [batch_size, k_seq_len, num_kv_heads, head_dim] - v, # [batch_size, k_seq_len, num_kv_heads, head_dim] - sm_scale: float = 1.0, + q, # [batch_size, num_q_heads, head_dim] + k, # [batch_size, k_seq_len, num_kv_heads, head_dim] + v, # [batch_size, k_seq_len, num_kv_heads, head_dim] + start_idx=None, # [batch_size] + kv_seq_len=None, # [batch_size] + sm_scale: float | None = None, block_h: int = 16, - block_k: int = 256, + block_k: int = 128, k_splits: int = 16, num_warps: int | None = None, num_stages: int = 2, @@ -254,10 +299,19 @@ def gqa( interpret: bool = False, debug: bool = False, ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) batch_size, q_heads, head_dim = q.shape - kv_heads = k.shape[2] + k_seq_len, kv_heads = k.shape[1], k.shape[2] assert kv_heads == v.shape[2] assert q_heads % kv_heads == 0 + if start_idx is not None: + assert start_idx.ndim in (0, 1) + start_idx = jnp.broadcast_to(jnp.asarray(start_idx)[..., None], + (batch_size, kv_heads)) + if kv_seq_len is not None: + assert kv_seq_len.ndim in (0, 1) + kv_seq_len = jnp.broadcast_to(jnp.asarray(kv_seq_len)[..., None], + (batch_size, kv_heads)) q_heads_per_kv_head = q_heads // kv_heads q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim) k_transposed = jnp.swapaxes( @@ -267,7 +321,7 @@ def gqa( v, 1, 2 ) # [batch_size, num_kv_heads, k_seq_len, head_dim] inner = functools.partial( - attn_unbatched, + decode_attn_unbatched, sm_scale=sm_scale, block_h=block_h, block_k=block_k, @@ -279,42 +333,70 @@ def gqa( debug=debug, ) with_kv_heads = jax.vmap(inner) - o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed) + o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed, + start_idx, kv_seq_len) return o.reshape(batch_size, q_heads, head_dim) @functools.partial(jax.jit, static_argnames=["sm_scale"]) def mqa_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, head_dim] - v, # [bs, k_seq_len, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, head_dim] + v, # [bs, k_seq_len, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, ): + bs = q.shape[0] + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32) + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) return jnp.einsum("bns,bsd->bnd", weights, v) @functools.partial(jax.jit, static_argnames=["sm_scale"]) def mha_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, num_k_heads, head_dim] - v, # [bs, k_seq_len, num_v_heads, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, ): + bs = q.shape[0] + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) assert q.shape[1] == k.shape[2] logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32) + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) return jnp.einsum("bns,bsnd->bnd", weights, v) @functools.partial(jax.jit, static_argnames=["sm_scale"]) def gqa_reference( - q, # [bs, num_q_heads, head_dim] - k, # [bs, k_seq_len, num_k_heads, head_dim] - v, # [bs, k_seq_len, num_v_heads, head_dim] - sm_scale=1.0, + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + start_idx=None, # [bs] + kv_seq_len=None, # [bs] + sm_scale=None, ): + sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) bs, num_q_heads, head_dim = q.shape num_kv_heads = k.shape[2] assert num_q_heads % num_kv_heads == 0 @@ -330,6 +412,15 @@ def gqa_reference( logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype( jnp.float32 ) + if start_idx is not None or kv_seq_len is not None: + start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) + kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None + else kv_seq_len, (bs,)) + mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None]) + & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) + mask = mask[:, None, None, :] + logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed) - return o.reshape(bs, num_q_heads, head_dim) + o = o.reshape(bs, num_q_heads, head_dim) + return o diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index ed059c235329..afd2f6ae3fcf 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -62,12 +62,15 @@ class DecodeAttentionTest(PallasBaseTest): @parameterized.named_parameters(*[ ( - f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}", + (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_" + f"{start_idx=}_{kv_seq_len=}"), batch_size, seq_len, num_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ) for ( batch_size, @@ -80,6 +83,8 @@ class DecodeAttentionTest(PallasBaseTest): (2, 1024, 2, 64, {}), (1, 1024, 8, 64, {}), ] + for start_idx in [None, 123] + for kv_seq_len in [None, 250] ]) @jax.numpy_dtype_promotion("standard") def test_mqa( @@ -89,6 +94,8 @@ def test_mqa( num_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ): del kwargs @@ -97,19 +104,24 @@ def test_mqa( k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16) v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16) - o = decode_attention.mqa(q, k, v, interpret=self.INTERPRET) - o_ref = decode_attention.mqa_reference(q, k, v) + o = decode_attention.mqa(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len, interpret=self.INTERPRET) + o_ref = decode_attention.mqa_reference(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len) np.testing.assert_allclose(o, o_ref, atol=0.05) @parameterized.named_parameters(*[ ( - f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}_{kwargs=}", + (f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}" + f"_{kwargs=}_{start_idx=}_{kv_seq_len=}"), batch_size, seq_len, num_q_heads, num_kv_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ) for ( batch_size, @@ -123,6 +135,8 @@ def test_mqa( (1, 1024, 16, 16, 64, {}), (1, 1024, 32, 32, 64, {}), ] + for start_idx in [None, 123] + for kv_seq_len in [None, 250] ]) @jax.numpy_dtype_promotion("standard") def test_gqa( @@ -133,6 +147,8 @@ def test_gqa( num_kv_heads, head_dim, kwargs, + start_idx, + kv_seq_len, ): del kwargs @@ -146,9 +162,10 @@ def test_gqa( v = random.normal( k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16 ) - - o = decode_attention.gqa(q, k, v, interpret=self.INTERPRET) - o_ref = decode_attention.gqa_reference(q, k, v) + o = decode_attention.gqa(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len, interpret=self.INTERPRET) + o_ref = decode_attention.gqa_reference(q, k, v, start_idx=start_idx, + kv_seq_len=kv_seq_len) np.testing.assert_allclose(o, o_ref, atol=0.05) class DecodeAttentionInterpretTest(DecodeAttentionTest): From 0c5846585f1391852b698d9f9add808c6c682f50 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 10:20:23 -0500 Subject: [PATCH 300/698] Add workflow for nightly pull from upstream --- .../workflows/rocm-nightly-upstream-sync.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .github/workflows/rocm-nightly-upstream-sync.yml diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml new file mode 100644 index 000000000000..880ea232d307 --- /dev/null +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -0,0 +1,18 @@ +# Pulls the latest changes from upstream into main and opens a PR to merge +# them into rocm-main. + +name: ROCm Nightly Upstream Sync +on: + schedule: + - cron: '0 6 * * *' +jobs: + sync-main: + runs-on: ubuntu-latest + steps: + - run: gh repo sync rocm/jax -b main + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + open-sync-pr: + runs-on: ubuntu-latest + steps: + - run: gh pr create --repo rocm/jax --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" From 7b5b68b7c1c78fc5135587669d01120a3f95f80a Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 10:29:36 -0500 Subject: [PATCH 301/698] Only run on weekdays --- .github/workflows/rocm-nightly-upstream-sync.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index 880ea232d307..ba81edac5bc9 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -4,7 +4,7 @@ name: ROCm Nightly Upstream Sync on: schedule: - - cron: '0 6 * * *' + - cron: '0 6 * * 1-5' jobs: sync-main: runs-on: ubuntu-latest From ec3f5006953fac40dc31e930abf5964a68109946 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 10:49:29 -0500 Subject: [PATCH 302/698] Fix yaml checker --- .github/workflows/rocm-nightly-upstream-sync.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index ba81edac5bc9..dcfbc01d1db5 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -15,4 +15,5 @@ jobs: open-sync-pr: runs-on: ubuntu-latest steps: - - run: gh pr create --repo rocm/jax --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" + - run: | + gh pr create --repo $GITHUB_REPOSITORY --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" From 8faf23119bfff6c2b7c9c8b7ee4b23f76cbb623c Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 31 Oct 2024 11:39:28 -0500 Subject: [PATCH 303/698] Set runners for ROCM --- .github/workflows/ci-build.yaml | 2 +- .github/workflows/upstream-nightly.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 5c786272ee3d..75eb9d99c39d 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -43,7 +43,7 @@ jobs: build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" - runs-on: linux-x86-n2-32 + runs-on: ROCM-Ubuntu container: image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 60 diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 04df278019a5..ada9b4e5825f 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -22,7 +22,7 @@ on: jobs: upstream-dev: - runs-on: ubuntu-20.04-16core + runs-on: ROCM-Ubuntu permissions: contents: read checks: write # for upload-artifact From bf0350831b55fe26d9ab357acd60ba00e7e83f86 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Thu, 31 Oct 2024 15:43:35 -0500 Subject: [PATCH 304/698] Allow devs to kick off sync job manually (#119) --- .github/workflows/rocm-nightly-upstream-sync.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index dcfbc01d1db5..98c958c3daa0 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -3,6 +3,7 @@ name: ROCm Nightly Upstream Sync on: + workflow_dispatch: schedule: - cron: '0 6 * * 1-5' jobs: From 909f746d63a6560db29d51ee273d2a0a8057f670 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Fri, 1 Nov 2024 10:05:43 -0500 Subject: [PATCH 305/698] Unpin container in CI build and remove libssl-dev install --- .github/workflows/ci-build.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 75eb9d99c39d..6b9baa8af097 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -44,8 +44,6 @@ jobs: build: name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" runs-on: ROCM-Ubuntu - container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 60 strategy: matrix: @@ -63,10 +61,6 @@ jobs: num_generated_cases: 1 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Image Setup - run: | - apt update - apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: From d2bda084470f35b3a8c8f4827513cf7f049180e6 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Mon, 4 Nov 2024 17:10:03 -0600 Subject: [PATCH 306/698] Rename the CI flow to 'ROCm CI' and only run it on PRs to rocm-main (#126) * Rename the CI flow to 'ROCm CI' and only run it on PRs to the rocm-main branch * Change name to 'ROCm CPU CI' --- .github/workflows/ci-build.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 6b9baa8af097..bfc6bc492872 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,4 +1,4 @@ -name: CI +name: ROCm CPU CI # We test all supported Python versions as follows: # - 3.10 : Documentation build @@ -11,10 +11,10 @@ on: # but only for the main branch push: branches: - - main + - rocm-main pull_request: branches: - - main + - rocm-main permissions: contents: read # to fetch code From 249ce1560ee9aa2bb263464b12480d1731805612 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 5 Nov 2024 11:32:09 -0600 Subject: [PATCH 307/698] Fix nightly sync permissions (#124) --- .github/workflows/rocm-nightly-upstream-sync.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index 98c958c3daa0..a15e49c2e87b 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -8,13 +8,19 @@ on: - cron: '0 6 * * 1-5' jobs: sync-main: + permissions: + contents: write runs-on: ubuntu-latest steps: - run: gh repo sync rocm/jax -b main env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} open-sync-pr: + permissions: + pull-requests: write runs-on: ubuntu-latest steps: - run: | gh pr create --repo $GITHUB_REPOSITORY --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} From f09863c44ced30b1c0459bf797e58a14244154da Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 13:56:16 -0600 Subject: [PATCH 308/698] Add GHA workflow for opening PRs upstream (#116) * Add file for opening PRs upstream * Add HEAD_REF as environment variable * Fill out code for making a new branch and opening a PR to upstream * Add names for steps * Fix yaml * Fix yaml again * Leave a comment on the old PR linking to the new one * Add proper permissions for creating banches and opening PRs * Fix YAML --- .github/workflows/rocm-open-upstream-pr.yml | 39 +++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/rocm-open-upstream-pr.yml diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml new file mode 100644 index 000000000000..09dfd06e907e --- /dev/null +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -0,0 +1,39 @@ +name: ROCm Open Upstream PR +on: + pull_request: + types: [ labeled ] + branches: [ rocm-main ] +jobs: + open-upstream: + if: ${{ github.event.label.name == 'open-upstream' }} + permissions: + contents: write + pull-requests: write + runs-on: ubuntu-latest + outputs: + new-pr-link: ${{ steps.create-pr.outputs.link }} + env: + NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream" + NEW_PR_TITLE: "[ROCM] ${{ github.event.pull_request.title }}" + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Rebase code to main + run: | + git checkout -b $NEW_BRANCH_NAME ${{ github.head_ref }} + git rebase --onto main + git push origin HEAD + # TODO: Change the base of the PR to upstream main + - name: Create a PR to upstream + id: create-pr + run: | + echo link=$(gh pr create --repo rocm/jax --base main --head $NEW_BRANCH_NAME --title "$NEW_PR_TITLE" --body "${{ github.event.pull_request.body }}") >> "$GITHUB_OUTPUT" + comment-link: + needs: open-upstream + permissions: + pull-requests: write + runs-on: ubuntu-latest + steps: + - name: Leave comment on old PR + run: gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body ${{ needs.open-upstream.outputs.new-pr-link.link }} + From 7831066110cddfd1b2c95e9250dd1c46f8fe9ddb Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 14:03:26 -0600 Subject: [PATCH 309/698] Create a new branch when merging upstream main to rocm-main (#128) --- .../workflows/rocm-nightly-upstream-sync.yml | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index a15e49c2e87b..98f3d2cfa39c 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -6,6 +6,8 @@ on: workflow_dispatch: schedule: - cron: '0 6 * * 1-5' +env: + SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }} jobs: sync-main: permissions: @@ -15,12 +17,28 @@ jobs: - run: gh repo sync rocm/jax -b main env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + create-sync-branch: + needs: sync-main + permissions: + contents: write + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Create branch + run: | + git checkout -b $SYNC_BRANCH_NAME main + git push origin HEAD open-sync-pr: + needs: create-sync-branch permissions: pull-requests: write runs-on: ubuntu-latest steps: - run: | - gh pr create --repo $GITHUB_REPOSITORY --head main --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" + gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream" env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + From a5ee6dc3a80f507c6b65b4ef41866a6ecd6d41e5 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 15:31:57 -0600 Subject: [PATCH 310/698] Fix upstream sync checkout (#130) * Checkout main before trying to switch to it * Fix the checkout command --- .github/workflows/rocm-nightly-upstream-sync.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index 98f3d2cfa39c..f29bef3bc46c 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -29,7 +29,8 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Create branch run: | - git checkout -b $SYNC_BRANCH_NAME main + git checkout origin/main + git checkout -b $SYNC_BRANCH_NAME git push origin HEAD open-sync-pr: needs: create-sync-branch From 144bef026f7456ac918d13b1a2ce0fcf168e7995 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Wed, 6 Nov 2024 16:15:59 -0600 Subject: [PATCH 311/698] Fix FFI example test in CI --- .github/workflows/ci-build.yaml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index bfc6bc492872..7256bdabb884 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -216,9 +216,7 @@ jobs: ffi: name: FFI example - runs-on: linux-x86-g2-16-l4-1gpu - container: - image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12 + runs-on: ROCM-Ubuntu timeout-minutes: 30 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -237,7 +235,7 @@ jobs: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} - name: Install JAX - run: pip install .[cuda12] + run: pip install . - name: Build and install example project run: python -m pip install -v ./examples/ffi[test] env: @@ -246,7 +244,7 @@ jobs: # a different toolchain. GCC is the default compiler on the # 'ubuntu-latest' runner, but we still set this explicitly just to be # clear. - CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON + CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON - name: Run CPU tests run: python -m pytest examples/ffi/tests env: From f75705426742b4c3a2c323f81a12a0c369cfd454 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 8 Nov 2024 11:48:45 -0500 Subject: [PATCH 312/698] Update some outdated syntax in FFI tutorial. --- docs/ffi.ipynb | 35 ++++++------ docs/ffi.md | 33 +++++------- docs/ffi/rms_norm.cc | 56 +++++++++----------- examples/ffi/src/jax_ffi_example/attrs.cc | 6 +-- examples/ffi/src/jax_ffi_example/cuda_e2e.cu | 40 +++++++------- examples/ffi/src/jax_ffi_example/rms_norm.cc | 17 +++--- 6 files changed, 83 insertions(+), 104 deletions(-) diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index ea8a86fa80f1..f1a699b5c56c 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -26,10 +26,7 @@ "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", "\n", - "This tutorial comes with two supplementary files:\n", - "\n", - "* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and\n", - "* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n", + "The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n", "\n", "## A simple example\n", "\n", @@ -101,7 +98,7 @@ "\n", "To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n", "For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n", - "The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:\n", + "The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n", "\n", "```c++\n", "#include \n", @@ -129,12 +126,11 @@ "// A wrapper function providing the interface between the XLA FFI call and our\n", "// library function `ComputeRmsNorm` above. This function handles the batch\n", "// dimensions by calling `ComputeRmsNorm` within a loop.\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y) {\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y) {\n", " auto [totalSize, lastDim] = GetDims(x);\n", " if (lastDim == 0) {\n", - " return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n", - " \"RmsNorm input must be an array\");\n", + " return ffi::Error::InvalidArgument(\"RmsNorm input must be an array\");\n", " }\n", " for (int64_t n = 0; n < totalSize; n += lastDim) {\n", " ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n", @@ -149,8 +145,8 @@ " RmsNorm, RmsNormImpl,\n", " ffi::Ffi::Bind()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", @@ -173,8 +169,7 @@ "Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n", "In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n", "\n", - "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.\n", - "The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt)." + "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble." ] }, { @@ -433,7 +428,7 @@ "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", "\n", - "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.\n", + "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n", "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", "\n", "This custom derivative rule can be wired in as follows:" @@ -508,16 +503,16 @@ "When defining our FFI wrapper for CPU, the function signature that we used was:\n", "\n", "```c++\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y)\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "To update this to interface with a CUDA kernel, this signature becomes:\n", "\n", "```c++\n", "ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n", - " ffi::Buffer x,\n", - " ffi::Result> y)\n", + " ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "And the handler definition is updated to include a `Ctx` in its binding:\n", @@ -528,8 +523,8 @@ " ffi::Ffi::Bind()\n", " .Ctx>()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", diff --git a/docs/ffi.md b/docs/ffi.md index 5afc8f809d4d..dbe901237ed4 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -34,10 +34,7 @@ JAX's FFI support is provided in two parts: In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. -This tutorial comes with two supplementary files: - -* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and -* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code. +The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi). ## A simple example @@ -96,7 +93,7 @@ and, for our example, this is the function that we want to expose to JAX via the To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla). For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call). -The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here: +The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here: ```c++ #include @@ -124,12 +121,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -144,8 +140,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` @@ -166,7 +162,6 @@ Now that we have our minimal FFI wrapper implemented, we need to expose this fun In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below. To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble. -The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt). ```{code-cell} ipython3 :tags: [hide-output] @@ -357,7 +352,7 @@ In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. 2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. -We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end. +We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end. The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. This custom derivative rule can be wired in as follows: @@ -422,16 +417,16 @@ Since this documentation page is automatically generated on a machine without ac When defining our FFI wrapper for CPU, the function signature that we used was: ```c++ -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) ``` To update this to interface with a CUDA kernel, this signature becomes: ```c++ ffi::Error RmsNormImpl(cudaStream_t stream, float eps, - ffi::Buffer x, - ffi::Result> y) + ffi::Buffer x, + ffi::ResultBuffer y) ``` And the handler definition is updated to include a `Ctx` in its binding: @@ -442,8 +437,8 @@ XLA_FFI_DEFINE_HANDLER( ffi::Ffi::Bind() .Ctx>() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` diff --git a/docs/ffi/rms_norm.cc b/docs/ffi/rms_norm.cc index 4dc8a890410c..467f13d44ac2 100644 --- a/docs/ffi/rms_norm.cc +++ b/docs/ffi/rms_norm.cc @@ -56,12 +56,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -75,17 +74,16 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x, XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); -ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -94,13 +92,12 @@ ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormFwd, RmsNormFwdImpl, - ffi::Ffi::Bind() - .Attr("eps") - .Arg>() // x - .Ret>() // y - .Ret>() // res +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y + .Ret>() // res ); void ComputeRmsNormBwd(int64_t size, float res, const float *x, @@ -115,14 +112,12 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, } } -ffi::Error RmsNormBwdImpl(ffi::Buffer res, - ffi::Buffer x, - ffi::Buffer ct_y, - ffi::Result> ct_x) { +ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, + ffi::Buffer ct_y, + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), @@ -131,11 +126,10 @@ ffi::Error RmsNormBwdImpl(ffi::Buffer res, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormBwd, RmsNormBwdImpl, - ffi::Ffi::Bind() - .Arg>() // res - .Arg>() // x - .Arg>() // ct_y - .Ret>() // ct_x +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>() // res + .Arg>() // x + .Arg>() // ct_y + .Ret>() // ct_x ); diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/attrs.cc index 2a6e8d847cf4..7ff5c98e52e1 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.cc +++ b/examples/ffi/src/jax_ffi_example/attrs.cc @@ -22,7 +22,7 @@ namespace nb = nanobind; namespace ffi = xla::ffi; ffi::Error ArrayAttrImpl(ffi::Span array, - ffi::Result> res) { + ffi::ResultBufferR0 res) { int64_t total = 0; for (int32_t x : array) { total += x; @@ -37,8 +37,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, .Ret>()); ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, - ffi::Result> secret, - ffi::Result> count) { + ffi::ResultBufferR0 secret, + ffi::ResultBufferR0 count) { auto maybe_secret = attrs.get("secret"); if (maybe_secret.has_error()) { return maybe_secret.error(); diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu b/examples/ffi/src/jax_ffi_example/cuda_e2e.cu index 858b5f8a888a..240adb6d6a8c 100644 --- a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu +++ b/examples/ffi/src/jax_ffi_example/cuda_e2e.cu @@ -44,11 +44,9 @@ __global__ void FooFwdKernel(const float *a, const float *b, float *c, // Buffer type provides buffer dimensions, so the "n" argument here is not // strictly necessary, but it allows us to demonstrate the use of attributes // (.Attr in the FFI handler definition above). -ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, - ffi::Buffer b, - ffi::Result> c, - ffi::Result> b_plus_1, - size_t n) { +ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, + ffi::Buffer b, ffi::ResultBuffer c, + ffi::ResultBuffer b_plus_1, size_t n) { const int block_dim = 128; const int grid_dim = 1; // Note how we access regular Buffer data vs Result Buffer data: @@ -73,12 +71,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooFwd, FooFwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // a - .Arg>() // b - .Ret>() // c - .Ret>() // b_plus_1 + .Arg>() // a + .Arg>() // b + .Ret>() // c + .Ret>() // b_plus_1 .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled //----------------------------------------------------------------------------// // Backward pass // @@ -106,11 +104,11 @@ __global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c } ffi::Error FooBwdHost(cudaStream_t stream, - ffi::Buffer c_grad, - ffi::Buffer a, - ffi::Result> b_plus_1, - ffi::Result> a_grad, - ffi::Result> b_grad, + ffi::Buffer c_grad, + ffi::Buffer a, + ffi::ResultBuffer b_plus_1, + ffi::ResultBuffer a_grad, + ffi::ResultBuffer b_grad, size_t n) { const int block_dim = 128; const int grid_dim = 1; @@ -131,10 +129,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooBwd, FooBwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // c_grad - .Arg>() // a - .Arg>() // b_plus_1 - .Ret>() // a_grad - .Ret>() // b_grad + .Arg>() // c_grad + .Arg>() // a + .Arg>() // b_plus_1 + .Ret>() // a_grad + .Ret>() // b_grad .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 2fb8d96c8461..455a0e557620 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -59,11 +59,10 @@ std::pair GetDims(const ffi::Buffer &buffer) { // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -82,12 +81,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ); ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -118,11 +116,10 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, ffi::Buffer ct_y, - ffi::Result> ct_x) { + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), From f08648366ec2cde4bffefaf24307e6abe72e6c57 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 20 Sep 2024 13:52:42 -0400 Subject: [PATCH 313/698] Add an example FFI call to demonstrate the use of global state. --- examples/ffi/CMakeLists.txt | 17 ++++--- examples/ffi/README.md | 25 ++++++++-- examples/ffi/src/jax_ffi_example/counter.cc | 53 ++++++++++++++++++++ examples/ffi/src/jax_ffi_example/counter.py | 38 ++++++++++++++ examples/ffi/tests/counter_test.py | 55 +++++++++++++++++++++ 5 files changed, 179 insertions(+), 9 deletions(-) create mode 100644 examples/ffi/src/jax_ffi_example/counter.cc create mode 100644 examples/ffi/src/jax_ffi_example/counter.py create mode 100644 examples/ffi/tests/counter_test.py diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 4179f4bd9ad4..9f9090e2b7ef 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -12,13 +12,18 @@ message(STATUS "XLA include directory: ${XLA_DIR}") find_package(nanobind CONFIG REQUIRED) -nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc") -target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) -install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +set( + JAX_FFI_EXAMPLE_PROJECTS + "rms_norm" + "attrs" + "counter" +) -nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc") -target_include_directories(_attrs PUBLIC ${XLA_DIR}) -install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS}) + nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc") + target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR}) + install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +endforeach() if(JAX_FFI_EXAMPLE_ENABLE_CUDA) enable_language(CUDA) diff --git a/examples/ffi/README.md b/examples/ffi/README.md index cc7018782a25..eb730b483b76 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -3,7 +3,26 @@ This directory includes an example project demonstrating the use of JAX's foreign function interface (FFI). The JAX docs provide more information about this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), -but the example in this directory explicitly demonstrates: +but the example in this directory complements that document by demonstrating +(and testing!) the full packaging workflow, and some more advanced use cases. +Within the example project, there are several example calls: -1. One way to package and distribute FFI targets, and -2. Some more advanced use cases. +1. `rms_norm`: This is the example from the tutorial on the JAX docs, and it + demonstrates the most basic use of the FFI. It also includes customization of + behavior under automatic differentiation using `jax.custom_vjp`. + +2. `counter`: This example demonstrates a common pattern for how an FFI call can + use global cache to maintain state between calls. This pattern is useful when + an FFI call requires an expensive initialization step which shouldn't be + run on every execution, or if there is other shared state that could be + reused between calls. In this simple example we just count the number of + times the call was executed. + +3. `attrs`: An example demonstrating the different ways that attributes can be + passed to the FFI. For example, we can pass arrays, variadic attributes, and + user-defined types. Full support of user-defined types isn't yet supported + by XLA, so that example will be added in the future. + +4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with + CUDA. The specifics of the kernels are not very important, but the general + structure, and packaging of the extension are useful for testing. diff --git a/examples/ffi/src/jax_ffi_example/counter.cc b/examples/ffi/src/jax_ffi_example/counter.cc new file mode 100644 index 000000000000..d7f17e730fd6 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/counter.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0 out) { + static std::mutex mutex; + static auto& cache = *new std::unordered_map(); + { + const std::lock_guard lock(mutex); + auto it = cache.find(index); + if (it != cache.end()) { + out->typed_data()[0] = ++it->second; + } else { + cache.insert({index, 0}); + out->typed_data()[0] = 0; + } + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Counter, CounterImpl, + ffi::Ffi::Bind().Attr("index").Ret>()); + +NB_MODULE(_counter, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); + return registrations; + }); +} diff --git a/examples/ffi/src/jax_ffi_example/counter.py b/examples/ffi/src/jax_ffi_example/counter.py new file mode 100644 index 000000000000..12c7f015bf58 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/counter.py @@ -0,0 +1,38 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example demonstrating how an FFI call can maintain "state" between calls + +In this case, the ``counter`` call simply accumulates the number of times it +was executed, but this pattern can also be used for more advanced use cases. +For example, this pattern is used in jaxlib for: + +1. The GPU solver linear algebra kernels which require an expensive "handler" + initialization, and +2. The ``triton_call`` function which caches the compiled triton modules after + their first use. +""" + +import jax +import jax.extend as jex + +from jax_ffi_example import _counter + +for name, target in _counter.registrations().items(): + jex.ffi.register_ffi_target(name, target) + + +def counter(index): + return jex.ffi.ffi_call( + "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/examples/ffi/tests/counter_test.py b/examples/ffi/tests/counter_test.py new file mode 100644 index 000000000000..1e2ad38a363f --- /dev/null +++ b/examples/ffi/tests/counter_test.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +from jax._src import test_util as jtu + +from jax_ffi_example import counter + +jax.config.parse_flags_with_absl() + + +class CounterTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + def test_basic(self): + self.assertEqual(counter.counter(0), 0) + self.assertEqual(counter.counter(0), 1) + self.assertEqual(counter.counter(0), 2) + self.assertEqual(counter.counter(1), 0) + self.assertEqual(counter.counter(0), 3) + + def test_jit(self): + @jax.jit + def counter_fun(x): + return x, counter.counter(2) + + self.assertEqual(counter_fun(0)[1], 0) + self.assertEqual(counter_fun(0)[1], 1) + + # Persists across different cache hits + self.assertEqual(counter_fun(1)[1], 2) + + # Persists after the cache is cleared + counter_fun.clear_cache() + self.assertEqual(counter_fun(0)[1], 3) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From eedd01118b100de5a6a74fdfb5bf19c47920c05a Mon Sep 17 00:00:00 2001 From: Jaroslav Sevcik Date: Wed, 6 Nov 2024 08:51:24 -0800 Subject: [PATCH 314/698] Add an option to specify mock GPU topology --- jax/_src/xla_bridge.py | 41 ++++++++++++++++------ tests/BUILD | 15 +++++++++ tests/mock_gpu_topology_test.py | 60 +++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 tests/mock_gpu_topology_test.py diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 23b255ef1750..28148761c8a4 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -90,6 +90,13 @@ help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) +_MOCK_GPU_TOPOLOGY = config.string_flag( + name="jax_mock_gpu_topology", + default="", + help='Mock multi-host GPU topology in GPU client. The value should ' + 'be of the form " x x ' + '". Empty string turns off mocking.', +) _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( name="jax_cpu_enable_gloo_collectives", @@ -425,6 +432,14 @@ def _version_check(name: str, f'following issues with CUDA components:\n' f'{join_str.join(errors)}') +def _get_num_nodes_from_gpu_topology(topology: str) -> int: + try: + slices_str, hosts_per_slice_str, _ = topology.split("x", 2) + return int(slices_str) * int(hosts_per_slice_str) + except (IndexError, ValueError): + raise ValueError('Mock topology must be of the form ' + '" x x ' + '".') def make_gpu_client( *, platform_name: str, visible_devices_flag: config.Flag[str] @@ -434,12 +449,14 @@ def make_gpu_client( if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} - use_mock_gpu_client = _MOCK_NUM_GPU_PROCESSES.value > 0 - num_nodes = ( - _MOCK_NUM_GPU_PROCESSES.value - if use_mock_gpu_client - else distributed.global_state.num_processes - ) + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + + use_mock_gpu_client = mock_num_gpu_processes > 0 + num_nodes = (mock_num_gpu_processes if use_mock_gpu_client + else distributed.global_state.num_processes) + if platform_name == "cuda": if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): _check_cuda_versions() @@ -634,10 +651,14 @@ def _options_from_jax_configs(plugin_name): visible_devices = CUDA_VISIBLE_DEVICES.value if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_processes = _MOCK_NUM_GPU_PROCESSES.value - options['enable_mock_nccl'] = mock_gpu_processes > 0 - if options['enable_mock_nccl']: - options['num_nodes'] = mock_gpu_processes + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + options['enable_mock_nccl'] = mock_num_processes > 0 + if mock_num_processes > 0: + options['num_nodes'] = mock_num_processes + if mock_gpu_topology: + options['mock_gpu_topology'] = mock_gpu_topology return options diff --git a/tests/BUILD b/tests/BUILD index dc81c408c4ce..1f6ea90b7a47 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -321,6 +321,21 @@ jax_multiplatform_test( ], ) +jax_multiplatform_test( + name = "mock_gpu_topology_test", + srcs = ["mock_gpu_topology_test.py"], + enable_backends = ["gpu"], + enable_configs = [ + "gpu_h100", + ], + tags = [ + "config-cuda-only", + ], + deps = [ + "//jax:experimental", + ], +) + jax_multiplatform_test( name = "array_test", srcs = ["array_test.py"], diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py new file mode 100644 index 000000000000..44ec4e2f9529 --- /dev/null +++ b/tests/mock_gpu_topology_test.py @@ -0,0 +1,60 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +jax.config.parse_flags_with_absl() + +NUM_SLICES = 2 +NUM_HOSTS_PER_SLICE = 4 + + +@jtu.with_config( + jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1", + jax_cuda_visible_devices="0") +class MockGPUTopologyTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Mocking devices only works on the GPU backend.") + super().setUp() + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockDeviceCount(self): + self.assertEqual(jax.device_count(), NUM_SLICES * NUM_HOSTS_PER_SLICE) + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockWithSharding(self): + mesh = jax.sharding.Mesh(jax.devices(), ('x',)) + f = jax.jit(jnp.sum, + in_shardings=NamedSharding(mesh, P('x')), + out_shardings=NamedSharding(mesh, P())) + + f_lowered = f.lower(jnp.arange(16)) + hlo = f_lowered.compiler_ir() + + mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE + self.assertIn(f'num_partitions = {mocked_count}', str(hlo)) + self.assertIn( + f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"', + str(hlo) + ) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) From 5808170a108a36ba518f661c8e1a59102410b074 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 12 Nov 2024 08:57:52 -0800 Subject: [PATCH 315/698] Add GPU overflow bugfix (#24846) to changelog. --- CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ab334c15904..17d15c740b7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. declared inline via {func}`dataclasses.field`. See the function documentation for examples. +* Bug fixes + * Fixed a bug where the GPU implementations of LU and QR decomposition would + result in an indexing overflow for batch sizes close to int32 max. See + {jax-issue}`#24843` for more details. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes @@ -79,7 +84,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The semi-public API `jax.lib.xla_client.register_custom_call_target` has been deprecated. Use the JAX FFI instead. * The semi-public APIs `jax.lib.xla_client.dtype_to_etype`, - `jax.lib.xla_client.ops`, + `jax.lib.xla_client.ops`, `jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`, `jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and `jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO From 64fcb9d3e91620f4608ca9b0bfe5ceb44261ee41 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 12 Nov 2024 09:23:02 -0800 Subject: [PATCH 316/698] Fix pgle profiling, broken in previous change. PiperOrigin-RevId: 695762690 --- jax/_src/pjit.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6ab8c90811a6..e2c50f2dc1a9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1608,15 +1608,22 @@ def _resolve_and_lower( lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) +_pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore + def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): pgle_compile_options, pgle_profiler = {}, None if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: - pgle_profiler = profiler.PGLEProfiler( - config.pgle_profiling_runs.value, - config.pgle_aggregation_percentile.value) + compilation_target_key = jaxpr + pgle_profiler = _pgle_profiler_dict.get(compilation_target_key) + if pgle_profiler is None: + pgle_profiler = profiler.PGLEProfiler( + config.pgle_profiling_runs.value, + config.pgle_aggregation_percentile.value) + _pgle_profiler_dict[compilation_target_key] = pgle_profiler + # The method below will return FDO profile when module was profiled # config.jax_pgle_profiling_runs amount of times, otherwise the result will # be None. From 6e1aa3c1e7f119e4d22e5e3757e7dba4e24dc807 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 8 Nov 2024 15:23:45 -0800 Subject: [PATCH 317/698] Specialize ufunc.reduce for monoidal binary ufuncs. --- jax/_src/numpy/reductions.py | 62 ++++++++++++++++++++++++++++++++++++ jax/_src/numpy/ufuncs.py | 32 ++++--------------- 2 files changed, 69 insertions(+), 25 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index be1e55675079..2293c8b17cc7 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -192,6 +192,11 @@ def _cast_to_bool(operand: ArrayLike) -> Array: def _cast_to_numeric(operand: ArrayLike) -> Array: return promote_dtypes_numeric(operand)[0] +def _require_integer(operand: ArrayLike) -> Array: + arr = lax_internal.asarray(operand) + if not dtypes.isdtype(arr, ("bool", "integral")): + raise ValueError(f"integer argument required; got dtype={arr.dtype}") + return arr def _ensure_optional_axes(x: Axis) -> Axis: def force(x): @@ -652,6 +657,63 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + arr = lax_internal.asarray(a) + init_val = np.array(-1, dtype=dtype or arr.dtype) + return _reduction(arr, "reduce_bitwise_and", None, lax.bitwise_and, init_val, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, "reduce_bitwise_or", None, lax.bitwise_or, 0, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, "reduce_bitwise_xor", None, lax.bitwise_xor, 0, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, "reduce_logical_and", None, lax.bitwise_and, True, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, "reduce_logical_or", None, lax.bitwise_or, False, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, "reduce_logical_xor", None, lax.bitwise_xor, False, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index acaed78e4db7..9dd370d599f2 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -31,7 +31,7 @@ from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax from jax._src.lax import other as lax_other -from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, @@ -1221,7 +1221,7 @@ def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) -@binary_ufunc(identity=-1) +@binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and) def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise AND operation elementwise. @@ -1250,7 +1250,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) -@binary_ufunc(identity=0) +@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or) def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise OR operation elementwise. @@ -1279,7 +1279,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) -@binary_ufunc(identity=0) +@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor) def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise XOR operation elementwise. @@ -1793,16 +1793,7 @@ def spacing(x: ArrayLike, /) -> Array: # Logical ops -def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - """Implementation of jnp.logical_and.reduce.""" - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_and.reduce()") - result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - -@binary_ufunc(identity=True, reduce=_logical_and_reduce) +@binary_ufunc(identity=True, reduce=reductions._reduce_logical_and) def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical AND operation elementwise. @@ -1823,16 +1814,7 @@ def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) -def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - """Implementation of jnp.logical_or.reduce.""" - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_or.reduce()") - result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - -@binary_ufunc(identity=False, reduce=_logical_or_reduce) +@binary_ufunc(identity=False, reduce=reductions._reduce_logical_or) def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical OR operation elementwise. @@ -1853,7 +1835,7 @@ def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) -@binary_ufunc(identity=False) +@binary_ufunc(identity=False, reduce=reductions._reduce_logical_xor) def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical XOR operation elementwise. From 310ff7347c0c4e2487f14f558d47a044be20effb Mon Sep 17 00:00:00 2001 From: James Martens Date: Tue, 12 Nov 2024 10:36:14 -0800 Subject: [PATCH 318/698] Change to internal dead code elimination. Now the functions in `dce_rules` are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation). PiperOrigin-RevId: 695789033 --- jax/_src/ad_checkpoint.py | 2 ++ jax/_src/custom_derivatives.py | 2 ++ jax/_src/interpreters/partial_eval.py | 33 +++++++++++++---------- jax/_src/interpreters/pxla.py | 2 ++ jax/_src/lax/control_flow/conditionals.py | 6 ++++- jax/_src/lax/control_flow/loops.py | 4 ++- jax/_src/pjit.py | 4 +++ jax/experimental/shard_map.py | 2 ++ jax/interpreters/partial_eval.py | 1 + 9 files changed, 40 insertions(+), 16 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 5ed0b0192a7b..fc135ac8f28c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -716,6 +716,8 @@ def remat_vmap(axis_data, args, dims, *, jaxpr, **params): # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) new_params = dict(eqn.params, jaxpr=new_jaxpr) if (not any(used_inputs) and not any(used_outputs) and diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 375efeb712b8..e37494c4fe41 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1531,6 +1531,8 @@ def _remat_opt_transpose( "remat optimization for custom_vjp does not support higher-order AD") def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): + if not any(used_outs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] if any(used_res): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ad97ef325f64..5431762d66d2 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -360,7 +360,7 @@ def const_out_axes_thunk(): staged_out_axes, _ = partition_list(out_knowns, out_axes) staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,) - # Create the input tracers for the staged-out (unkonwn-value) call. + # Create the input tracers for the staged-out (unknown-value) call. const_tracers = map(self.new_instantiated_const, res) env_tracers = map(self.to_jaxpr_tracer, env) unknown_arg_tracers = [t for t in tracers if not t.is_known()] @@ -1382,6 +1382,11 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], return new_jaxpr, used_consts, used_inputs +def has_effects(eqn: JaxprEqn) -> bool: + effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} + return bool(effs) + + @weakref_lru_cache def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], instantiate: tuple[bool, ...] @@ -1395,21 +1400,14 @@ def write(x: Atom, b: bool) -> None: if type(x) is Var: env[x] = read(x) or b - def has_effects(eqn: JaxprEqn) -> bool: - effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} - return bool(effs) - new_eqns = [] map(write, jaxpr.outvars, used_outputs) for eqn in jaxpr.eqns[::-1]: used_outs = map(read, eqn.outvars) - if not any(used_outs) and not has_effects(eqn): - used_ins = [False] * len(eqn.invars) - else: - rule = dce_rules.get(eqn.primitive, _default_dce_rule) - used_ins, new_eqn = rule(used_outs, eqn) - if new_eqn is not None: - new_eqns.append(new_eqn) + rule = dce_rules.get(eqn.primitive, _default_dce_rule) + used_ins, new_eqn = rule(used_outs, eqn) + if new_eqn is not None: + new_eqns.append(new_eqn) map(write, eqn.invars, used_ins) used_inputs = map(read, jaxpr.invars) used_inputs = map(op.or_, instantiate, used_inputs) @@ -1433,7 +1431,9 @@ def has_effects(eqn: JaxprEqn) -> bool: def _default_dce_rule( used_outs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outs) and not has_effects(eqn): + return [False] * len(eqn.invars), None return [True] * len(eqn.invars), eqn dce_rules: dict[Primitive, DCERule] = {} @@ -1441,6 +1441,8 @@ def _default_dce_rule( def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) new_params = dict(eqn.params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(eqn.primitive) @@ -1454,6 +1456,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn [v for v, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn + dce_rules[core.call_p] = dce_jaxpr_call_rule @@ -1465,8 +1468,10 @@ def _cached_closed_call_dce(jaxpr_, used_outputs: tuple[bool, ...] return core.ClosedJaxpr(new_jaxpr, consts), used_inputs def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: # TODO(mattjj): de-duplicate with above rule? + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr_ = eqn.params['call_jaxpr'] closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs)) new_params = dict(eqn.params, call_jaxpr=closed_jaxpr) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9a17194d46c9..316fbc077ceb 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1353,6 +1353,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval): def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None axis_name = eqn.params["axis_name"] with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]): new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 6333638deae6..9e1f7e04c741 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -642,7 +642,11 @@ def _ordered_unique(xs): return list(d.keys()) def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + closed_branches = eqn.params['branches'] branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index b5bb8658e675..d15917b8b1da 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -944,7 +944,9 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params) def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_xs = len(jaxpr.in_avals) - num_consts - num_carry diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e2c50f2dc1a9..f1844c7ba13b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2326,6 +2326,10 @@ def _dce_jaxpr_pjit( def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + dced_jaxpr, used_inputs = _dce_jaxpr_pjit( eqn.params['jaxpr'], tuple(used_outputs)) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 7ddd3805b5d0..3a9446862456 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1660,6 +1660,8 @@ def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]: # TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] with core.extend_axis_env_nd(mesh.shape.items()): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 1aa3ebc67b06..dca438996229 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -63,6 +63,7 @@ debug_info_final as debug_info_final, def_trivial_padding as def_trivial_padding, forwarding_rules as forwarding_rules, + has_effects as has_effects, infer_lambda_input_type as infer_lambda_input_type, instantiate_const_at as instantiate_const_at, make_jaxpr_effects as make_jaxpr_effects, From 1221da84677827d6467d5dfb21c9ecf80bfe9aa2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 12 Nov 2024 10:54:33 -0800 Subject: [PATCH 319/698] [Mosaic] Fix mask creation for packed sublanes Unaligned concat used to be f32 only, but implicitly protected via unimplemented support for multi-row-shift in sub32 types. When this was added, we started invoking unaligned concat flow w/ sub32 types, but the masking code that assumed full rows (unpacked types) was no longer sufficient - we need better granularity for these cases. This only affects sublanes, as that is where we pack, we don't have partial lanes. This CL, as a small benefit, also adds better error messages to the ops involved in lower_to_llo.cc. PiperOrigin-RevId: 695796095 --- .../tpu/transforms/apply_vector_layout.cc | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8cb01ee67ad4..c9c4a81e668d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2664,7 +2664,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const auto bitwidth = res_ty.getElementTypeBitWidth(); const int packing = res_layout->packing(); - SmallVector out_idx; vreg.Each([&](absl::Span idx, Value *v) { out_idx.assign(idx.begin(), idx.end()); @@ -2674,17 +2673,29 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const VectorType vmask_ty = getNativeVregOrVmaskType( builder.getI1Type(), bitwidth, ctx.target_shape); if (tiling_dim.value() == 0) { // sublane - mask = builder.create( - op.getLoc(), vmask_ty, - ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(operand_offset * packing), - boundIdxConst(layout->tiling()[1])}); + if (operand_offset % packing != 0) { + // Packed case, degenerate where we have a half or quarter + // sublane. + // TODO(mvoz): We can probably always use the + // CreateSubelementMaskOp if (1) optimize it on TPUv4 and (2) Add + // support for unpacked types in some of the invariants in + // lower_to_llo. + mask = builder.create( + op.getLoc(), vmask_ty, 0, operand_offset, packing); + } else { + auto sublane_offset = operand_offset / packing; + mask = builder.create( + op.getLoc(), vmask_ty, + ArrayRef{boundIdxConst(0), boundIdxConst(0)}, + ArrayRef{boundIdxConst(sublane_offset), + boundIdxConst(layout->tiling()[1])}); + } } else { // lane mask = builder.create( op.getLoc(), vmask_ty, ArrayRef{boundIdxConst(0), boundIdxConst(0)}, ArrayRef{boundIdxConst(layout->tiling()[0]), - boundIdxConst(operand_offset * packing)}); + boundIdxConst(operand_offset)}); } // Blend the current value with the existing value in the output. *v = builder.create(op.getLoc(), mask, From d304025a4102f144bee89db6d8562b29ed6d3c2d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 12 Nov 2024 11:01:18 -0800 Subject: [PATCH 320/698] [mosaic_gpu] The profiler now uses FFI calls for creating events and computing elapsed time PiperOrigin-RevId: 695798787 --- jax/BUILD | 1 + jax/experimental/mosaic/gpu/profiler.py | 114 +++++++++++----------- jaxlib/mosaic/gpu/BUILD | 4 +- jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 122 +++++++++++++++++------- 4 files changed, 145 insertions(+), 96 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 71be67368f3b..e8e817ece3e9 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -722,6 +722,7 @@ py_library( ":jax", ":mlir", "//jax/_src/lib", + "//jax/extend:ffi", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:func_dialect", diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index e4949b325507..337581c54b86 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -14,16 +14,15 @@ # ============================================================================== import contextlib -import ctypes -import functools import itertools import json import math +from typing import Callable, ParamSpec, TypeVar import warnings import jax -from jax._src.interpreters import mlir from jax._src.lib import xla_client +from jax.extend import ffi import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -34,72 +33,71 @@ from .utils import * # noqa: F403 - try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - xla_client.register_custom_call_target( - "mosaic_gpu_record_event", - mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(), - platform="CUDA", - ) except ImportError: pass +else: + for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): + xla_client.register_custom_call_target( + name, handler, platform="CUDA", api_version=1 + ) # ruff: noqa: F405 # mypy: ignore-errors +T = TypeVar("T") +P = ParamSpec("P") -record_event_p = jax.core.Primitive("record_event") -record_event_p.multiple_results = True - -@record_event_p.def_abstract_eval -def _record_event_abstract_eval(*args, event): - del event # Unused. - return args - -@functools.partial(mlir.register_lowering, record_event_p, platform="cuda") -def _record_event_lowering_rule(ctx, *args, event): - ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes( - 8, byteorder="little" - ) # pytype: disable=attribute-error - op = mlir.custom_call( - "mosaic_gpu_record_event", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - backend_config=ptr_bytes, - operand_output_aliases={i: i for i in range(len(args))}, - ) - return op.results - -def _record_event(args, event): +def _event_record(args, *, copy_before): flat_args, treedef = jax.tree.flatten(args) - return jax.tree.unflatten( - treedef, record_event_p.bind(*flat_args, event=event) - ) - -def measure(f, *args, **kwargs): - # TODO(apaszke): Raise if this is called under jit. - start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() - end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() - try: - - @jax.jit - def run(*args, **kwargs): - flat_args, treedef = jax.tree.flatten((args, kwargs)) - flat_args = _record_event(flat_args, start_event) - args, kwargs = jax.tree.unflatten(treedef, flat_args) - return _record_event(f(*args, **kwargs), end_event) - - jax.block_until_ready(run(*args, **kwargs)) # Warmup. - results = jax.block_until_ready(run(*args, **kwargs)) - elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed( - start_event, end_event + event, *flat_outs = ffi.ffi_call( + "mgpu_event_record", + result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args), + input_output_aliases={i: i + 1 for i in range(len(flat_args))}, + )(*flat_args, copy_before=copy_before) + return event, treedef.unflatten(flat_outs) + + +def _event_elapsed(start_event, end_event): + return ffi.ffi_call( + "mgpu_event_elapsed", + result_shape_dtypes=jax.core.ShapedArray((), jnp.float32), + )(start_event, end_event) + + +def measure( + f: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> tuple[T, float]: + """Measures the time it takes to execute the function on the GPU. + + Args: + f: The function to measure. It must accept at least one argument and return + at least one output to be measurable. + *args: The arguments to pass to ``f``. + **kwargs: The keyword arguments to pass to ``f``. + + Returns: + The return value of ``f`` and the elapsed time in milliseconds. + """ + if not (args or kwargs): + # We require at least one argument and at least one output to ensure + # that there is a data dependency between `_event_record` calls in + # the resulting HLO program. + raise ValueError("Can only measure functions with arguments") + + @jax.jit + def run(*args, **kwargs): + start_event, (args, kwargs) = _event_record( + (args, kwargs), copy_before=True ) - finally: - mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event) - mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event) - return results, elapsed + end_event, outs = _event_record(f(*args, **kwargs), copy_before=False) + if jax.tree.structure(outs).num_leaves == 0: + raise ValueError("Can only measure functions with at least one output") + return outs, _event_elapsed(start_event, end_event) + + outs, elapsed = run(*args, **kwargs) + return outs, float(elapsed) class ProfilerSpec: diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 2fb8f0103e65..1f78782a0891 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -185,9 +185,11 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/strings", "@nanobind", - "@xla//xla/service:custom_call_status", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/tsl/cuda:cudart", ], ) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 922d13d213f5..608270239882 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -13,19 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include #include "nanobind/nanobind.h" +#include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/service/custom_call_status.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" namespace jax::cuda { namespace { +namespace ffi = xla::ffi; +namespace nb = nanobind; + static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -38,45 +43,88 @@ static std::string ToString(CUresult result) { return absl::StrCat(error_name, ": ", error_string); } -void EventRecordCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto* event = reinterpret_cast(opaque); - if (auto res = gpuEventRecord(**event, reinterpret_cast(stream)); - res) { - auto message = absl::StrCat("Failed to record event: ", ToString(res)); - XlaCustomCallStatusSetFailure(status, message.c_str(), message.size()); - } +// Ensure it is safe to store gpuEvent_t in a uint64_t buffer. +static_assert(sizeof(gpuEvent_t) <= sizeof(uint64_t)); + +static const auto* kEventRecord = + ffi::Ffi::Bind() + .Ctx>() + .Attr("copy_before") + .RemainingArgs() + .Ret>() // event + .RemainingRets() + .To([](gpuStream_t stream, bool copy_before, + auto remaining_args, auto ret, auto remaining_rets) { + static auto* event = new gpuEvent_t; + if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); + res) { + return ffi::Error::Internal( + absl::StrCat("Failed to create event: ", ToString(res))); + } + auto do_copy = [&]() { + gpuMemcpyAsync(ret->untyped_data(), event, + sizeof(gpuEvent_t), gpuMemcpyHostToDevice, stream); + }; + if (copy_before) { + do_copy(); + } + if (auto res = gpuEventRecord(*event, stream); res) { + return ffi::Error::Internal( + absl::StrCat("Failed to record event: ", ToString(res))); + } + if (!copy_before) { + do_copy(); + } + return ffi::Error::Success(); + }) + .release(); + +XLA_FFI_Error* EventRecord(XLA_FFI_CallFrame* call_frame) { + return kEventRecord->Call(call_frame); +} + +static const auto* kEventElapsed = + ffi::Ffi::Bind() + .Ctx>() + .Arg>() // start_event + .Arg>() // end_event + .Ret>() // elapsed_ms + .To([](gpuStream_t stream, auto start, auto end, auto out) { + gpuStreamSynchronize(stream); + auto start_event = std::make_unique(); + auto end_event = std::make_unique(); + absl::MakeCleanup([&]() { + gpuEventDestroy(*start_event); + gpuEventDestroy(*end_event); + }); + gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpyDeviceToHost); + gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpyDeviceToHost); + float elapsed; + if (auto res = + gpuEventElapsedTime(&elapsed, *start_event, *end_event); + res) { + return ffi::Error::Internal(absl::StrCat( + "Failed to get elapsed time between events: ", ToString(res))); + } + gpuMemcpy(out->untyped_data(), &elapsed, sizeof(float), + gpuMemcpyHostToDevice); + return ffi::Error::Success(); + }) + .release(); + +XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) { + return kEventElapsed->Call(call_frame); } NB_MODULE(_mosaic_gpu_ext, m) { - m.def("_gpu_event_create", []() { - gpuEvent_t* event = new gpuEvent_t(); - if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); res) { - throw std::runtime_error( - absl::StrCat("Failed to create event: ", ToString(res))); - } - return reinterpret_cast(event); - }); - m.def("_gpu_event_destroy", [](uintptr_t event) { - if (auto res = gpuEventDestroy(*reinterpret_cast(event)); - res) { - throw std::runtime_error( - absl::StrCat("Failed to destroy event: ", ToString(res))); - } - }); - m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) { - float elapsed_ms = -1; - if (auto res = gpuEventElapsedTime( - &elapsed_ms, *reinterpret_cast(start_event), - *reinterpret_cast(end_event)); - res) { - throw std::runtime_error(absl::StrCat( - "Failed to get elapsed time between events: ", ToString(res))); - } - return elapsed_ms; + m.def("registrations", []() { + return nb::make_tuple( + nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)), + nb::make_tuple("mgpu_event_elapsed", EncapsulateFunction(EventElapsed)) + ); }); - m.def("_record_event_capsule", - []() { return EncapsulateFunction(EventRecordCall); }); m.def("_sync_all_devices", []() { int devices = 0; if (cudaGetDeviceCount(&devices) != gpuSuccess) { From c4a0369f5c3a401aed68474a25280aa7cc114dfe Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 12 Nov 2024 12:51:45 -0800 Subject: [PATCH 321/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f17344020c3240dcf28cabd12eadc97df178a1e6. PiperOrigin-RevId: 695838490 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f74c74077198..b38a984edc8d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e93a258e4494231626c7d3b6a6447e746ea72f9c" -XLA_SHA256 = "99f3a6b06230becf013f00009afeee4c89f52818e7a4a1ea4851157dc853830e" +XLA_COMMIT = "f17344020c3240dcf28cabd12eadc97df178a1e6" +XLA_SHA256 = "ce306964e6f44a4a3a6a804a7455fdc8c88e5f1c6370c9bae56dd8cbd365cdff" def repo(): tf_http_archive( From 370c4a70bb48dec4c225b2d8957165334e68cb3a Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 12 Nov 2024 13:44:08 -0800 Subject: [PATCH 322/698] Change the assumed width of the bool packing in the early-lowering checks in pallas PiperOrigin-RevId: 695856621 --- jax/_src/pallas/mosaic/lowering.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 489aae59dcd2..ebf172bc612d 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -502,8 +502,13 @@ def err_details(): ) else: assert rank == 1 - # TODO(necula): test this for bool. What should it do? - tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype)) + # bools get a bitwidth of 32 due to how mosaic handles them + if bm.array_shape_dtype.dtype == jnp.bool_: + bitwidth = 32 + else: + bitwidth = lax_internal._bit_width(bm.array_shape_dtype.dtype) + packing = 32 // bitwidth + tiling_size = 128 * packing evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) if not evenly_divisible: raise ValueError( From c32db46e6c7294bfb6243dabe86e3a6d27268a99 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Tue, 12 Nov 2024 14:37:08 -0800 Subject: [PATCH 323/698] [Mosaic] Add parameter names to tpu.sem_signal and add tests This CLs adds parameter names to the optional parameters of `tpu.sem_signal` -- `device_id`, `core_id` -- to remove the ambiguity upon deserialization. Adds LIT tests of signalling on TC with parameter names. PiperOrigin-RevId: 695875037 --- jaxlib/mosaic/dialect/tpu/tpu.td | 2 +- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 3affd31e51d6..b312bca7a7d3 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -640,7 +640,7 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { Optional:$core_id // For megacore ); let assemblyFormat = [{ - $semaphore `,` $amount (`,` $device_id^)? (`,` $core_id^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 3f6050f31dab..fd68c9e6c95e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -92,6 +92,9 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) { EnqueueDMAOp>::getOperandSegmentSizeAttr(), mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); } else if (op->getNumOperands() == 3) { // Remote signal. + // Hardcoding that one optional value is device_id, not core_id. This + // could misinterpret sem_signals where core_id is specified, but + // device_id isn't. op->setAttr(OpTrait::AttrSizedOperandSegments< EnqueueDMAOp>::getOperandSegmentSizeAttr(), mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); From d47e254100f46efd82b6133f551e1ef6289d620a Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Tue, 12 Nov 2024 15:50:50 -0800 Subject: [PATCH 324/698] Dedent your yields! Fixes a surprising interaction between the generator system in linear_util.py and the try/finally python context managers we use for managing tracing context. The `finally` block wasn't always being called until garbage collection, so the context stack pushes/pops weren't always correctly nested. Dedenting the yield fixes this particular bug but long-term we should get rid of linear_util altogether. PiperOrigin-RevId: 695898528 --- jax/_src/interpreters/batching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 0adb582a7993..b6325ed81ce1 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -653,7 +653,7 @@ def batch_subtrace(tag, axis_data, in_dims, *in_vals): outs = yield in_tracers, {} out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) segment_lens, out_dims = indirectify_ragged_axes(out_dims) - yield (*segment_lens, *out_vals), out_dims + yield (*segment_lens, *out_vals), out_dims def indirectify_ragged_axes(dims): if not any(type(d) is RaggedAxis for d in dims): @@ -803,7 +803,7 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) new_out_axes = indirectify_ragged_axes_against_inputs_outputs( out_axes, in_vals, out_vals) - yield out_vals, new_out_axes + yield out_vals, new_out_axes @lu.transformation_with_aux def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, From 195d4070813ebd04e93afe3d4f8aeb2c1a270698 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 12 Nov 2024 19:59:57 -0800 Subject: [PATCH 325/698] Add new CI scripts for running Bazel CPU presubmits This commit introduces new CI scripts and environment files for running Bazel CPU presubmits. * Adds a ci directory at the root of the repository to store these files. * Environment files are located in ci/envs and define new JAXCI_ environment variables to control CI build behavior. * The build script sources these environment files and set up the build environment before running the build commands. PiperOrigin-RevId: 695957540 --- .github/workflows/bazel_cpu_rbe.yml | 41 ++++++++++++++ ci/README.md | 10 ++++ ci/envs/default.env | 37 +++++++++++++ ci/run_bazel_test_cpu_rbe.sh | 68 +++++++++++++++++++++++ ci/utilities/setup_build_environment.sh | 71 +++++++++++++++++++++++++ 5 files changed, 227 insertions(+) create mode 100644 .github/workflows/bazel_cpu_rbe.yml create mode 100644 ci/README.md create mode 100644 ci/envs/default.env create mode 100755 ci/run_bazel_test_cpu_rbe.sh create mode 100644 ci/utilities/setup_build_environment.sh diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml new file mode 100644 index 000000000000..4a2e2ecb7fe6 --- /dev/null +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -0,0 +1,41 @@ +name: CI - Bazel CPU tests (RBE) + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_tests: + if: github.event.repository.fork == false + strategy: + matrix: + runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"] + + runs-on: ${{ matrix.runner }} + # TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CPU Tests with RBE + run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 000000000000..ea867df52f97 --- /dev/null +++ b/ci/README.md @@ -0,0 +1,10 @@ +# JAX continuous integration + +> [!WARNING] +> This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> JAX repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +******************************************************************************** \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env new file mode 100644 index 000000000000..528c02701acc --- /dev/null +++ b/ci/envs/default.env @@ -0,0 +1,37 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file contains all the default values for the "JAXCI_" environment +# variables used in the CI scripts. These variables are used to control the +# behavior of the CI scripts such as the Python version used, path to JAX/XLA +# repo, if to clone XLA repo, etc. + +# The path to the JAX git repository. +export JAXCI_JAX_GIT_DIR=$(pwd) + +# Controls the version of Hermetic Python to use. Use system default if not +# set. +export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} + +# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local +# copy of XLA instead of the pinned version in the WORKSPACE. When +# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} + +# If set to 1, the builds will clone the XLA repository at HEAD and set its +# path in JAXCI_XLA_GIT_DIR. +export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} + +# Allows overriding the XLA commit that is used. +export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} \ No newline at end of file diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh new file mode 100755 index 000000000000..6ba9f6dce239 --- /dev/null +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel CPU tests with RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel CPU tests with RBE. +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# When running on Mac or Linux Aarch64, we only build the test targets and +# not run them. These platforms do not have native RBE support so we +# RBE cross-compile them on remote Linux x86 machines. As the tests still +# need to be run on the host machine and because running the tests on a +# single machine can take a long time, we skip running them on these +# platforms. +if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then + echo "Building RBE CPU tests..." + bazel build --config=rbe_cross_compile_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --test_output=errors \ + --color=yes \ + //tests:cpu_tests //tests:backend_independent_tests +else + echo "Running RBE CPU tests..." + bazel test --config=rbe_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --test_output=errors \ + --color=yes \ + //tests:cpu_tests //tests:backend_independent_tests +fi \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh new file mode 100644 index 000000000000..e77e84f3c07f --- /dev/null +++ b/ci/utilities/setup_build_environment.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Set up the build environment for JAX CI jobs. This script depends on the +# "JAXCI_" environment variables set or sourced in the build script. + +# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI +# jobs running on Linux runners in GitHub Actions. Without this, git complains +# that the directory has dubious ownership and refuses to run any commands. +# Avoid running on Windows runners as git runs into issues with not being able +# to lock the config file. Other git commands seem to work on the Windows +# runners so we can skip this step for Windows. +# TODO(b/375073267): Remove this once we understand why git repositories are +# being marked as unsafe inside the self-hosted runners. +if [[ ! $(uname -s) =~ "MSYS_NT" ]]; then + git config --global --add safe.directory $JAXCI_JAX_GIT_DIR +fi + +function clone_main_xla() { + echo "Cloning XLA at HEAD to $(pwd)/xla" + git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla + export JAXCI_XLA_GIT_DIR=$(pwd)/xla +} + +# Clone XLA at HEAD if required. +if [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then + # Clone only if $(pwd)/xla does not exist to avoid failure on re-runs. + if [[ ! -d $(pwd)/xla ]]; then + clone_main_xla + else + echo "JAXCI_CLONE_MAIN_XLA set but local XLA folder already exists: $(pwd)/xla so using that instead." + # Set JAXCI_XLA_GIT_DIR if local XLA already exists + export JAXCI_XLA_GIT_DIR=$(pwd)/xla + fi +fi + +# If a XLA commit is provided, check out XLA at that commit. +if [[ ! -z "$JAXCI_XLA_COMMIT" ]]; then + # Clone XLA at HEAD if a path to local XLA is not provided. + if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + clone_main_xla + fi + pushd "$JAXCI_XLA_GIT_DIR" + + git fetch --depth=1 origin "$JAXCI_XLA_COMMIT" + echo "JAXCI_XLA_COMMIT is set. Checking out XLA at $JAXCI_XLA_COMMIT" + git checkout "$JAXCI_XLA_COMMIT" + + popd +fi + +if [[ ! -z ${JAXCI_XLA_GIT_DIR} ]]; then + echo "INFO: Overriding XLA to be read from $JAXCI_XLA_GIT_DIR instead of the" + echo "pinned version in the WORKSPACE." + echo "If you would like to revert this behavior, unset JAXCI_CLONE_MAIN_XLA" + echo "and JAXCI_XLA_COMMIT in your environment. Note that the Bazel RBE test" + echo "commands overrides the XLA repository and thus require a local copy of" + echo "XLA to run." +fi \ No newline at end of file From f2a25cc2314083e50d588fc045c24b5ce754427c Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Tue, 12 Nov 2024 21:59:59 -0800 Subject: [PATCH 326/698] [XLA] Make our LLVM usage more googley With the advent of heterogenuous compute, XLA compilation now encompasses sub-compilation for multiple devices. These all can use LLVM, but with different settings. Today this means it is possible for one XLA client to reinitialize LLVM's global state while another client is in the middle of compilation. Add a global lock around our LLVM usage. Concurrent compilation is still allowed, as long as both invocations have the same set of options. This means from within the same client multiple compilation invocations should still be non-blocking. PiperOrigin-RevId: 695981613 --- tests/layout_test.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/layout_test.py b/tests/layout_test.py index 31f3d71d0537..afddab916723 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -655,6 +655,44 @@ def f(x): f(sparecore_arr) + def test_sparsecore_and_host_compute(self): + if not ( + jax.devices()[0].device_kind == 'TPU v5' + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest('Does not have a sparsecore present') + shape = (128, 128) + inp = jnp.arange(math.prod(shape)).reshape(shape) + s = SingleDeviceSharding(jax.devices()[0]) + + sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + sparse_layout = Layout(sparse_dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + + host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),)) + host_layout = Layout(host_dll, s) + host_arr = jax.device_put(inp, host_layout) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_compute(x): + return x * x + + @compute_on('device_host') + @jax.jit + def host_compute(x): + return x + x + + @partial( + jax.jit, + in_shardings=(sparse_layout, host_layout), + out_shardings=(sparse_layout, host_layout), + ) + def f(x, y): + return sparsecore_compute(x), host_compute(y) + + f(sparecore_arr, host_arr) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 9a28b561a666a7ed4de7da9f870c1d1988518723 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 13 Nov 2024 01:31:41 -0800 Subject: [PATCH 327/698] Fix parallel pgle-tests execution. PiperOrigin-RevId: 696031645 --- tests/pgle_test.py | 84 +++++++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 5f0c28541b62..f34beb797846 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -12,50 +12,79 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import ExitStack from functools import partial import glob import logging import math import os import tempfile -import unittest from absl.testing import absltest import jax +from jax._src import api +from jax._src import compilation_cache as cc from jax._src import config -from jax._src import profiler -from jax._src import pjit from jax._src import monitoring +from jax._src import pjit +from jax._src import profiler from jax._src import test_util as jtu -from jax._src import api from jax.experimental import profiler as exp_profiler -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec -from jax._src import compilation_cache as cc -import numpy as np - from jax.experimental.serialize_executable import ( deserialize_and_load, serialize, ) +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec +import numpy as np jax.config.parse_flags_with_absl() -dump_dir = tempfile.TemporaryDirectory().name -os.environ['XLA_FLAGS'] = ( - f'--xla_dump_to={dump_dir}' - ' --xla_gpu_experimental_dump_fdo_profiles=true' - ' --xla_gpu_enable_latency_hiding_scheduler=true' -) @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): + _dump_exit_stack: ExitStack | None = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._dump_exit_stack = ExitStack() + + cls.dump_dir = cls._dump_exit_stack.enter_context(tempfile.TemporaryDirectory()) + if 'XLA_FLAGS' in os.environ: + cls.old_xla_flags = os.environ['XLA_FLAGS'] + else: + cls.old_xla_flags = None + + os.environ['XLA_FLAGS'] = ( + f'--xla_dump_to={cls.dump_dir}' + ' --xla_gpu_experimental_dump_fdo_profiles=true' + ' --xla_gpu_enable_latency_hiding_scheduler=true' + # TODO(patrios): Remove this flag once b/376647494 is fixed. + ' --xla_gpu_graph_level=0' + ) + if cls.old_xla_flags: + os.environ['XLA_FLAGS'] += ' ' + cls.old_xla_flags + + @classmethod + def tearDownClass(cls): + if cls.old_xla_flags: + os.environ['XLA_FLAGS'] = cls.old_xla_flags + cls._dump_exit_stack.close() + super().tearDownClass() + def setUp(self): super().setUp() cc.set_cache_dir(None) cc.reset_cache() def tearDown(self): + # Cleanup dump directory + for file in os.listdir(self.dump_dir): + file_path = os.path.join(self.dump_dir, file) + if os.path.isfile(file_path): + os.remove(file_path) + cc.set_cache_dir(None) super().tearDown() @@ -87,7 +116,6 @@ def f(x, y): self.assertIsNotNone(fdo_profile) self.assertIn(b'custom', fdo_profile) - @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfileLarge(self): mesh = jtu.create_mesh((2,), ('x',)) its = 500 @@ -106,14 +134,10 @@ def f(x): shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - with config.pgle_profiling_runs(0): - f_lowered = f.lower(x) - f_compiled = f_lowered.compile() - pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): - f_compiled(x) + f(x) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertEqual(fdo_profile.count(b'custom'), its) @@ -177,7 +201,6 @@ def f(x): self.assertArraysEqual(compiled(x), expected) self.assertEqual(cache_miss_count[0], 0) - @unittest.skip("Test failing in CI") def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) @@ -206,11 +229,12 @@ def f(x): config.persistent_cache_min_compile_time_secs(0), config.pgle_profiling_runs(2), tempfile.TemporaryDirectory() as cache_dir): + cc.reset_cache() cc.set_cache_dir(cache_dir) # Run 1: Module should be compiled without FDO with jtu.count_cached_compilation_cache_miss() as cache_miss_count: f(x) - self.assertEqual(cache_miss_count[0], 1) + self.assertGreater(cache_miss_count[0], 0) # Non-pgle profiled version of module should be saved non_pgle_profiled_files = os.listdir(cache_dir) @@ -221,26 +245,24 @@ def f(x): f(x) self.assertEqual(cache_miss_count[0], 0) - module_before_pgle = os.listdir(dump_dir) - print(module_before_pgle) + module_before_pgle = os.listdir(self.dump_dir) self.assertNotEmpty(module_before_pgle) # Run 3: Module should be compiled with FDO and stored to persistent cache with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - # Add xla_dump_to to env flags f(x) - self.assertEqual(cache_miss_count[0], 1) + self.assertGreater(cache_miss_count[0], 0) # Check if FDO profile file of the biggest module is not empty module_after_pgle = [ x - for x in os.listdir(dump_dir) + for x in os.listdir(self.dump_dir) if x not in module_before_pgle ] self.assertNotEmpty(module_after_pgle) biggest_module_after_pgle = max( module_after_pgle, key=lambda x: os.path.getsize( - os.path.join(dump_dir, x) + os.path.join(self.dump_dir, x) ), ) base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) @@ -251,7 +273,7 @@ def f(x): '.fdo_profile' ): self.assertGreater( - os.path.getsize(os.path.join(dump_dir, module)), 0 + os.path.getsize(os.path.join(self.dump_dir, module)), 0 ) for pgle_profiler in profilers_dict.values(): @@ -283,7 +305,7 @@ def check_if_cache_hit(event): f(x) monitoring._unregister_event_listener_by_callback(check_if_cache_hit) - self.assertEqual(cache_hit, 1) + self.assertGreater(cache_hit, 0) def testPassingFDOProfile(self): mesh = jtu.create_mesh((2,), ('x',)) From dfabcb027d62f2864b339c095b9ecfeb2984c050 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 13 Nov 2024 06:14:24 -0800 Subject: [PATCH 328/698] Add a shard map replication rule for cond_p. --- jax/experimental/shard_map.py | 26 ++++++++++++++++ tests/shard_map_test.py | 57 +++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3a9446862456..1d4e347ffaeb 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1274,6 +1274,32 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) return out_vals, out_rep +@register_check(control_flow.conditionals.cond_p) +def _cond_rule(mesh, *in_rep, branches): + _, *args_rep = in_rep + true_out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) + false_out_rep = _check_rep(mesh, branches[1].jaxpr, args_rep) + if not true_out_rep == false_out_rep: + raise Exception("The true and false branches of cond produced mismatched " + f"replication types {true_out_rep} and {false_out_rep}. " + "Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return true_out_rep + +@register_rewrite(control_flow.conditionals.cond_p) +def _cond_rewrite(mesh, in_rep, *args, branches): + pred_rep, *args_rep = in_rep + _, true_out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) + _, false_out_rep = _replication_rewrite_nomatch(mesh, branches[1], args_rep) + out_rep = map(op.and_, true_out_rep, false_out_rep) + out_rep = map(partial(op.and_, pred_rep), out_rep) + branches_ = ( + _replication_rewrite_match(mesh, branches[0], args_rep, out_rep), + _replication_rewrite_match(mesh, branches[1], args_rep, out_rep), + ) + out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) + return out_vals, out_rep @register_rewrite(core.closed_call_p) def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 48850c8da66a..df24315ce110 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -993,6 +993,63 @@ def body(c, _): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_cond_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + x = jnp.arange(4) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(True, true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return y + return jax.lax.cond(True, true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + + # https://github.com/jax-ml/jax/issues/24418 + def f(a): + c = jax.lax.cond(jnp.any(a), lambda: 1, lambda: 0) + return jnp.reshape(c, a.shape) + + mesh = jtu.create_mesh((2,), ('x',)) + a = jnp.array([True, False]) + shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + def test_eager_custom_jvp_basic(self): @jax.custom_jvp def foo(x): From bc82203a5c8488f493da931f06920d2dda868dd9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 13 Nov 2024 08:19:13 -0800 Subject: [PATCH 329/698] Avoid using a contextmanager in Primitive.bind. It's slightly faster to inline the context manager code into the implementation of bind. PiperOrigin-RevId: 696142743 --- jax/_src/core.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index beb755348a0f..96aecfde3a74 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -446,8 +446,16 @@ def bind(self, *args, **params): # TODO: figure out how to handle function arguments # assert (not config.enable_checks.value or # all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args - with take_current_trace() as cur_trace: - return self.bind_with_trace(cur_trace, args, params) + + # This is equivalent to "with take_current_trace()", but the bind() code + # is called frequently and it's slightly faster to avoid using a context + # manager object. + prev_trace = trace_ctx.trace + trace_ctx.set_trace(eval_trace) + try: + return self.bind_with_trace(prev_trace, args, params) + finally: + trace_ctx.set_trace(prev_trace) def bind_with_trace(self, trace, args, params): return trace.process_primitive(self, args, params) From be3c8be186b860fbbd459960775e1890c4c84f02 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 13 Nov 2024 08:47:24 -0800 Subject: [PATCH 330/698] Fix bug where the Python wrapper to ParseArguments() didn't intern the static argnames strings, causing false mismatches when searching for static arguments. Fixes https://github.com/jax-ml/jax/issues/24857 PiperOrigin-RevId: 696151287 --- tests/pjit_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a9760d02fc0f..8fe46c3b83e5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -17,6 +17,7 @@ import re from functools import partial import logging +import json import math import textwrap import threading @@ -59,6 +60,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -3825,6 +3827,16 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P())) self.assertEqual(out.sharding.memory_kind, 'device') + @unittest.skipIf(xla_extension_version < 297, "Requires jaxlib 0.4.36+") + def test_jit_static_argnames_non_interned(self): + def do_nothing(foobar: int): + return foobar + + argname = "foobar" + # Has the side effect of ensuring argname is not interned. + argname = str(json.loads(json.dumps(argname))) + jax.jit(do_nothing, static_argnames=[argname])(foobar=2) # doesn't crash + def test_most_recent_executable_outer_inner_cache(self): x = np.zeros((20, 20), dtype=jnp.float64) From 93a1f9d3178c124293f16d3b156d7df335d23308 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 13 Nov 2024 09:02:40 -0800 Subject: [PATCH 331/698] [AutoPGLE] Fix test after pjrt cache refactoring PiperOrigin-RevId: 696156229 --- tests/pgle_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index f34beb797846..a27f3ec0b9ac 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -219,8 +219,6 @@ def f(x): shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - profilers_dict = ( - pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict) with (config.enable_compilation_cache(True), config.enable_pgle(True), config.raise_persistent_cache_errors(True), @@ -276,7 +274,7 @@ def f(x): os.path.getsize(os.path.join(self.dump_dir, module)), 0 ) - for pgle_profiler in profilers_dict.values(): + for pgle_profiler in pjit._pgle_profiler_dict.values(): self.assertTrue(pgle_profiler.is_enabled()) self.assertTrue(pgle_profiler.is_fdo_consumed()) @@ -291,7 +289,7 @@ def f(x): os.remove(os.path.join(cache_dir, non_pgle_file)) api.clear_caches() - profilers_dict.clear() + pjit._pgle_profiler_dict.clear() # Run 4: Persistent compilation cache should be hit PGLE profiler should # be disabled From 4d0a007d573091531775f7c7d25e1ae7963e8d39 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 13 Nov 2024 11:14:16 +0200 Subject: [PATCH 332/698] Add square_p --- jax/_src/lax/lax.py | 23 ++++++++++++++++++++++- jax/_src/numpy/ufuncs.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 9 +++++++++ jax/_src/pallas/mosaic_gpu/lowering.py | 5 +++++ jax/_src/pallas/triton/lowering.py | 1 + jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/jet.py | 1 + jax/experimental/sparse/transform.py | 1 + jax/extend/core/primitives.py | 1 + jax/lax/__init__.py | 1 + tests/lax_test.py | 8 +------- 11 files changed, 44 insertions(+), 9 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9781f67152c8..7e15f46c3ef1 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1915,7 +1915,7 @@ def batch_matmul(lhs: Array, rhs: Array, def square(x: ArrayLike) -> Array: r"""Elementwise square: :math:`x^2`.""" - return integer_pow(x, 2) + return square_p.bind(x) def reciprocal(x: ArrayLike) -> Array: r"""Elementwise reciprocal: :math:`1 \over x`.""" @@ -2524,6 +2524,27 @@ def _abs_jvp_rule(g, ans, x): lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +square_p = standard_unop(_int | _float | _complex, 'square') + +def _square_complex(x): + a, b = real(x), imag(x) + # zero square(x).real is handled explicitly for abs(a)==abs(b) cases + # where for finite a, 2 * a is non-finite: + zero_re = is_finite(a) & (eq(a, b) | eq(a, -b)) + # equivalent to a**2 - b**2 but avoids overflow errors for large a + # and large b cases: + re = (a - b) * (a + b) + im = a * b * 2 + return select(zero_re, complex(_const(a, 0), im), complex(re, im)) + +def _square_lower_hlo(ctx, x): + if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): + return mlir.lower_fun(_square_complex, multiple_results=False)(ctx, x) + return [hlo.multiply(x, x)] + +ad.defjvp2(square_p, lambda g, ans, x: mul(g, mul(_const(x, 2), x))) +mlir.register_lowering(square_p, _square_lower_hlo) # TODO(pearu): use chlo.square + def _pow_dtype_rule(x, y): if (dtypes.issubdtype(x.dtype, np.inexact) and dtypes.issubdtype(y.dtype, np.integer)): diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index acaed78e4db7..29fa97411ee4 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -3107,7 +3107,7 @@ def square(x: ArrayLike, /) -> Array: """ check_arraylike("square", x) x, = promote_dtypes_numeric(x) - return lax.integer_pow(x, 2) + return lax.square(x) @partial(jit, inline=True) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index ebf172bc612d..3dbb410be29f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2084,6 +2084,15 @@ def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule +def _square_lowering_rule(ctx: LoweringRuleContext, x): + if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer): + return arith.muli(x, x) + return arith.mulf(x, x) + + +lowering_rules[lax.square_p] = _square_lowering_rule + + def _exp_lowering_rule(ctx: LoweringRuleContext, x): return math.exp(x) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ba343cd923c3..dc46f1b81b6a 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1160,6 +1160,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): return x * x return NotImplementedError +@register_lowering_rule(lax.square_p) +def _square_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) + return x * x @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 1a0400ebf0db..12a6e6b7965e 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -780,6 +780,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], ), + lax.square_p: lambda ctx, x: _mul(x, x), lax.pow_p: _make_dispatch_table( "pow", cuda=[ diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c6d920918074..1ad00f091b7c 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1726,6 +1726,7 @@ def _atan2(y, x, **kwargs): tf_impl[lax.asinh_p] = tf.math.asinh tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.square_p] = tf.math.square tf_impl[lax.rsqrt_p] = tf.math.rsqrt def _cbrt(x): diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 827e4d01b390..9ae2a3f139cb 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -405,6 +405,7 @@ def def_comp(prim, comp): def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x ** 0.5) +def_comp(lax.square_p, lambda x: x * x) def_comp(lax.rsqrt_p, lambda x: x ** -0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 7c5a966500f7..f85142493e41 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -97,6 +97,7 @@ lax.sin_p, lax.sinh_p, lax.sqrt_p, + lax.square_p, lax.tan_p, lax.tanh_p, lax.convert_element_type_p, diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index feb70b5171be..02f0657cc371 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -127,6 +127,7 @@ sinh_p as sinh_p, sort_p as sort_p, sqrt_p as sqrt_p, + square_p as square_p, squeeze_p as squeeze_p, sub_p as sub_p, tan_p as tan_p, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index d2fb6a9bae3c..d569ed641138 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -206,6 +206,7 @@ sqrt as sqrt, sqrt_p as sqrt_p, square as square, + square_p as square_p, squeeze as squeeze, squeeze_p as squeeze_p, stop_gradient as stop_gradient, diff --git a/tests/lax_test.py b/tests/lax_test.py index 17132996c429..14f453b38e7c 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4362,12 +4362,6 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'sign': regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4') - elif name == 'square': - if is_cuda: - regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real') - if is_cpu: - regions_with_inaccuracies_keep('ninf.real', 'pinf.real', 'q1.real', 'q2.real', 'q3.real', 'q4.real') - elif name == 'log': regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag') @@ -4411,7 +4405,7 @@ def regions_with_inaccuracies_keep(*to_keep): regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', - 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh'}: + 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh', 'square'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable From 558ebb9fb14c0b7e1038f4bc00962f461430862b Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Wed, 13 Nov 2024 10:25:15 -0800 Subject: [PATCH 333/698] Add Pallas Triton lowering for jax.lax.bitcast_convert_type. Only handles the case where operand type and target type have the same bitwidth. PiperOrigin-RevId: 696184251 --- jax/_src/pallas/triton/lowering.py | 21 +++++++++++++++ tests/pallas/ops_test.py | 41 ++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 1a0400ebf0db..e4d0244735b8 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2612,3 +2612,24 @@ def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type: # All integer types in Triton are signless. return ir.IntegerType.get_signless(dtype.itemsize * 8) return mlir.dtype_to_ir_type(dtype) + + +@register_lowering(lax.bitcast_convert_type_p) +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, operand: ir.Value, *, new_dtype +) -> ir.Value: + # TODO(petebu) Handle case where src and dst types have different bitwidths + src_elem_type = _element_type(operand.type) + dst_elem_type = _element_type(_dtype_to_ir_type(new_dtype)) + assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) + assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) + if src_elem_type.width != dst_elem_type.width: + raise NotImplementedError( + f"cannot cast {operand} to {new_dtype} because of different widths" + ) + if ir.RankedTensorType.isinstance(operand.type): + shape = ir.RankedTensorType(operand.type).shape + result_type = ir.RankedTensorType.get(shape, dst_elem_type) + else: + result_type = dst_elem_type + return tt_dialect.bitcast(result_type, operand) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index b8a42ecf1835..41670137c39f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1927,6 +1927,47 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.triu(x, k=k)) + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.float16), + (jnp.int16, jnp.bfloat16), + (jnp.float32, jnp.int32), + ) + def test_bitcast_convert_type(self, in_dtype, out_dtype): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + m, n = 4, 4 + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () + + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) + + x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + + def test_bitcast_convert_type_scalar(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + x = jnp.int32(42) + out_dtype = jnp.float32 + out_shape = jax.ShapeDtypeStruct(x.shape, out_dtype) + grid = () + + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_dtype) + + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + + class OpsInterpretTest(OpsTest): INTERPRET = True From a79d307ac7f314fc2706026e3fb8283637d073c5 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 29 Jul 2024 16:13:01 -0700 Subject: [PATCH 334/698] When caching is enabled, also enable XLA caching features as well Add unit test Fix typechecker Set caching mode depending on process id --- docs/persistent_compilation_cache.md | 18 ++++++++++ jax/BUILD | 1 + jax/_src/compiler.py | 27 +++++++++++++++ jax/_src/config.py | 9 +++++ tests/compilation_cache_test.py | 50 ++++++++++++++++++++++++++++ 5 files changed, 105 insertions(+) diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 246d3a6cb084..37afa2f594e3 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -24,6 +24,7 @@ import jax.numpy as jnp jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) +jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") @jax.jit def f(x): @@ -87,6 +88,23 @@ cc.set_cache_dir("/tmp/jax_cache") Note that both criteria need to be satisfied for a function to be cached. +### Additional caching + +XLA supports additional caching mechanism which can be enabled alongside JAX's +persistent compilation cache to further improve recompilation time. + +* `jax_persistent_cache_enable_xla_caches`: Possible values: + + * `all`: enable all XLA caching features + + * `none`: don't enable any extra XLA caching features + + * `xla_gpu_kernel_cache_file`: only enable the kernel cache + + * `xla_gpu_per_fusion_autotune_cache_dir`: (default value) only enable the + autotuning cache + + ### Google Cloud When running on Google Cloud, the compilation cache can be placed on a Google diff --git a/jax/BUILD b/jax/BUILD index e8e817ece3e9..0da99677dc7b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -430,6 +430,7 @@ pytype_strict_library( ":config", ":mlir", ":monitoring", + ":path", ":profiler", ":traceback_util", ":xla_bridge", diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 113f7507c4b0..ebb1a2b54855 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -29,10 +29,12 @@ from jax._src import distributed from jax._src import lib from jax._src import monitoring +from jax._src import path as pathlib from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir import numpy as np @@ -241,6 +243,31 @@ def get_compile_options( debug_options.xla_detailed_logging = detailed_logging + # If persistent cache is enabled, also enable additional XLA caching features. + if compilation_cache.is_persistent_cache_enabled() and jaxlib_version > (0, 4, 35): + # compilation_cache_dir can't be None here, but the type checker is a bit + # strict. + path = pathlib.Path(config.compilation_cache_dir.value or "") + enabled_flags = config.persistent_cache_enable_xla_caches.value or "" + + if enabled_flags == "all" or "xla_gpu_kernel_cache_file" in enabled_flags: + kernel_cache_path = path / "xla_gpu_kernel_cache_file" + debug_options.xla_gpu_kernel_cache_file = str(kernel_cache_path) + # This option is required to use the kernel cache. + debug_options.xla_gpu_enable_llvm_module_compilation_parallelism = True + logger.debug("Enabling XLA kernel cache at '%s'", kernel_cache_path) + + if enabled_flags == "all" or "xla_gpu_per_fusion_autotune_cache_dir" in enabled_flags: + autotune_cache_path = path / "xla_gpu_per_fusion_autotune_cache_dir" + debug_options.xla_gpu_per_fusion_autotune_cache_dir = str(autotune_cache_path) + logger.debug("Enabling XLA autotuning cache at '%s'", autotune_cache_path) + + # Set caching mode so that only process 0 can write to the cache. + if distributed.global_state.process_id == 0: + debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.UPDATE + else: + debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.READ + return compile_options diff --git a/jax/_src/config.py b/jax/_src/config.py index f3edde69981f..a44b0125a210 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1369,6 +1369,15 @@ def _update_jax_memories_thread_local(val): ' filesystem being used for the cache. ' '* > 0: the actual minimum size desired; no overrides.')) +# TODO: Change default to all +persistent_cache_enable_xla_caches = optional_string_state( + name='jax_persistent_cache_enable_xla_caches', + default='xla_gpu_per_fusion_autotune_cache_dir', + help=('When the persistent cache is enabled, additional XLA caching will ' + 'also be enabled automatically. This option can be used to configure' + 'which XLA caching methods will be enabled.'), +) + compilation_cache_include_metadata_in_key = bool_state( name='jax_compilation_cache_include_metadata_in_key', default=False, diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 40c2181a9e3c..d10558afbe16 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -40,6 +40,8 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface +from jax._src.lib import xla_client as xc +from jax._src.lib import version as jaxlib_version from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -535,6 +537,42 @@ def test_backend_serialization_deserialization(self): self.assertEqual( executable.fingerprint, deserialized_executable.fingerprint) + def test_persistent_cache_enable_xla_caches(self): + if jaxlib_version <= (0, 4, 35): + self.skipTest("Test requires AutotuneCacheMode bindings") + with config.compilation_cache_dir("jax-cache"): + with config.persistent_cache_enable_xla_caches("none"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("all"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "jax-cache/xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "jax-cache/xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("xla_gpu_kernel_cache_file"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "jax-cache/xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("xla_gpu_per_fusion_autotune_cache_dir"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "jax-cache/xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) @jtu.with_config( jax_enable_compilation_cache=False, @@ -570,5 +608,17 @@ def test_tasks_disable_cache_metric(self): "/jax/compilation_cache/task_disabled_cache"] self.assertEqual(count_after_second_use, count_after_first_use) + def test_persistent_cache_enable_xla_caches_disabled(self): + if jaxlib_version <= (0, 4, 35): + self.skipTest("Test requires AutotuneCacheMode bindings") + with config.enable_compilation_cache(False): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 72a4692b943f87076bd65a3fad947a8c5fa54119 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 13 Nov 2024 12:44:05 -0800 Subject: [PATCH 335/698] doc: link directly to installation on the main page --- docs/_static/style.css | 6 +++--- docs/index.rst | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/_static/style.css b/docs/_static/style.css index 296912ace2c8..2c1dfcbcbf08 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -8,15 +8,15 @@ background-color: #fff; } -.getting-started { +.installation { background-color: rgba(78, 150, 253, var(--block-bg-opacity)); } -.user-guides { +.getting-started { background-color: rgba(0, 169, 154, var(--block-bg-opacity)); } -.developer-docs { +.user-guides { background-color: rgba(171, 0, 182, var(--block-bg-opacity)); } diff --git a/docs/index.rst b/docs/index.rst index ba724f8e77ab..5f3bce5cf7da 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,12 @@ designed for high-performance numerical computing and large-scale machine learni .. grid:: 3 + .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Installation + :columns: 12 6 6 4 + :link: installation + :link-type: ref + :class-card: installation + .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting started :columns: 12 6 6 4 :link: beginner-guide @@ -44,12 +50,6 @@ designed for high-performance numerical computing and large-scale machine learni :link-type: ref :class-card: user-guides - .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer notes - :columns: 12 6 6 4 - :link: contributor-guide - :link-type: ref - :class-card: developer-docs - If you're looking to train neural networks, use Flax_ and start with its tutorials. For an end-to-end transformer library built on JAX, see MaxText_. From 307e88f2801464e2ec760e7e1d75a57ae9d2887d Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 13 Nov 2024 15:58:45 -0500 Subject: [PATCH 336/698] Fix typos: Change 'arugments' to 'arguments'. --- docs/Custom_Operation_for_GPUs.md | 4 ++-- docs/Custom_Operation_for_GPUs.py | 2 +- jax/_src/callback.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index fcb7b570e493..f4b61cbcf7dc 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -679,7 +679,7 @@ class RmsNormFwdClass: NamedSharding(mesh, PartitionSpec(None, None))) invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) output_shardings = (arg_shardings[0], invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables impl = partial(RmsNormFwdClass.impl, eps=eps) @@ -739,7 +739,7 @@ class RmsNormBwdClass: output_shardings = (output_sharding, invvar_sharding, invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables def impl(g, invvar, x, weight): grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 31a00c49071e..1cdf67c41a90 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -353,7 +353,7 @@ def partition(eps: float, mesh : jax.sharding.Mesh, NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything. invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) output_shardings = (arg_shardings[0], invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables impl = partial(RmsNormFwdClass.impl, eps=eps) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 71886b453bef..013b766b8550 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -343,7 +343,7 @@ def pure_callback( * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` is deprecated and it will eventually raise ``NotImplementedError``. * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over - the batched arugments, calling ``callback`` once for each batch element. + the batched arguments, calling ``callback`` once for each batch element. * ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1`` added as the leading dimension unbatched inputs. * ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the From 4a884d4184b4d1b4d9084abab64abda62c0c6c0b Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 13 Nov 2024 13:40:13 -0800 Subject: [PATCH 337/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/2a7890387f812c17fb5f17eec961ee52ac3e059d. PiperOrigin-RevId: 696255293 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b38a984edc8d..fdb6b1607816 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f17344020c3240dcf28cabd12eadc97df178a1e6" -XLA_SHA256 = "ce306964e6f44a4a3a6a804a7455fdc8c88e5f1c6370c9bae56dd8cbd365cdff" +XLA_COMMIT = "2a7890387f812c17fb5f17eec961ee52ac3e059d" +XLA_SHA256 = "cfe1eebc643355f55e6422451cbd750ac6a7f096ed8d6a0605238e4d8ce6d0d1" def repo(): tf_http_archive( From 1c9b23c566bcc0a373f2f9c8716bcc46851000f1 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 12 Nov 2024 22:39:26 -0800 Subject: [PATCH 338/698] Stop using generators for linear_util transformations. They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces. --- jax/_src/api_util.py | 62 +++++++++-------- jax/_src/checkify.py | 19 +++--- jax/_src/custom_derivatives.py | 47 +++++++------ jax/_src/interpreters/ad.py | 61 +++++++++-------- jax/_src/interpreters/batching.py | 85 +++++++++++++----------- jax/_src/interpreters/partial_eval.py | 64 ++++++++++-------- jax/_src/interpreters/pxla.py | 8 +-- jax/_src/linear_util.py | 96 ++++++++++++--------------- jax/_src/pallas/primitives.py | 7 +- jax/experimental/attrs.py | 31 ++++----- jax/experimental/jax2tf/jax2tf.py | 8 +-- jax/experimental/jet.py | 35 +++++----- jax/experimental/ode.py | 8 +-- jax/experimental/shard_map.py | 46 +++++++------ jax/experimental/sparse/transform.py | 9 +-- jax/extend/linear_util.py | 2 + tests/util_test.py | 10 +-- 17 files changed, 311 insertions(+), 287 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 329abd6b7570..1bfce85d592c 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]: else: return tuple(map(_ensure_str, x)) -@lu.transformation_with_aux -def flatten_fun(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree @@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args): ans = fun(*args) return tree_unflatten(out_tree, ans) -@lu.transformation_with_aux -def flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} - yield tree_flatten(ans) + ans = f(*py_args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun_nokwargs(fun, io_tree, py_args): in_tree_expected, out_tree = io_tree @@ -118,17 +122,18 @@ def flattened_fun_in_tree( else: return in_tree, lambda: out_tree_store.val, has_kwargs -@lu.transformation_with_aux -def flatten_fun_nokwargs2(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs2(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - pair = yield py_args, {} + pair = f(*py_args) if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " f"tuple, but got type {type(pair)} with value {pair!r}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) - yield (ans_flat, aux_flat), (ans_tree, aux_tree) + store.store((ans_tree, aux_tree)) + return ans_flat, aux_flat class _HashableWithStrictTypeEquality: """Box object used when comparing static arguments as a jit key. @@ -277,8 +282,8 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args -@lu.transformation -def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): +@lu.transformation2 +def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(fixed_args) + len(dyn_args)) for i, arg in zip(dyn_argnums, dyn_args): @@ -286,9 +291,7 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): fixed_args_ = iter(fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel - ans = yield args, kwargs - yield ans - + return f(*args, **kwargs) def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], kwargs: dict[str, Any]): @@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs -@lu.transformation -def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): +@lu.transformation2 +def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs) - ans = yield args, kwargs - yield ans + return f(*args, **kwargs) @lru_cache(maxsize=4096) @@ -435,9 +437,9 @@ def flat_out_axes( f, out_axes = _flat_out_axes(f, tuple(leaves), treedef) return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) -@lu.transformation_with_aux -def _flat_out_axes(leaves, treedef, *args, **kwargs): - ans = yield args, kwargs +@lu.transformation_with_aux2 +def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): + ans = f(*args, **kwargs) spec = tree_unflatten(treedef, leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) @@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs): "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None - yield ans, spec_flat + store.store(spec_flat) + return ans def check_callable(fun): # In Python 3.10+, the only thing stopping us from supporting staticmethods @@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() for path, l in generate_key_paths(x) if l is not static) -@lu.transformation_with_aux -def result_paths(*args, **kwargs): +@lu.transformation_with_aux2 +def result_paths(f, store, *args, **kwargs): "linear_util transform to get output pytree paths of pre-flattened function." - ans = yield args, kwargs - yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] + ans = f(*args, **kwargs) + store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + return ans def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, result_paths: tuple[str, ...] | None = None, diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 55db5d13e848..22fde8bd1cb5 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -330,11 +330,12 @@ def update_error(error, pred, code, metadata, payload, effect_type): ## Checkify transformation for plumbing functional error values. -@lu.transformation_with_aux -def _flatten_and_get_error_metadata_thunk(*invals): - error, out = yield invals, {} +@lu.transformation_with_aux2 +def _flatten_and_get_error_metadata_thunk(f, store, *invals): + error, out = f(*invals) out_vals, out_tree = jtu.tree_flatten((error, out)) - yield out_vals, (out_tree, set(error._pred.keys())) + store.store((out_tree, set(error._pred.keys()))) + return out_vals def default_checkify_rule(primitive: core.Primitive, error: Error, enabled_errors, *invals: core.Value, @@ -438,10 +439,12 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors, consts = tuple(c.x for c in hashable_consts) return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args) -@lu.transformation_with_aux -def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) +@lu.transformation_with_aux2 +def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def _reduce_any_error(error: Error): diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e37494c4fe41..69130cc1831e 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -75,13 +75,14 @@ def _zeros_like_pytree(x): # like the api_util.py function, but also grabs output avals for error checking -@lu.transformation_with_aux -def _flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def _flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} + ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) ans_avals = [core.get_aval(x) for x in ans_flat] - yield ans_flat, (ans_tree, ans_avals) + store.store((ans_tree, ans_avals)) + return ans_flat ### JVPs @@ -266,18 +267,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable def _add_args(f, extra_args): return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args)) -@lu.transformation -def _add_args_(extra_args, *args, **kwargs): +@lu.transformation2 +def _add_args_(f, extra_args, *args, **kwargs): extra_args = tuple(arg.val for arg in extra_args) all_args = (extra_args + args) - yield (yield all_args, kwargs) + return f(*all_args, **kwargs) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args): primals_in, tangents_in = split_list(args, [len(args) // 2]) py_primals = tree_unflatten(in_tree, primals_in) py_tangents = tree_unflatten(in_tree, tangents_in) - pair_out = yield (py_primals, py_tangents), {} + pair_out = f(py_primals, py_tangents) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} " "must produce a pair (list or tuple of length two) representing " @@ -348,7 +349,8 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): if av_et != av_t) raise TypeError(msg.format('\n'.join(disagreements))) - yield primals_out + tangents_out, (out_tree, primal_avals) + store.store((out_tree, primal_avals)) + return primals_out + tangents_out class CustomJVPCallPrimitive(core.Primitive): multiple_results = True @@ -652,15 +654,15 @@ def _check_for_tracers(x): "arguments should typically not be indicated as nondiff_argnums.") raise UnexpectedTracerError(msg) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, *args): if symbolic_zeros: args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])] else: args = args[::2] py_args = tree_unflatten(in_tree, args) - pair_out = yield py_args, {} + pair_out = f(*py_args) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -710,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - yield (*res, *primals_out), (out_tree, res_tree) + store.store((out_tree, res_tree)) + return (*res, *primals_out) -@lu.transformation -def _flatten_bwd(in_tree, in_avals, out_trees, *args): +@lu.transformation2 +def _flatten_bwd(f, in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) - py_cts_in = yield (py_res, py_cts_out), {} + py_cts_in = f(py_res, py_cts_out) if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)): py_cts_in = tuple(py_cts_in) # For each None in py_cts_in, indicating an argument for which the rule @@ -775,7 +778,7 @@ def append(x, d): f"to an input of shape/dtype {a.str_short()}.") raise ValueError(msg) results.append(ct) - yield results + return results # TODO(mattjj): remove both these exceptions to cotangent compatibility check def _temporary_dtype_exception(a, a_) -> bool: @@ -1425,11 +1428,11 @@ def fun_jaxpr_thunk(): return wrapped_fwd -@lu.transformation -def _fix_fwd_args(*args): +@lu.transformation2 +def _fix_fwd_args(f, *args): args = [(x, True) for x in args] args = [x for pair in args for x in pair] - yield (yield args, {}) + return f(*args) def _remat_opt_impl( *args, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d080aae759a6..99340e728545 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -68,42 +68,43 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, fun, aux = jvp_subtrace_aux(fun) return jvpfun(fun, instantiate, transform_stack), aux -@lu.transformation -def jvpfun(instantiate, transform_stack, primals, tangents): +@lu.transformation2 +def jvpfun(f, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with ctx: - out_primals, out_tangents = yield (tag, primals, tangents), {} + out_primals, out_tangents = f(tag, primals, tangents) if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] - yield out_primals, out_tangents + return out_primals, out_tangents -@lu.transformation -def jvp_subtrace(tag, primals, tangents): +@lu.transformation2 +def jvp_subtrace(f, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) in_tracers = [maybe_jvp_tracer(trace, x, t) for x, t in zip(primals, tangents)] with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out = unzip2(map(trace.to_primal_tangent_pair, ans)) - yield out + return out -@lu.transformation_with_aux -def jvp_subtrace_aux(tag, primals, tangents): +@lu.transformation_with_aux2 +def jvp_subtrace_aux(f, store, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) with core.set_current_trace(trace): - ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {} + ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents))) out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag else x for x in aux] - yield (out_primals, out_tangents), aux_primals + store.store(aux_primals) + return out_primals, out_tangents def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) @@ -262,10 +263,11 @@ def get_primitive_transpose(p): "Transpose rule (for reverse-mode differentiation) for '{}' " "not implemented".format(p)) from err -@lu.transformation_with_aux -def nonzero_tangent_outputs(*args, **kwargs): - results = (_, tangents_out) = yield args, kwargs - yield results, [type(r) is not Zero for r in tangents_out] +@lu.transformation_with_aux2 +def nonzero_tangent_outputs(f, store, *args, **kwargs): + results = (_, tangents_out) = f(*args, **kwargs) + store.store([type(r) is not Zero for r in tangents_out]) + return results class JVPTrace(Trace): @@ -543,15 +545,16 @@ def zero_jvp(primitive, primals, tangents, **params): def instantiate_zeros(tangent): return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent -@lu.transformation_with_aux -def traceable(in_tree, *primals_and_tangents): +@lu.transformation_with_aux2 +def traceable(f, store, in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) tangents = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primals, tangents)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) tangents_out = [None if type(t) is Zero else t for t in tangents_out] out_flat, out_tree = tree_flatten((primals_out, tangents_out)) - yield out_flat, out_tree + store.store(out_tree) + return out_flat def call_transpose(primitive, params, call_jaxpr, args, ct, _): @@ -588,10 +591,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals): primitive_transposes[core.closed_call_p] = _closed_call_transpose -@lu.transformation_with_aux -def nonzero_outputs(*args, **kwargs): - results = yield args, kwargs - yield results, [type(r) is not Zero for r in results] +@lu.transformation_with_aux2 +def nonzero_outputs(f, store, *args, **kwargs): + results = f(*args, **kwargs) + store.store([type(r) is not Zero for r in results]) + return results def map_transpose(primitive, params, call_jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts @@ -655,17 +659,18 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() -@lu.transformation_with_aux -def f_jvp_traceable(nonzeros, *primals_and_nztangents): +@lu.transformation_with_aux2 +def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p) for p, nz in zip(primals, nonzeros)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) out_nonzeros = [type(t) is not Zero for t in tangents_out] nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] - yield list(primals_out) + nonzero_tangents_out, out_nonzeros + store.store(out_nonzeros) + return list(primals_out) + nonzero_tangents_out def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index b6325ed81ce1..f4658ec2be29 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -327,11 +327,13 @@ def unregister_vmappable(data_type: type) -> None: def is_vmappable(x: Any) -> bool: return type(x) is Jumble or type(x) in vmappables -@lu.transformation_with_aux -def flatten_fun_for_vmap(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_for_vmap(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans, is_leaf=is_vmappable) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable) + store.store(out_tree) + return ans # Propagate ragged masking rules from invars to outvars # rule([params], [raggedness_per_invar], outvars) -> @@ -580,16 +582,16 @@ def batch(fun: lu.WrappedFun, axis_data, f = _batch_inner(fun, axis_data, out_dim_dests) return _batch_outer(f, axis_data, in_dims) -@lu.transformation -def _batch_outer(axis_data, in_dims, *in_vals): +@lu.transformation2 +def _batch_outer(f, axis_data, in_dims, *in_vals): tag = TraceTag() with source_info_util.transform_name_stack('vmap'): - outs, trace = yield (tag, in_dims, *in_vals), {} + outs, trace = f(tag, in_dims, *in_vals) with core.ensure_no_leaks(trace): del trace - yield outs + return outs -@lu.transformation -def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): +@lu.transformation2 +def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) @@ -599,13 +601,13 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): - outs = yield in_tracers, {} + outs = f(*in_tracers) out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), outs, out_dim_dests) - yield out_vals, trace + return out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. def vtile(f_flat: lu.WrappedFun, @@ -628,21 +630,21 @@ def untile_axis(out, axis: int | None): shape[axis:axis+2] = [shape[axis] * shape[axis+1]] return out.reshape(shape) - @lu.transformation - def _map_to_tile(*args_flat): + @lu.transformation2 + def _map_to_tile(f, *args_flat): sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) tile_size_ = tile_size or next(sizes, None) assert tile_size_ is not None, "No mapped arguments?" - outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} - yield map(untile_axis, outputs_flat, out_axes_flat) + outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat)) + return map(untile_axis, outputs_flat, out_axes_flat) axis_data = AxisData(axis_name, tile_size, None) return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat)) ### API for batching functions with jaxpr type inputs and outputs -@lu.transformation_with_aux -def batch_subtrace(tag, axis_data, in_dims, *in_vals): +@lu.transformation_with_aux2 +def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) with core.set_current_trace(trace): @@ -650,10 +652,11 @@ def batch_subtrace(tag, axis_data, in_dims, *in_vals): in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) if dim is not None else x for x, dim in zip(in_vals, in_dims)] - outs = yield in_tracers, {} + outs = f(*in_tracers) out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) segment_lens, out_dims = indirectify_ragged_axes(out_dims) - yield (*segment_lens, *out_vals), out_dims + store.store(out_dims) + return (*segment_lens, *out_vals) def indirectify_ragged_axes(dims): if not any(type(d) is RaggedAxis for d in dims): @@ -789,8 +792,8 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() -@lu.transformation_with_aux -def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): +@lu.transformation_with_aux2 +def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) _, in_axes = resolve_ragged_axes(in_vals, in_axes) @@ -799,16 +802,17 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): - outs = yield in_tracers, {} + outs = f(*in_tracers) out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) new_out_axes = indirectify_ragged_axes_against_inputs_outputs( out_axes, in_vals, out_vals) - yield out_vals, new_out_axes + store.store(new_out_axes) + return out_vals -@lu.transformation_with_aux -def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, +@lu.transformation_with_aux2 +def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes, *in_vals): - out_vals = yield (trace, in_axes, *in_vals), {} + out_vals = f(trace, in_axes, *in_vals) out_axes = out_axes() out_axes_dest = [(None if src is not_mapped else 0) if dst is zero_if_mapped else dst @@ -819,16 +823,16 @@ def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, out_vals = map(partial(matchaxis, axis_data.name, axis_data.size), out_axes, out_axes_dest, out_vals) out_batched = [dst is not None for dst in out_axes_dest] - yield out_vals, out_batched + store.store(out_batched) + return out_vals -@lu.transformation -def _batch_jaxpr_outer(axis_data, in_dims, *in_vals): +@lu.transformation2 +def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) else ax for x, ax in unsafe_zip(in_vals, in_dims)] tag = TraceTag() - out_vals = yield (tag, in_dims, *in_vals), {} - yield out_vals + return f(tag, in_dims, *in_vals) def _merge_bdims(x, y): if x == y: @@ -845,8 +849,8 @@ class ZeroIfMapped: pass ### functions for handling custom_vjp -@lu.transformation_with_aux -def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): +@lu.transformation_with_aux2 +def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): size = axis_data.size with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) @@ -855,7 +859,7 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): if type(val) is SymbolicZero else BatchTracer(trace, val, dim) for val, dim in zip(in_vals, in_dims * 2)] with core.set_current_trace(trace): - outs = yield in_tracers, {} + outs = f(*in_tracers) # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can # be wasteful in the rare case it actually triggers; handle symbolically! outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] @@ -868,7 +872,8 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): out_primal_bds, out_dims, out_primals) out_tangents = map(partial(matchaxis, trace.axis_data.name, size), out_tangent_bds, out_dims, out_tangents) - yield out_primals + out_tangents, out_dims * 2 + store.store(out_dims * 2) + return out_primals + out_tangents def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): axis_size = axis_data.size @@ -886,11 +891,11 @@ def new_bwd(*args): return bwd_.call_wrapped(*args) return new_bwd -@lu.transformation -def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): +@lu.transformation2 +def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): # this is like _match_axes, but we do reduce-sums as needed - out_vals = yield in_vals, {} - yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, + out_vals = f(*in_vals) + return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5431762d66d2..943c15b6ea49 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -475,18 +475,19 @@ def partition_pvals( consts = [pval.get_known() for pval in pvals if pval.is_known()] return knowns, avals, consts -@lu.transformation_with_aux +@lu.transformation_with_aux2 def partial_eval_wrapper_nounits( - in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], + f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], *in_consts: Any): in_avals_, in_consts_ = iter(in_avals), iter(in_consts) in_pvals = [PartialVal.known(next(in_consts_)) if known else PartialVal.unknown(next(in_avals_)) for known in in_knowns] sentinel = object() assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel - jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {} + jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals) out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) + store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env)) + return (*out_consts, *res) custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} @@ -574,20 +575,22 @@ def trace_to_jaxpr_nounits( return jaxpr, out_pvals, consts # TODO(mattjj): superfluous wrapper...? -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits( + f, trace: JaxprTrace, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + return jaxpr, (out_pvals, out_consts, env) -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits2( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -596,19 +599,19 @@ def trace_to_subjaxpr_nounits2( current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + return jaxpr, (out_pvals, out_consts, env) -def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): +def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] in_args = merge_lists(in_knowns, in_tracers, in_consts) with core.set_current_trace(trace): - ans = yield in_args, {} + ans = f(*in_args) assert isinstance(ans, (list, tuple)), ( f"Got unexpected return type when tracing function to jaxpr: {ans}") assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( @@ -625,8 +628,9 @@ def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): # The below variant implements an optimization where residuals which are also # inputs are indicated in auxiliary data rather than passed as outputs. # TODO(mattjj): update all callers to use this version, delete other version. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -635,8 +639,8 @@ def trace_to_subjaxpr_nounits_fwd( with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) with core.set_current_trace(trace): - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. @@ -646,15 +650,16 @@ def trace_to_subjaxpr_nounits_fwd( pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] del out_tracers - yield jaxpr, (fwds, out_pvals, pruned_consts, env) + return jaxpr, (fwds, out_pvals, pruned_consts, env) # The below variant implements two optimizations: # 1. residuals that are also primal inputs are indicated in aux data rather # than passed as outputs; # 2. residuals that are also primal outputs are indicated in aux data rather # than passed as redundant outputs. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd2( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -662,8 +667,8 @@ def trace_to_subjaxpr_nounits_fwd2( current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) - out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] # Which consts (aka residuals) are just forwarded inputs? Check obj id. @@ -680,7 +685,7 @@ def trace_to_subjaxpr_nounits_fwd2( if f1 is None and f2 is None] del out_tracers - yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) + return jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) FreeVar = namedtuple('FreeVar', ['val']) @@ -2066,10 +2071,10 @@ def transpose_jaxpr_thunk(): custom_staging_rules: dict[Primitive, Callable] = {} -@lu.transformation -def _interleave_fun(every_others, *args, **kwargs): +@lu.transformation2 +def _interleave_fun(f, every_others, *args, **kwargs): args_ = [x for pair in zip(args, every_others) for x in pair] - yield (yield (args_, kwargs)) + return f(*args_, **kwargs) # TODO: consider renaming to "lazy_thunk" def _memoize(fn): @@ -2083,18 +2088,19 @@ def memoized(*args): return out return memoized -@lu.transformation_with_aux -def _jvp_jaxpr_zeros(in_zeros, zero_avals, *primal_tangent_avals): +@lu.transformation_with_aux2 +def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): in_primals, nz_in_tangents = split_list(primal_tangent_avals, [len(in_zeros)]) symbolic_zeros = map(ad_util.SymbolicZero, zero_avals) tangents = merge_lists(in_zeros, nz_in_tangents, symbolic_zeros) - out = yield (*in_primals, *tangents), {} + out = f(*in_primals, *tangents) n, ragged = divmod(len(out), 2) assert not ragged out_primals, out_tangents = out[:n], out[n:] out_zeros = [type(t) is ad_util.SymbolicZero for t in out_tangents] out_nz_tangents, _ = partition_list(out_zeros, out_tangents) - yield [*out_primals, *out_nz_tangents], out_zeros + store.store(out_zeros) + return [*out_primals, *out_nz_tangents] # TODO(mattjj): remove this DebugInfo and helper functions, replace with # api_util.py versions diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 316fbc077ceb..caa414741666 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -690,15 +690,15 @@ def find_replicas( num_global_replicas = global_axis_size * jaxpr_replicas return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas) -@lu.transformation -def _change_argument_ranks(in_axes, out_axes_thunk, *args): +@lu.transformation2 +def _change_argument_ranks(f, in_axes, out_axes_thunk, *args): args = tuple( arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) for in_axis, arg in zip(in_axes, args) ) - results = yield (args, {}) + results = f(*args) out_axes = out_axes_thunk() - yield tuple( + return tuple( x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) for x, axis in zip(results, out_axes) ) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 08f94c6e8eda..37d812dec619 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -64,6 +64,7 @@ def trans1(static_arg, *dynamic_args, **kwargs): from __future__ import annotations from collections.abc import Callable +from functools import partial from typing import Any, NamedTuple import weakref @@ -149,10 +150,11 @@ class WrappedFun: params: extra parameters to pass as keyword arguments to `f`, along with the transformed keyword arguments. """ - __slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info") + __slots__ = ("f", "f_transformed", "transforms", "stores", "params", "in_type", "debug_info") - def __init__(self, f, transforms, stores, params, in_type, debug_info): + def __init__(self, f, f_transformed, transforms, stores, params, in_type, debug_info): self.f = f + self.f_transformed = f_transformed self.transforms = transforms self.stores = stores self.params = params @@ -165,8 +167,14 @@ def __name__(self): def wrap(self, gen, gen_static_args, out_store) -> WrappedFun: """Add another transform and its store.""" - return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + if out_store is None: + return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) + else: + return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) def populate_stores(self, stores): """Copy the values from the `stores` into `self.stores`.""" @@ -175,47 +183,8 @@ def populate_stores(self, stores): self_store.store(other_store.val) def call_wrapped(self, *args, **kwargs): - """Calls the underlying function, applying the transforms. - - The positional `args` and keyword `kwargs` are passed to the first - transformation generator. - """ - stack = [] - for (gen, gen_static_args), out_store in zip(self.transforms, self.stores): - gen = gen(*(gen_static_args + tuple(args)), **kwargs) - args, kwargs = next(gen) - stack.append((gen, out_store)) - gen = gen_static_args = out_store = None - - try: - ans = self.f(*args, **dict(self.params, **kwargs)) - except: - # Some transformations yield from inside context managers, so we have to - # interrupt them before reraising the exception. Otherwise they will only - # get garbage-collected at some later time, running their cleanup tasks - # only after this exception is handled, which can corrupt the global - # state. - while stack: - stack.pop()[0].close() - raise - - args = kwargs = None - while stack: - gen, out_store = stack.pop() - try: - ans = gen.send(ans) - except: - # As above does for the first half of the transformation, exceptions - # raised in the second half of the transformation also require us to - # clean up references here. - while stack: - stack.pop()[0].close() - raise - if out_store is not None: - ans, side = ans - out_store.store(side) - - return ans + """Calls the transformed function""" + return self.f_transformed(*args, **kwargs) def __repr__(self): def transform_to_str(x): @@ -234,7 +203,7 @@ def __eq__(self, other): self.debug_info == other.debug_info) @curry -def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: +def transformation2(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """Adds one more transformation to a WrappedFun. Args: @@ -244,8 +213,28 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """ return fun.wrap(gen, gen_static_args, None) +# Backwards compat only. TODO: deprecate +@curry +def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + return gen_inst.send(f(*args_, **kwargs_)) + return transformation2(gen2, fun, *gen_static_args)() + +# Backwards compat only. TODO: deprecate +@curry +def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, store, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + ans, aux = gen_inst.send(f(*args_, **kwargs_)) + store.store(aux) + return ans + return transformation_with_aux2(gen2, fun, *gen_static_args)() + @curry -def transformation_with_aux( +def transformation_with_aux2( gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False ) -> tuple[WrappedFun, Callable[[], Any]]: """Adds one more transformation with auxiliary output to a WrappedFun.""" @@ -261,8 +250,9 @@ def fun_name(f): def wrap_init(f, params=None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" + params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) - return WrappedFun(f, (), (), params, None, None) + return WrappedFun(f, partial(f, **params_dict), (), (), params, None, None) def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: @@ -270,7 +260,7 @@ def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: if in_type is None: return f _check_input_type(in_type) - return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info) def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed @@ -317,7 +307,7 @@ def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None assert f.debug_info is None if debug_info is None: return f - return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info) def cache(call: Callable, *, explain: Callable | None = None): @@ -357,9 +347,9 @@ def _evict_function(f): cache_clearing_funs.add(memoized_fun.cache_clear) return memoized_fun -@transformation -def hashable_partial(*args): - yield (yield args, {}) +@transformation2 +def hashable_partial(f, *args): + return f(*args) def merge_linear_aux(aux1, aux2): diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index c7bd7dd7178f..d77ca86c152a 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -824,14 +824,13 @@ def debug_print_lowering_rule(ctx, *args, **params): # because they should appear as atomic JAX values to the users. # TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU # inferred by the compiler. -@lu.transformation -def wrap_with_transforms(transforms, *args): +@lu.transformation2 +def wrap_with_transforms(f, transforms, *args): new_args = tuple( state_types.TransformedRef(a, t) if t else a for a, t in zip(args, transforms) ) - res = yield new_args, {} - yield res + return f(*new_args) run_scoped_p = jax_core.Primitive("run_scoped") diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index a25d93a35c51..b4adbadfa6c5 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -97,34 +97,34 @@ def jvp(f, primals, tangents, attr_tangents): out_tangents = tree_unflatten(out_tree(), out_tangents_flat) return out_primals, out_tangents, tangent_attrs_out -@lu.transformation -def _set_attrs(attrs, attr_vals, *args): +@lu.transformation2 +def _set_attrs(f, attrs, attr_vals, *args): for (o, a), x in zip(attrs, attr_vals): jax_setattr(o, a, x) - yield (yield args, {}) + return f(*args) def _jvp(fun: lu.WrappedFun): return jvpfun2(jvp_subtrace2(fun)) -@lu.transformation -def jvpfun2(primals, tangents): +@lu.transformation2 +def jvpfun2(f, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = source_info_util.transform_name_stack('jvp') with ctx: - out_primals, out_tangents, tangent_attrs_out = yield (tag, primals, tangents), {} - yield out_primals, out_tangents, tangent_attrs_out + out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) + return out_primals, out_tangents, tangent_attrs_out -@lu.transformation -def jvp_subtrace2(tag, primals, tangents): +@lu.transformation2 +def jvp_subtrace2(f, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = ad.JVPTrace(parent_trace, tag) tag.attrs_tracked = [] # attrs written to in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x for x, t in zip(primals, tangents)] with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) tangent_attrs_out = [] for (obj, name) in tag.attrs_tracked: @@ -133,7 +133,7 @@ def jvp_subtrace2(tag, primals, tangents): if type(tangent) is not ad.Zero: tangent_attrs_out.append((obj, name, tangent)) del tag.attrs_tracked - yield out_primals, out_tangents, tangent_attrs_out + return out_primals, out_tangents, tangent_attrs_out def _setattr_jvp(trace, obj, attr, maybe_tracer): primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) @@ -175,11 +175,12 @@ def _linearize(traceable: lu.WrappedFun, *primals): return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], jaxpr, consts, attrs()) -@lu.transformation_with_aux -def _split_attrs(*args, **kwargs): - primals, tangents, tangent_attrs = yield args, kwargs +@lu.transformation_with_aux2 +def _split_attrs(f, store, *args, **kwargs): + primals, tangents, tangent_attrs = f(*args, **kwargs) attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) - yield (primals, tangents, tangent_attr_vals), attrs + store.store(attrs) + return primals, tangents, tangent_attr_vals def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): in_tree, out_tree = io_tree diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c6d920918074..b8acb0d1a14a 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1040,20 +1040,20 @@ def impl_multiple_results_jax(*args_jax): return wrapped_tf -@lu.transformation -def _interpret_subtrace(in_avals: Sequence[core.ShapedArray], +@lu.transformation2 +def _interpret_subtrace(f, in_avals: Sequence[core.ShapedArray], *in_vals: TfVal): trace = TensorFlowTrace() in_tracers = tuple( TensorFlowTracer(trace, val, aval) for val, aval in zip(in_vals, in_avals)) with core.set_current_trace(trace): - outs = yield in_tracers, {} # type: Sequence[TfVal] + outs = f(*in_tracers) out_tracers: Iterable[TensorFlowTracer] = ( map(trace.to_tf_tracer, outs)) out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( tuple((t.val, t.aval) for t in out_tracers)) - yield out_vals_with_avals + return out_vals_with_avals def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal, diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 827e4d01b390..75b040a28f4f 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -141,40 +141,43 @@ def jet(fun, primals, series): if not treedef_is_leaf(treedef): raise ValueError(f"term {j} for argument {i} is not an array") - @lu.transformation_with_aux - def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) + @lu.transformation_with_aux2 + def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, tree = tree_flatten(ans) + store.store(tree) + return ans f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) -@lu.transformation -def jet_fun(order, primals, series): +@lu.transformation2 +def jet_fun(f, order, primals, series): tag = core.TraceTag() - out_primals, out_terms = yield (tag, order, primals, series), {} + out_primals, out_terms = f(tag, order, primals, series) out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] - yield out_primals, out_terms + return out_primals, out_terms -@lu.transformation -def jet_subtrace(tag, order, primals, series): +@lu.transformation2 +def jet_subtrace(f, tag, order, primals, series): with core.take_current_trace() as parent_trace: trace = JetTrace(tag, parent_trace, order) in_tracers = map(partial(JetTracer, trace), primals, series) with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans)) - yield out_primals, out_terms + return out_primals, out_terms -@lu.transformation_with_aux -def traceable(in_tree_def, *primals_and_series): +@lu.transformation_with_aux2 +def traceable(f, store, in_tree_def, *primals_and_series): primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series) - primals_out, series_out = yield (primals_in, series_in), {} + primals_out, series_out = f(primals_in, series_in) out_flat, out_tree_def = tree_flatten((primals_out, series_out)) - yield out_flat, out_tree_def + store.store(out_tree_def) + return out_flat class JetTracer(core.Tracer): diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index b8e3daee48c8..987e461a39b2 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -47,12 +47,12 @@ def ravel_first_arg(f, unravel): return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped -@lu.transformation -def ravel_first_arg_(unravel, y_flat, *args): +@lu.transformation2 +def ravel_first_arg_(f, unravel, y_flat, *args): y = unravel(y_flat) - ans = yield (y,) + args, {} + ans = f(y, *args) ans_flat, _ = ravel_pytree(ans) - yield ans_flat + return ans_flat def interp_fit_dopri(y0, y1, k, dt): # Fit a polynomial to the results of a Runge-Kutta step. diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3a9446862456..c658ddd3a4b8 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1479,15 +1479,15 @@ def known_out_names(): return pe.merge_lists(out_knowns, out_tracers, out_consts) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval -@lu.transformation -def _promote_scalar_residuals(*args, **kwargs): - jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs +@lu.transformation2 +def _promote_scalar_residuals(f, *args, **kwargs): + jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) which = [f1 is None and f2 is None and not v.aval.shape for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x for x in out_consts] - yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) + return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) def _promote_scalar_residuals_jaxpr(jaxpr, which): @lu.wrap_init @@ -1728,13 +1728,13 @@ def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): check_rep=False, auto=frozenset()), in_specs, out_specs) -@lu.transformation -def _handle_reshapes(in_axes, out_axes_thunk, *args, **kwargs): +@lu.transformation2 +def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), list(args), list(in_axes)) - out = yield args, {} - yield tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), - list(out), list(out_axes_thunk())) + out = f(*args) + return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), + list(out), list(out_axes_thunk())) def _axis_to_spec(axis_name, ax): if isinstance(ax, int): @@ -1855,27 +1855,28 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) return _match_rep(fun, mesh, out_reps_src, out_reps_dst) -@lu.transformation_with_aux -def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args): +@lu.transformation_with_aux2 +def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): with core.take_current_trace() as parent: tag = core.TraceTag() t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, args) with core.set_current_trace(t): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) del t, in_tracers, ans - yield out_vals, out_reps + store.store(out_reps) + return out_vals -@lu.transformation -def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): - outs = yield args, {} +@lu.transformation2 +def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): + outs = f(*args) out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ _check_reps2(mesh, out_reps_dst, out_reps_src) outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] - yield outs + return outs # TODO(mattjj): caching def _replication_rewrite_match( @@ -1901,16 +1902,17 @@ def _replication_rewrite_nomatch( jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts), out_rep() -@lu.transformation_with_aux -def _rewrite_subtrace(tag, mesh, in_reps, *in_vals): +@lu.transformation_with_aux2 +def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals): with core.take_current_trace() as parent_trace: assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) t = RewriteTrace(parent_trace, tag, mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) with core.set_current_trace(t): - outs = yield in_tracers, {} - ans = unzip2(map(t.to_val_rep_pair, outs)) - yield ans + outs = f(*in_tracers) + out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs)) + store.store(out_reps) + return out_vals def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): def new_bwd(*args): diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 7c5a966500f7..050d0a5e0a49 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -340,16 +340,17 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zero with core.set_current_trace(self): return fun.call_wrapped(*tracers) -@lu.transformation_with_aux -def sparsify_subtrace(tag, spenv, spvalues, *bufs): +@lu.transformation_with_aux2 +def sparsify_subtrace(f, store, tag, spenv, spvalues, *bufs): with core.take_current_trace() as parent: trace = SparseTrace(parent, tag, spenv) with core.set_current_trace(trace): in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] - outs = yield in_tracers, {} + outs = f(*in_tracers) out_traces = [trace.to_sparse_tracer(out) for out in outs] buffers = spenv._buffers - yield buffers, [out._spvalue for out in out_traces] + store.store([out._spvalue for out in out_traces]) + return buffers def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]): tag = core.TraceTag() diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 74c52dddbae8..8b80d033fa5c 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -22,5 +22,7 @@ merge_linear_aux as merge_linear_aux, transformation as transformation, transformation_with_aux as transformation_with_aux, + transformation2 as transformation2, + transformation_with_aux2 as transformation_with_aux2, wrap_init as wrap_init, ) diff --git a/tests/util_test.py b/tests/util_test.py index 5f07d2f50880..5e99fff4b347 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -42,8 +42,8 @@ def f(*args, **kwargs): assert not kwargs return tuple(a * factor for a in args) - @lu.transformation_with_aux - def kw_to_positional(factor, *args, **kwargs): + @lu.transformation_with_aux2 + def kw_to_positional(f, store, factor, *args, **kwargs): """A transformation with auxiliary output. Turns all keyword parameters into positional ones. @@ -55,12 +55,12 @@ def kw_to_positional(factor, *args, **kwargs): kwargs_keys = kwargs.keys() new_args = tuple(kwargs[k] for k in kwargs_keys) new_kwargs = dict(factor=factor) - results = yield args + new_args, new_kwargs # Yield transformed (args, kwargs) + results = f(*(args + new_args), **new_kwargs) # Yield transformed (args, kwargs) # Assume results correspond 1:1 to the args + new_args assert len(results) == len(args) + len(new_args) aux_output = len(new_args) - yield (results[0:len(args)], - dict(zip(kwargs_keys, results[len(args):]))), aux_output + store.store(aux_output) + return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):]))) wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`. wf, out_thunk = kw_to_positional(wf, 2) From 94f9a488b1eeea4e28d78b12c22e9d6c60fa0aba Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 13 Nov 2024 22:11:39 +0000 Subject: [PATCH 339/698] Don't override --xla_tpu_use_enhanced_launch_barrier if explicitly set --- jax/_src/cloud_tpu_init.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index c7665da961af..8ff52bd2f559 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,8 @@ def cloud_tpu_init() -> None: os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" + if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']: + os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true' # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') From 2463bf1f943c7fab74dd19d1262a463fd7213463 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 13 Nov 2024 16:13:26 -0800 Subject: [PATCH 340/698] Avoid repeatedly rebuilding a tuple in issubdtype. PiperOrigin-RevId: 696304594 --- jax/_src/dtypes.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index ac0418932b83..f5b0c3fd68b1 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -339,6 +339,8 @@ def _issubclass(a: Any, b: Any) -> bool: return False +_types_for_issubdtype = (type, np.dtype, ExtendedDType) + # TODO(jakevdp): consider whether to disallow None here. We allow it # because np.issubdtype allows it (and treats it as equivalent to float64). def issubdtype(a: DTypeLike | ExtendedDType | None, @@ -360,8 +362,8 @@ def issubdtype(a: DTypeLike | ExtendedDType | None, # unhashable (e.g. custom objects with a dtype attribute). The following check is # fast and covers the majority of calls to this function within JAX library code. return _issubdtype_cached( - a if isinstance(a, (type, np.dtype, ExtendedDType)) else np.dtype(a), # type: ignore[arg-type] - b if isinstance(b, (type, np.dtype, ExtendedDType)) else np.dtype(b), # type: ignore[arg-type] + a if isinstance(a, _types_for_issubdtype) else np.dtype(a), # type: ignore[arg-type] + b if isinstance(b, _types_for_issubdtype) else np.dtype(b), # type: ignore[arg-type] ) From aefe6215ca6252f283fe87f5563d188ddfef10eb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 14 Nov 2024 02:37:08 -0800 Subject: [PATCH 341/698] [pallas:mosaic_gpu] Ported two pipelining optimizations to `emit_pipeline` * Skip SMEM->GMEM copy if the destination buffer is being revisited * Skip SMEM->GMEM copy if the corresponding index map does not use grid indices PiperOrigin-RevId: 696448043 --- jax/_src/pallas/mosaic_gpu/BUILD | 3 + jax/_src/pallas/mosaic_gpu/lowering.py | 10 +- jax/_src/pallas/mosaic_gpu/pipeline.py | 147 +++++++++++++++++++---- jax/_src/pallas/mosaic_gpu/primitives.py | 32 ++++- tests/pallas/mosaic_gpu_test.py | 31 +++++ 5 files changed, 195 insertions(+), 28 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 6f98c83fdfd8..ad418e2b936d 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -107,7 +107,10 @@ pytype_strict_library( ":core", ":primitives", "//jax", + "//jax:core", + "//jax:mosaic_gpu", "//jax:pallas", + "//jax:partial_eval", "//jax:util", "//jax/_src/pallas", ], diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index dc46f1b81b6a..6d30cdb0d4a3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -264,6 +264,7 @@ def scratch_view( class LoweringRuleContext: module_ctx: ModuleContext launch_ctx: mgpu.LaunchContext + predicate: ir.Value avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] @@ -878,6 +879,7 @@ def write_env(var: jax_core.Var, val): rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, + predicate=mgpu.single_thread_predicate(per_block=False), avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], ) @@ -1120,6 +1122,12 @@ def _convert_element_type_lowering_rule( ) +mosaic_lowering_rules.update({ + lax.neg_p: lambda ctx, x: -x, + lax.not_p: lambda ctx, x: ~x, +}) + + def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return impl(x, y) @@ -1576,4 +1584,4 @@ def _as_index(v: object) -> ir.Value: case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()): return _as_index(v.registers.item()) case _: - raise ValueError(f"Unsupported index: {v}") + raise ValueError(f"Unsupported index: {v} of type {type(v)}") diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 21267b50a007..91e1e1c45429 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import itertools as it @@ -25,7 +25,10 @@ import jax from jax import lax +from jax._src import core +from jax._src import linear_util as lu from jax._src import util +from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives @@ -37,17 +40,19 @@ zip = util.safe_zip +@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class BufferedRef: - spec: pallas_core.BlockSpec + spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True}) + is_index_invariant: bool = dataclasses.field(metadata={"static": True}) gmem_ref: pallas_core.AbstractMemoryRef smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape] - def compute_gmem_slice(self, grid_indices) -> tuple[Any, ...]: + def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: index_map = self.spec.index_map assert index_map is not None return tuple( - pl.ds(idx * size, size) + pl.Slice(idx * size, size) # type: ignore[arg-type] for idx, size in zip( index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type] ) @@ -61,16 +66,31 @@ def copy_in(self, slot, grid_indices, barrier_ref): barrier=barrier_ref.at[slot], ) - def copy_out(self, slot, grid_indices): + def copy_out(self, slot, grid_indices, predicate=None): gmem_slices = self.compute_gmem_slice(grid_indices) gpu_primitives.copy_smem_to_gmem( - self.smem_ref.at[slot], self.gmem_ref.at[gmem_slices] # pytype: disable=unsupported-operands + self.smem_ref.at[slot], + self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands + predicate=predicate, ) -jax.tree_util.register_dataclass( - BufferedRef, data_fields=["gmem_ref", "smem_ref"], meta_fields=["spec"] -) +def _uses_arguments( + index_map: Callable[..., Any], num_args: int +) -> Sequence[bool]: + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args + ) + _, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars)) + return used_inputs + + +def _is_index_invariant( + spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid +) -> bool: + index_map = spec.index_map + assert index_map is not None + return not any(_uses_arguments(index_map, len(grid))) def _inc_grid_by_1( @@ -85,6 +105,25 @@ def _inc_grid_by_1( return tuple(reversed(next_indices)) +# ``pl.Slice`` uses a different pytree encoding, depending on whether the +# start/size are static or dynamic. This leads to pytree structure mismatch +# in the pipeline body. So, we define a different ``Slice`` class below. + + +@dataclasses.dataclass(frozen=True) +class _Slice: + start: int | jax.Array + size: int | jax.Array + + def __eq__(self, other: _Slice) -> jax.Array: # type: ignore + return lax.bitwise_and(self.start == other.start, self.size == other.size) + + +jax.tree_util.register_dataclass( + _Slice, data_fields=["start", "size"], meta_fields=[] +) + + def emit_pipeline( body, *, @@ -102,6 +141,16 @@ def emit_pipeline( max_concurrent_steps = num_steps def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): + for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)): + if any( + spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore + for idx in range(1, len(grid) + 1) + ): + raise NotImplementedError( + f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block" + f" shape {spec.block_shape}." + ) + in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( map( @@ -132,13 +181,18 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): def scoped_pipeline( *, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref ): - - in_brefs: Sequence[BufferedRef] = map( - BufferedRef, in_specs, in_gmem_refs, in_smem_refs - ) - out_brefs: Sequence[BufferedRef] = map( - BufferedRef, out_specs, out_gmem_refs, out_smem_refs - ) + in_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref) + for spec, gmem_ref, smem_ref in zip( + in_specs, in_gmem_refs, in_smem_refs + ) + ] + out_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref) + for spec, gmem_ref, smem_ref in zip( + out_specs, out_gmem_refs, out_smem_refs + ) + ] for step, indices in enumerate( it.islice(it.product(*map(range, grid)), max_concurrent_steps) @@ -147,10 +201,11 @@ def scoped_pipeline( def loop_body(step, carry): slot = step % max_concurrent_steps - indices, fetch_indices = carry + indices, fetch_indices, last_store_slices = carry - # Wait for the current GMEM->SMEM copy to complete. - gpu_primitives.barrier_wait(barrier_ref.at[slot]) + if in_specs: + # Wait for the current GMEM->SMEM copy to complete. + gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) @@ -159,9 +214,34 @@ def loop_body(step, carry): *(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs)) ) + if not all(bref.is_index_invariant for bref in out_brefs): + gpu_primitives.commit_smem() + # Copy the output from SMEM to GMEM. - gpu_primitives.commit_smem() - map(lambda bref: bref.copy_out(slot, indices), out_brefs) + new_store_slices = last_store_slices[:] + for idx, bref in enumerate(out_brefs): + if bref.is_index_invariant: + assert last_store_slices[idx] is None + continue + assert last_store_slices[idx] is not None + new_store_slices[idx] = tuple( + _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices) + ) + are_same_slices = map( + lambda old, new: old == new, + last_store_slices[idx], + new_store_slices[idx], + ) + slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) + is_last_step = step == num_steps - 1 + # TODO(apaszke,slebedev): This still diverges significantly from the + # TPU semantics in that it will move on to the next SMEM output slice + # even if it's not storing the previous one. + bref.copy_out( + slot, + indices, + predicate=lax.bitwise_or(slices_changed, is_last_step), + ) fetch_step = step + max_concurrent_steps fetch_slot = slot # (x + y) % y == x % y @@ -174,13 +254,34 @@ def loop_body(step, carry): lambda: [None] * len(in_brefs), ) - return _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid) + return ( + _inc_grid_by_1(indices, grid), + _inc_grid_by_1(fetch_indices, grid), + new_store_slices, + ) indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid) fetch_indices = indices for _ in range(max_concurrent_steps): fetch_indices = _inc_grid_by_1(fetch_indices, grid) - lax.fori_loop(0, num_steps, loop_body, (indices, fetch_indices)) + last_store_slices = [ + None + if bref.is_index_invariant + else (_Slice(-1, -1),) * len(bref.spec.block_shape) + for bref in out_brefs + ] + last_indices, _, _ = lax.fori_loop( + 0, num_steps, loop_body, (indices, fetch_indices, last_store_slices) + ) + + # Outputs invariant to the sequential axis are never written from inside the + # loop. This is the only place where we store them. + if all(bref.is_index_invariant for bref in out_brefs): + gpu_primitives.commit_smem() + last_slot = (num_steps - 1) % max_concurrent_steps + for bref in out_brefs: + if bref.is_index_invariant: + bref.copy_out(last_slot, last_indices, predicate=None) # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1ced213394ff..5fc4ed5e7afc 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -26,6 +26,7 @@ from jax._src import tree_util from jax._src import util from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -34,6 +35,7 @@ from jax._src.state import indexing from jax._src.state import primitives as state_primitives import jax.experimental.mosaic.gpu as mgpu +import jax.numpy as jnp WARPGROUP_SIZE = 128 @@ -54,19 +56,31 @@ def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, src, dst, - *flat_transforms, + *flat_args, src_transforms_treedef, dst_transforms_treedef, + has_user_predicate, ): + predicate = ctx.predicate + if has_user_predicate: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + predicate = arith_dialect.andi( + predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) + ) flat_src_transforms, flat_dst_transforms = util.split_list( - flat_transforms, + flat_args, [src_transforms_treedef.num_leaves], ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) src, src_transforms = lowering._handle_indexing(src, src_transforms) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params) + ctx.launch_ctx.async_copy( + src_ref=src, + dst_ref=dst, + predicate=predicate, + **copy_params, + ) return () @@ -98,10 +112,18 @@ def _extract_smem_copy_params(transforms): def copy_smem_to_gmem( - src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef + src: pallas_core.AbstractMemoryRef, + dst: pallas_core.AbstractMemoryRef, + predicate: jax.Array | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. + Args: + src: The SMEM reference to copy from. + dst: The GMEM reference to copy to. + predicate: A boolean indicating whether the copy should be performed. If + ``None``, the copy is always performed. + See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` :func:`jax.experimental.mosaic.gpu.commit_smem` @@ -127,8 +149,10 @@ def copy_smem_to_gmem( dst, *flat_src_transforms, *flat_dst_transforms, + *[] if predicate is None else [predicate], src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, + has_user_predicate=predicate is not None, ) return None diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index cbbe8da54972..83202937503d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1146,6 +1146,37 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit_with_grid_invariant_output(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + y = jnp.empty_like(x) + for i in range(num_steps): + i_slice = slice(16 * i, 16 * (i + 1)) + y = y.at[:, :16].set(x[:, i_slice] + 1) + # We only compare the elements in the first 16 columns, because the rest + # are never written to. + np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16]) + def test_emit_with_parallel_grid(self): self.skipTest("Enable once we support multiple levels of indexing") From ad5a062198eea3a837b57784d2bb299817bb1f46 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 14 Nov 2024 07:02:36 -0800 Subject: [PATCH 342/698] Make the jaxpr for jnp.pad in "constant" mode more succinct. Example before: ``` $ print(jax.jit(lambda x: jnp.pad(x, ((0, 0), (1, 0), (0, 1)), constant_values=7)).lower(jnp.ones((3,4,5))).as_text()) module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) { %c = stablehlo.constant dense<7> : tensor %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor) -> tensor<3x5x6xf32> return %0 : tensor<3x5x6xf32> } func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor) -> tensor<3x5x6xf32> { %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<3x2xi32> %1 = stablehlo.convert %0 : (tensor<3x2xi32>) -> tensor<3x2xf32> %2 = stablehlo.slice %1 [0:1, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32> %3 = stablehlo.reshape %2 : (tensor<1x1xf32>) -> tensor %4 = stablehlo.pad %arg0, %3, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> %5 = stablehlo.slice %1 [0:1, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32> %6 = stablehlo.reshape %5 : (tensor<1x1xf32>) -> tensor %7 = stablehlo.pad %4, %6, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> %8 = stablehlo.slice %1 [1:2, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32> %9 = stablehlo.reshape %8 : (tensor<1x1xf32>) -> tensor %10 = stablehlo.pad %7, %9, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor) -> tensor<3x5x5xf32> %11 = stablehlo.slice %1 [1:2, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32> %12 = stablehlo.reshape %11 : (tensor<1x1xf32>) -> tensor %13 = stablehlo.pad %10, %12, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor) -> tensor<3x5x5xf32> %14 = stablehlo.slice %1 [2:3, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32> %15 = stablehlo.reshape %14 : (tensor<1x1xf32>) -> tensor %16 = stablehlo.pad %13, %15, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor) -> tensor<3x5x5xf32> %17 = stablehlo.slice %1 [2:3, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32> %18 = stablehlo.reshape %17 : (tensor<1x1xf32>) -> tensor %19 = stablehlo.pad %16, %18, low = [0, 0, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor) -> tensor<3x5x6xf32> return %19 : tensor<3x5x6xf32> } } ``` After: ``` module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) { %c = stablehlo.constant dense<7> : tensor %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor) -> tensor<3x5x6xf32> return %0 : tensor<3x5x6xf32> } func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor) -> tensor<3x5x6xf32> { %0 = stablehlo.convert %arg1 : (tensor) -> tensor %1 = stablehlo.pad %arg0, %0, low = [0, 1, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor) -> tensor<3x5x6xf32> return %1 : tensor<3x5x6xf32> } } ``` --- jax/_src/numpy/lax_numpy.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b90004e19932..88ddc85a0a40 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4048,15 +4048,37 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str): def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array: nd = ndim(array) - constant_values = broadcast_to(constant_values, (nd, 2)) constant_values = lax_internal._convert_element_type( constant_values, array.dtype, dtypes.is_weakly_typed(array)) + constant_values_nd = ndim(constant_values) + + if constant_values_nd == 0: + widths = [(low, high, 0) for (low, high) in pad_width] + return lax.pad(array, constant_values, widths) + + if constant_values_nd == 1: + if constant_values.shape[-1] == 1: + widths = [(low, high, 0) for (low, high) in pad_width] + return lax.pad(array, squeeze(constant_values), widths) + elif constant_values.shape[-1] == 2: + widths = [(low, 0, 0) for (low, _) in pad_width] + array = lax.pad(array, constant_values[0], widths) + widths = [(0, high, 0) for (_, high) in pad_width] + return lax.pad(array, constant_values[1], widths) + else: + raise ValueError("jnp.pad: constant_values has unsupported shape " + f"{constant_values.shape}. If the shape is 1D or 2D, the " + "last dimension must be of size 1 or 2.") + + constant_values = broadcast_to(constant_values, (nd, 2)) for i in range(nd): widths = [(0, 0, 0)] * nd - widths[i] = (pad_width[i][0], 0, 0) - array = lax.pad(array, constant_values[i, 0], widths) - widths[i] = (0, pad_width[i][1], 0) - array = lax.pad(array, constant_values[i, 1], widths) + if pad_width[i][0] != 0: + widths[i] = (pad_width[i][0], 0, 0) + array = lax.pad(array, constant_values[i, 0], widths) + if pad_width[i][1] != 0: + widths[i] = (0, pad_width[i][1], 0) + array = lax.pad(array, constant_values[i, 1], widths) return array From 081eaeaaccdcc10af292e8c5c333f1a0945acd76 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 14 Nov 2024 08:17:10 -0800 Subject: [PATCH 343/698] Don't use an out-of-line lowering for integer_pow for small powers. This yields a smaller stablehlo output. Add a fast path for y == 1 and y == -1, which turn out to be reasonably common. --- jax/_src/lax/lax.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7e15f46c3ef1..a1ad88f7c249 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2633,24 +2633,24 @@ def _integer_pow(x, *, y): def _integer_pow_lowering(ctx, x, *, y): # These cases are subsumed by the general case, but it's faster to emit these # common cases directly. - if y == 2: + if y == 1: + out = x + elif y == 2: out = hlo.multiply(x, x) elif y == 3: out = hlo.multiply(hlo.multiply(x, x), x) + elif y == -1: + out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x) else: lowering = mlir.lower_fun(_integer_pow, multiple_results=False) - # TODO(b/217551391): emitting an out-of-line call leads to a large - # expansion when the MLIR is lowered to HLO, because the HLO lowering - # clones the callee. Consider unconditionally caching when the MLIR->HLO - # lowering doesn't expand the program. - lowering = mlir.cache_lowering(lowering) - out = lowering(ctx, x, y=y) + if builtins.abs(y) >= 3: + lowering = mlir.cache_lowering(lowering) + out, = lowering(ctx, x, y=y) if config.sharding_in_types.value: aval_out, = ctx.avals_out proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - out = out[0] if isinstance(out, list) else out return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] - return out if isinstance(out, list) else [out] + return [out] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) From 6644bbc9c149ed0481d97a0ae396b5a4fd6cb378 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 14 Nov 2024 17:59:58 +0100 Subject: [PATCH 344/698] Fixed typo in asan.yaml --- .github/workflows/asan.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 9a49ed2a3e61..ea87d4e29e40 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -12,7 +12,7 @@ on: branches: - main paths: - - '**/workflows/asan.yml' + - '**/workflows/asan.yaml' jobs: asan: From f401c97967584331710273f89acc6b7048026913 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 12 Nov 2024 12:26:34 -0800 Subject: [PATCH 345/698] finalize deprecation of jax.clear_backends --- CHANGELOG.md | 1 + jax/__init__.py | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 78351047b09b..10b1fc808970 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` on the function inputs. + * `jax.clear_backends` was removed after being deprecated in v0.4.26. * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/jax/__init__.py b/jax/__init__.py index 7916ef0e3962..8ca7721da445 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -83,7 +83,6 @@ from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies -from jax._src.api import clear_backends as _deprecated_clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import custom_gradient as custom_gradient @@ -218,16 +217,15 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), - # Added Mar 18, 2024 + # Finalized Nov 12 2024; remove after Feb 12 2025 "clear_backends": ( - "jax.clear_backends is deprecated.", - _deprecated_clear_backends + "jax.clear_backends was removed in JAX v0.4.36", + None ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.api import clear_backends as clear_backends from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves From de07be5cb67691ac822ebf03fac5cd6aa8752deb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 14 Nov 2024 09:46:12 -0800 Subject: [PATCH 346/698] Fix pgle test handling of directory removal. This is required after https://github.com/jax-ml/jax/pull/22899. PiperOrigin-RevId: 696555876 --- tests/pgle_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index a27f3ec0b9ac..fa574df18f29 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -18,6 +18,7 @@ import logging import math import os +import shutil import tempfile from absl.testing import absltest @@ -286,7 +287,11 @@ def f(x): # Removing non-pgle profiled module from cache to check that later pgle # profiled version will be used. for non_pgle_file in non_pgle_profiled_files: - os.remove(os.path.join(cache_dir, non_pgle_file)) + path = os.path.join(cache_dir, non_pgle_file) + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) api.clear_caches() pjit._pgle_profiler_dict.clear() From a8464ce76169ed11f2dbd38b1d5124da857447ab Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 14 Nov 2024 09:51:33 -0800 Subject: [PATCH 347/698] [Mosaic][TPU] Omit short circuiting of relayout (we should always relayout!) and implement product mismatch case for where we relayout from replicated to offset, and the number of vregs changes. PiperOrigin-RevId: 696557463 --- .../tpu/transforms/apply_vector_layout.cc | 59 +++++++++++++++---- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c9c4a81e668d..8792503f4636 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4723,6 +4723,11 @@ FailureOr> disassemble( TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value()); TPU_ASSERT_LOC(val.getLoc(), def_layout->generalizes(layout, vty.getShape(), target_shape)); + auto layout_product = + xla::Product(layout.tileArrayShape(vty.getShape(), target_shape)); + auto def_layout_product = + xla::Product(def_layout->tileArrayShape(vty.getShape(), target_shape)); + TPU_ASSERT_LOC(val.getLoc(), layout_product == def_layout_product); // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of // having `tileArrayShape` and `tileArrayImplicitShape`. SmallVector layout_shape = @@ -6324,11 +6329,50 @@ FailureOr> relayout(RewriteContext &ctx, if (src.generalizes(dst, vty.getShape(), target_shape)) { // A value with a replicated offset might use fewer vregs than a value with // a non-zero offset. - if (xla::Product(src.tileArrayShape(vty.getShape(), target_shape)) != - xla::Product(dst.tileArrayShape(vty.getShape(), target_shape))) { - return emitError(v.getLoc(), - "Not implemented: source layout is more general, but " - "vreg count changes"); + auto src_product = + xla::Product(src.tileArrayShape(vty.getShape(), target_shape)); + auto dst_product = + xla::Product(dst.tileArrayShape(vty.getShape(), target_shape)); + if (src_product != dst_product) { + TPU_ASSERT_LOC(v.getLoc(), dst_product > src_product); + auto src_offsets = src.offsets(); + + TPU_ASSERT_LOC(v.getLoc(), src_offsets != dst.offsets()); + TPU_ASSERT_LOC(v.getLoc(), src.bitwidth() == dst.bitwidth()); + + if (src.implicit_dim() != dst.implicit_dim()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and implicit dims are mismatched"); + } + + if (src.tiling() != dst.tiling()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and tiling are mismatched"); + } + + // This case is moving from a replicated to a non replicated layout. + // As such, we need to make a new destination shape that is the + // materialization of the src shape with replication. + FAILUREOR_ASSIGN_OR_RETURN(auto src_vregs, + disassemble(builder, src, v, target_shape, + /*use_implicit_shape=*/true)); + auto dst_vregs_shape = dst.tileArrayShape(vty.getShape(), target_shape); + xla::Array dst_vregs(dst_vregs_shape); + dst_vregs.Each([&](const absl::Span idx, Value *vreg) { + SmallVector local_idx(idx.begin(), idx.end()); + if (!src_offsets[0].has_value()) { + local_idx[local_idx.size() - 2] = 0; + } + if (!src_offsets[1].has_value()) { + local_idx[local_idx.size() - 1] = 0; + } + *vreg = src_vregs(local_idx); + }); + return assemble(builder, vty, dst, std::move(dst_vregs), target_shape, + /*use_implicit_shape=*/true) + .getResult(); } src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); return assemble(builder, vty, dst, std::move(src_tiles), target_shape, @@ -6411,8 +6455,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { if (vector_operand == nullptr) { continue; } - auto vty = vector_operand.getType(); - // The operand should always be an Operation (and not a BlockArgument) // since we expect the FuncOp to have only memrefs and semaphores as // arguments. @@ -6427,9 +6469,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); - if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) { - continue; - } OpBuilder builder(&op); FAILUREOR_ASSIGN_OR_RETURN( Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, From 05716b58b0d812bdd5af9769c10eeb026b5ee5af Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 14 Nov 2024 09:57:24 -0800 Subject: [PATCH 348/698] [sharding_in_types] Support shard_map with sharding in types. Right now only full manual mode is supported. This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`. In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh. PiperOrigin-RevId: 696559375 --- jax/_src/core.py | 29 ++++++++++++++++---- jax/_src/lax/lax.py | 14 +++++----- jax/_src/mesh.py | 50 ++++++++++++++++++++++++++++------- jax/_src/partition_spec.py | 20 ++++++++++++++ jax/_src/sharding_impls.py | 14 +--------- jax/experimental/shard_map.py | 25 +++++++++++++----- tests/pjit_test.py | 24 +++++++++++++++++ 7 files changed, 134 insertions(+), 42 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 96aecfde3a74..a1fcdac65df0 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1656,8 +1656,10 @@ def str_short(self, short_dtypes=False): self.dtype.name) dt_str = dt_str.replace('void', 'float0') if hasattr(self, 'sharding') and self.sharding is not None: - shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec)) - return f'{dt_str}[{shapestr}]' + shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) + axis_types = self.sharding.mesh.axis_types + axt = _get_axis_type_str(axis_types) if axis_types is not None else '' + return f'{dt_str}[{shapestr}]{axt}' else: shapestr = ','.join(map(str, self.shape)) return f'{dt_str}[{shapestr}]' @@ -1669,15 +1671,32 @@ def _len(self, ignored_tracer): raise TypeError("len() of unsized object") from err # same as numpy error +def _get_axis_type_str(axis_types): + from jax._src.mesh import AxisTypes # type: ignore + + out = [] + for t, axes in axis_types.items(): + a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes + if t == AxisTypes.Collective: + out.append(f"C:{a}") + elif t == AxisTypes.User: + out.append(f"U:{a}") + else: + assert t == AxisTypes.Auto + out.append(f"A:{a}") + return f"{{{', '.join(out)}}}" + def _get_shape_sharding_str(shape, spec): + out = [] for s1, s2 in zip(shape, spec): if s2 is None: - yield f"{s1}" + out.append(f"{s1}") elif isinstance(s2, tuple): ss = ','.join(s for s in s2) - yield f"{s1}@({ss})" + out.append(f"{s1}@({ss})") else: - yield f"{s1}@{s2}" + out.append(f"{s1}@{s2}") + return ','.join(out) def _get_abstract_sharding(val): from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7e15f46c3ef1..359640e69c5e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2203,14 +2203,13 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): for op, in_aval in zip(ops, in_avals): if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: out.append(op) + elif in_aval.sharding.mesh.are_all_axes_collective: + out.append(op) else: # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains # CompilerShardingAxis, then specify `unspecified_dims` via # `wrap_with_sharding_op`. - if config.use_shardy_partitioner.value: - sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim) - else: - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) return out @@ -2227,10 +2226,9 @@ def _nary_lower_hlo(op: Callable, ctx, out = op(*args) if config.sharding_in_types.value: - if config.use_shardy_partitioner.value: - out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim) - else: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + if aval_out.sharding.mesh.are_all_axes_collective: + return [out] + out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] else: return [out] diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 43791f2e5f72..082c443fade4 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -18,6 +18,7 @@ import collections from collections.abc import Hashable, Sequence import contextlib +import enum import functools import math import threading @@ -101,6 +102,12 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names) +class AxisTypes(enum.Enum): + Auto = enum.auto() + User = enum.auto() + Collective = enum.auto() + + _mesh_object_dict = {} # type: ignore @@ -157,9 +164,11 @@ class Mesh(contextlib.ContextDecorator): devices: np.ndarray axis_names: tuple[MeshAxisName, ...] + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None def __new__(cls, devices: np.ndarray | Sequence[xc.Device], - axis_names: str | Sequence[MeshAxisName]): + axis_names: str | Sequence[MeshAxisName], + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): if not isinstance(devices, np.ndarray): devices = np.array(devices) if isinstance(axis_names, str): @@ -175,7 +184,10 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - key = (axis_names, devices.shape, tuple(devices.flat)) + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) + key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple) val = _mesh_object_dict.get(key, None) if val is not None: return val @@ -184,11 +196,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], self.devices = devices.copy() self.devices.flags.writeable = False self.axis_names = axis_names + self.axis_types = axis_types + self._axis_types_tuple = axis_types_tuple _mesh_object_dict[key] = self return self def __reduce__(self): - return (type(self), (self.devices, self.axis_names)) + return (type(self), (self.devices, self.axis_names, self.axis_types)) def __eq__(self, other): if not isinstance(other, Mesh): @@ -199,12 +213,14 @@ def __eq__(self, other): return True return (self.axis_names == other.axis_names and self.devices.shape == other.devices.shape and + self._axis_types_tuple == other._axis_types_tuple and self._internal_device_list == other._internal_device_list) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.axis_names, self._internal_device_list, self.devices.shape)) + (self.axis_names, self._internal_device_list, self.devices.shape, + self._axis_types_tuple)) return self._hash def __setattr__(self, name, value): @@ -301,7 +317,8 @@ def __str__(self): def _repr(self): if self.empty: return "Mesh(device_ids=[], axis_names=())" - return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})" def __repr__(self): return self._repr @@ -313,7 +330,7 @@ def local_devices(self): @functools.cached_property def abstract_mesh(self): - return AbstractMesh(self.shape_tuple) + return AbstractMesh(self.shape_tuple, self.axis_types) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -338,25 +355,32 @@ class AbstractMesh: details. """ - def __init__(self, shape_tuple: tuple[tuple[str, int], ...]): + def __init__(self, shape_tuple: tuple[tuple[str, int], ...], + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): self.shape_tuple = shape_tuple + self.axis_types = axis_types if self.shape_tuple: self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) else: self._axis_names, self._axis_sizes = (), () + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + self._axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) def __hash__(self): - return hash(self.shape_tuple) + return hash((self.shape_tuple, self._axis_types_tuple)) def __eq__(self, other): if not isinstance(other, AbstractMesh): return False if id(self) == id(other): return True - return self.shape_tuple == other.shape_tuple + return (self.shape_tuple == other.shape_tuple and + self._axis_types_tuple == other._axis_types_tuple) def __repr__(self): - return f"AbstractMesh({self.shape_tuple})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"AbstractMesh({self.shape_tuple}{atr})" @property def axis_names(self): @@ -382,6 +406,12 @@ def _internal_device_list(self): def empty(self): return self.size == 0 + @functools.cached_property + def are_all_axes_collective(self) -> bool: + if self.axis_types is None: + return False + return all(t == AxisTypes.Collective for t in self.axis_types.keys()) + @property def devices(self): _raise_value_error("devices") diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 18e7d18d931d..f9bc2b60cee9 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + class _UnconstrainedPartitionSingleton: def __repr__(self): @@ -48,3 +50,21 @@ def __repr__(self): def __reduce__(self): return (PartitionSpec, tuple(self)) + + def _normalized_spec(self, ndim: int) -> PartitionSpec: + out = [] # type: ignore + for p in self: + if p is None: + out.append(None) + elif p == self.UNCONSTRAINED: + out.append(p) + elif isinstance(p, (list, tuple)): + if len(p) == 1: + out.append(p[0]) + else: + out.append(tuple(p)) + else: + out.append(p) + if len(out) < ndim: + out.extend([None] * (ndim - len(out))) + return PartitionSpec(*out) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index fa65bbe9328d..9b847f15d86a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -361,19 +361,7 @@ def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) def _normalized_spec(self, ndim: int) -> PartitionSpec: - out = [] # type: ignore - for p in self._parsed_pspec: - if p is None: - raise ValueError("UNCONSTRAINED is not supported yet.") - if not p: - out.append(None) - elif isinstance(p, tuple) and len(p) == 1: - out.append(p[0]) - else: - out.append(p) - if len(out) < ndim: - out.extend([None] * (ndim - len(out))) - return PartitionSpec(*out) + return self.spec._normalized_spec(ndim) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index edfcb031703f..9391d7ddf546 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -46,7 +46,7 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import AbstractMesh, Mesh +from jax._src.mesh import AbstractMesh, Mesh, AxisTypes from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, @@ -528,17 +528,30 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue raise NotImplementedError(f"Unsupported aval type: {type(aval)}") def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: + ) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + if config.sharding_in_types.value: + new_mesh = AbstractMesh( + mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names}) + new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim)) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array def _unshard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue,) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + if config.sharding_in_types.value: + spec = _names_to_pspec(names)._normalized_spec(aval.ndim) + new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8fe46c3b83e5..8a63bbe39099 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5201,6 +5201,30 @@ def f(x): self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_shard_map_full_manual(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh.are_all_axes_collective) + self.assertTrue(y.sharding.mesh.are_all_axes_collective) + return x * y + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', 'y'))(x, y) + self.assertEqual(z.sharding.spec, P('x', 'y')) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', 'y')) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp * np_inp) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From e6f6a8af8d2bd3bec601dfd029b06d2baecd6130 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 6 Nov 2024 10:43:17 -0800 Subject: [PATCH 349/698] Move Control Flow text from Sharp Bits into its own tutorial. --- README.md | 5 +- docs/control-flow.md | 361 ++++++++++++ docs/faq.rst | 2 +- docs/jit-compilation.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 607 +-------------------- docs/notebooks/Common_Gotchas_in_JAX.md | 328 +---------- docs/stateful-computations.md | 1 + docs/tutorials.rst | 1 + 8 files changed, 379 insertions(+), 928 deletions(-) create mode 100644 docs/control-flow.md diff --git a/README.md b/README.md index 89fe51212638..ce695dd6a26d 100644 --- a/README.md +++ b/README.md @@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like. Using `jit` puts constraints on the kind of Python control flow the function can use; see -the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT) +the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` @@ -369,7 +368,7 @@ Some standouts: and NumPy types aren't preserved, namely `np.add(1, np.array([2], np.float32)).dtype` is `float64` rather than `float32`. 1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). + flow](https://jax.readthedocs.io/en/latest/control-flow.html). You'll always get loud errors if something goes wrong. You might have to use [`jit`'s `static_argnums` parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), diff --git a/docs/control-flow.md b/docs/control-flow.md new file mode 100644 index 000000000000..04eb3cac8d24 --- /dev/null +++ b/docs/control-flow.md @@ -0,0 +1,361 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + ++++ {"id": "rg4CpMZ8c3ri"} + +(control-flow)= +# Control flow and logical operators with JIT + + + +When executing eagerly (outside of `jit`), JAX code works with Python control flow and logical operators just like Numpy code. Using control flow and logical operators with `jit` is more complicated. + +In a nutshell, Python control flow and logical operators are evaluated at JIT compile time, such that the compiled function represents a single path through the [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph) (logical operators affect the path via short-circuiting). If the path depends on the values of the inputs, the function (by default) cannot be JIT compiled. The path may depend on the shape or dtype of the inputs, and the function is re-compiled every time it is called on an input with a new shape or dtype. + +```{code-cell} +from jax import grad, jit +import jax.numpy as jnp +``` + +For example, this works: + +```{code-cell} +:id: OZ_BJX0CplNC +:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c + +@jit +def f(x): + for i in range(3): + x = 2 * x + return x + +print(f(3)) +``` + ++++ {"id": "22RzeJ4QqAuX"} + +So does this: + +```{code-cell} +:id: pinVnmRWp6w6 +:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422 + +@jit +def g(x): + y = 0. + for i in range(x.shape[0]): + y = y + x[i] + return y + +print(g(jnp.array([1., 2., 3.]))) +``` + ++++ {"id": "TStltU2dqf8A"} + +But this doesn't, at least by default: + +```{code-cell} +:id: 9z38AIKclRNM +:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac +:tags: [raises-exception] + +@jit +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +# This will fail! +f(2) +``` + +Neither does this: + +```{code-cell} +:tags: [raises-exception] + +@jit +def g(x): + return (x > 0) and (x < 3) + +# This will fail! +g(2) +``` + ++++ {"id": "pIbr4TVPqtDN"} + +__What gives!?__ + +When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation. + +For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. + +To get a view of your Python code that is valid for many different argument values, JAX traces it with the `ShapedArray` abstraction as input, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. + +But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace. + +The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnames` (or `static_argnums`) argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again: + +```{code-cell} +:id: -Tzp0H7Bt1Sn +:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a + +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +f = jit(f, static_argnames='x') + +print(f(2.)) +``` + ++++ {"id": "MHm1hIQAvBVs"} + +Here's another example, this time involving a loop: + +```{code-cell} +:id: iwY86_JKvD6b +:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937 + +def f(x, n): + y = 0. + for i in range(n): + y = y + x[i] + return y + +f = jit(f, static_argnames='n') + +f(jnp.array([2., 3., 4.]), 2) +``` + ++++ {"id": "nSPTOX8DvOeO"} + +In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation + ++++ {"id": "wWdg8LTYwCW3"} + +️⚠️ **functions with argument-__value__ dependent shapes** + +These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`. + +```{code-cell} +:id: Tqe9uLmUI_Gv +:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883 + +def example_fun(length, val): + return jnp.ones((length,)) * val +# un-jit'd works fine +print(example_fun(5, 4)) +``` + +```{code-cell} +:id: fOlR54XRgHpd +:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71 +:tags: [raises-exception] + +bad_example_jit = jit(example_fun) +# this will fail: +bad_example_jit(10, 4) +``` + +```{code-cell} +:id: kH0lOD4GgFyI +:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade + +# static_argnames tells JAX to recompile on changes at these argument positions: +good_example_jit = jit(example_fun, static_argnames='length') +# first compile +print(good_example_jit(10, 4)) +# recompiles +print(good_example_jit(5, 4)) +``` + ++++ {"id": "MStx_r2oKxpp"} + +`static_argnames` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! + +Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: + +```{code-cell} +:id: m2ABpRd8K094 +:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3 + +@jit +def f(x): + print(x) + y = 2 * x + print(y) + return y +f(2) +``` + ++++ {"id": "uCDcWG4MnVn-"} + +## Structured control flow primitives + +There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives: + + - `lax.cond` _differentiable_ + - `lax.while_loop` __fwd-mode-differentiable__ + - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static. + - `lax.scan` _differentiable_ + ++++ {"id": "Sd9xrLMXeK3A"} + +### `cond` +python equivalent: + +```python +def cond(pred, true_fun, false_fun, operand): + if pred: + return true_fun(operand) + else: + return false_fun(operand) +``` + +```{code-cell} +:id: SGxz9JOWeiyH +:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3 + +from jax import lax + +operand = jnp.array([0.]) +lax.cond(True, lambda x: x+1, lambda x: x-1, operand) +# --> array([1.], dtype=float32) +lax.cond(False, lambda x: x+1, lambda x: x-1, operand) +# --> array([-1.], dtype=float32) +``` + ++++ {"id": "lIYdn1woOS1n"} + +`jax.lax` provides two other functions that allow branching on dynamic predicates: + +- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is + like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays + rather than as functions. +- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is + like `lax.cond`, but allows switching between any number of callable choices. + +In addition, `jax.numpy` provides several numpy-style interfaces to these functions: + +- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with + three arguments is the numpy-style wrapper of `lax.select`. +- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) + is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. +- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has + an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather + than as functions. It is implemented in terms of multiple calls to `lax.select`. + ++++ {"id": "xkOFAw24eOMg"} + +### `while_loop` + +python equivalent: +``` +def while_loop(cond_fun, body_fun, init_val): + val = init_val + while cond_fun(val): + val = body_fun(val) + return val +``` + +```{code-cell} +:id: jM-D39a-c436 +:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e + +init_val = 0 +cond_fun = lambda x: x < 10 +body_fun = lambda x: x+1 +lax.while_loop(cond_fun, body_fun, init_val) +# --> array(10, dtype=int32) +``` + ++++ {"id": "apo3n3HAeQY_"} + +### `fori_loop` +python equivalent: +``` +def fori_loop(start, stop, body_fun, init_val): + val = init_val + for i in range(start, stop): + val = body_fun(i, val) + return val +``` + +```{code-cell} +:id: dt3tUpOmeR8u +:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380 + +init_val = 0 +start = 0 +stop = 10 +body_fun = lambda i,x: x+i +lax.fori_loop(start, stop, body_fun, init_val) +# --> array(45, dtype=int32) +``` + ++++ {"id": "SipXS5qiqk8e"} + +### Summary + +$$ +\begin{array} {r|rr} +\hline \ +\textrm{construct} +& \textrm{jit} +& \textrm{grad} \\ +\hline \ +\textrm{if} & ❌ & ✔ \\ +\textrm{for} & ✔* & ✔\\ +\textrm{while} & ✔* & ✔\\ +\textrm{lax.cond} & ✔ & ✔\\ +\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ +\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ +\textrm{lax.scan} & ✔ & ✔\\ +\hline +\end{array} +$$ + +
+ +$\ast$ = argument-value-independent loop condition - unrolls the loop + +
+ +## Logical operators + +`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`. + ++++ {"id": "izLTvT24dAq0"} + +## Python control flow + autodiff + +Remember that the above constraints on control flow and logical operators are relevant only with `jit`. If you just want to apply `grad` to your python functions, without `jit`, you can use regular Python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). + +```{code-cell} +:id: aAx0T3F8lLtu +:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52 + +def f(x): + if x < 3: + return 3. * x ** 2 + else: + return -4 * x + +print(grad(f)(2.)) # ok! +print(grad(f)(4.)) # ok! +``` diff --git a/docs/faq.rst b/docs/faq.rst index 1d2bb204f24c..44267f6f5f7d 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 51322fda9476..5e5be308068a 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -170,7 +170,7 @@ jax.jit(g)(10, 20) # Raises an error The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values. Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as `shape` or `dtype`, and not via their values. -For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow). +For more detail on the interaction between Python control flow and JAX, see {ref}`control-flow`. One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical. In that case, you can consider JIT-compiling only part of the function. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 71bd4527644a..92c736957db6 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -34,7 +34,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "from jax import grad, jit\n", + "from jax import jit\n", "from jax import lax\n", "from jax import random\n", "import jax\n", @@ -1175,610 +1175,14 @@ }, { "cell_type": "markdown", + "id": "1dc0e6b2", "metadata": { "id": "rg4CpMZ8c3ri" }, "source": [ - "## 🔪 Control flow" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "izLTvT24dAq0" - }, - "source": [ - "### ✔ Python control_flow + autodiff ✔\n", - "\n", - "If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "id": "aAx0T3F8lLtu", - "outputId": "383b7bfa-1634-4d23-8497-49cb9452ca52" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.0\n", - "-4.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "print(grad(f)(2.)) # ok!\n", - "print(grad(f)(4.)) # ok!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hIfPT7WMmZ2H" - }, - "source": [ - "### Python control flow + JIT\n", - "\n", - "Using control flow with `jit` is more complicated, and by default it has more constraints.\n", - "\n", - "This works:" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "id": "OZ_BJX0CplNC", - "outputId": "60c902a2-eba1-49d7-c8c8-2f68616d660c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "24\n" - ] - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " for i in range(3):\n", - " x = 2 * x\n", - " return x\n", - "\n", - "print(f(3))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "22RzeJ4QqAuX" - }, - "source": [ - "So does this:" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "id": "pinVnmRWp6w6", - "outputId": "25e06cf2-474f-4782-af7c-4f5514b64422" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "6.0\n" - ] - } - ], - "source": [ - "@jit\n", - "def g(x):\n", - " y = 0.\n", - " for i in range(x.shape[0]):\n", - " y = y + x[i]\n", - " return y\n", - "\n", - "print(g(jnp.array([1., 2., 3.])))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TStltU2dqf8A" - }, - "source": [ - "But this doesn't, at least by default:" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "id": "9z38AIKclRNM", - "outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ConcretizationTypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Tracedwith\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n" - ] - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "# This will fail!\n", - "f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pIbr4TVPqtDN" - }, - "source": [ - "__What gives!?__\n", - "\n", - "When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n", - "\n", - "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n", - "\n", - "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", - "\n", - "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", - "\n", - "But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n", - "\n", - "The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": { - "id": "-Tzp0H7Bt1Sn", - "outputId": "f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.0\n" - ] - } - ], - "source": [ - "def f(x):\n", - " if x < 3:\n", - " return 3. * x ** 2\n", - " else:\n", - " return -4 * x\n", - "\n", - "f = jit(f, static_argnums=(0,))\n", - "\n", - "print(f(2.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MHm1hIQAvBVs" - }, - "source": [ - "Here's another example, this time involving a loop:" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "id": "iwY86_JKvD6b", - "outputId": "48f9b51f-bd32-466f-eac1-cd23444ce937" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(5., dtype=float32)" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def f(x, n):\n", - " y = 0.\n", - " for i in range(n):\n", - " y = y + x[i]\n", - " return y\n", - "\n", - "f = jit(f, static_argnums=(1,))\n", - "\n", - "f(jnp.array([2., 3., 4.]), 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nSPTOX8DvOeO" - }, - "source": [ - "In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wWdg8LTYwCW3" - }, - "source": [ - "️⚠️ **functions with argument-__value__ dependent shapes**\n", - "\n", - "These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`." - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "id": "Tqe9uLmUI_Gv", - "outputId": "989be121-dfce-4bb3-c78e-a10829c5f883" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[4. 4. 4. 4. 4.]\n" - ] - } - ], - "source": [ - "def example_fun(length, val):\n", - " return jnp.ones((length,)) * val\n", - "# un-jit'd works fine\n", - "print(example_fun(5, 4))" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": { - "id": "fOlR54XRgHpd", - "outputId": "cf31d798-a4ce-4069-8e3e-8f9631ff4b71", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "TypeError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got (Tracedwith,).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n" - ] - } - ], - "source": [ - "bad_example_jit = jit(example_fun)\n", - "# this will fail:\n", - "bad_example_jit(10, 4)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": { - "id": "kH0lOD4GgFyI", - "outputId": "d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n", - "[4. 4. 4. 4. 4.]\n" - ] - } - ], - "source": [ - "# static_argnums tells JAX to recompile on changes at these argument positions:\n", - "good_example_jit = jit(example_fun, static_argnums=(0,))\n", - "# first compile\n", - "print(good_example_jit(10, 4))\n", - "# recompiles\n", - "print(good_example_jit(5, 4))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MStx_r2oKxpp" - }, - "source": [ - "`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!\n", - "\n", - "Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": { - "id": "m2ABpRd8K094", - "outputId": "4f7ebe17-ade4-4e18-bd8c-4b24087c33c3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tracedwith\n", - "Tracedwith\n" - ] - }, - { - "data": { - "text/plain": [ - "Array(4, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "@jit\n", - "def f(x):\n", - " print(x)\n", - " y = 2 * x\n", - " print(y)\n", - " return y\n", - "f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uCDcWG4MnVn-" - }, - "source": [ - "### Structured control flow primitives\n", - "\n", - "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n", - "\n", - " - `lax.cond` _differentiable_\n", - " - `lax.while_loop` __fwd-mode-differentiable__\n", - " - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.\n", - " - `lax.scan` _differentiable_" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sd9xrLMXeK3A" - }, - "source": [ - "#### `cond`\n", - "python equivalent:\n", - "\n", - "```python\n", - "def cond(pred, true_fun, false_fun, operand):\n", - " if pred:\n", - " return true_fun(operand)\n", - " else:\n", - " return false_fun(operand)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": { - "id": "SGxz9JOWeiyH", - "outputId": "942a8d0e-5ff6-4702-c499-b3941f529ca3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([-1.], dtype=float32)" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import lax\n", - "\n", - "operand = jnp.array([0.])\n", - "lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n", - "# --> array([1.], dtype=float32)\n", - "lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n", - "# --> array([-1.], dtype=float32)" - ] - }, - { - "cell_type": "markdown", - "id": "e6622244", - "metadata": { - "id": "lIYdn1woOS1n" - }, - "source": [ - "`jax.lax` provides two other functions that allow branching on dynamic predicates:\n", - "\n", - "- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is\n", - " like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays\n", - " rather than as functions.\n", - "- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is\n", - " like `lax.cond`, but allows switching between any number of callable choices.\n", - "\n", - "In addition, `jax.numpy` provides several numpy-style interfaces to these functions:\n", - "\n", - "- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with\n", - " three arguments is the numpy-style wrapper of `lax.select`.\n", - "- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)\n", - " is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.\n", - "- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has\n", - " an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather\n", - " than as functions. It is implemented in terms of multiple calls to `lax.select`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xkOFAw24eOMg" - }, - "source": [ - "#### `while_loop`\n", - "\n", - "python equivalent:\n", - "```\n", - "def while_loop(cond_fun, body_fun, init_val):\n", - " val = init_val\n", - " while cond_fun(val):\n", - " val = body_fun(val)\n", - " return val\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": { - "id": "jM-D39a-c436", - "outputId": "552fe42f-4d32-4e25-c8c2-b951160a3f4e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(10, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "init_val = 0\n", - "cond_fun = lambda x: x < 10\n", - "body_fun = lambda x: x+1\n", - "lax.while_loop(cond_fun, body_fun, init_val)\n", - "# --> array(10, dtype=int32)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "apo3n3HAeQY_" - }, - "source": [ - "#### `fori_loop`\n", - "python equivalent:\n", - "```\n", - "def fori_loop(start, stop, body_fun, init_val):\n", - " val = init_val\n", - " for i in range(start, stop):\n", - " val = body_fun(i, val)\n", - " return val\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "id": "dt3tUpOmeR8u", - "outputId": "7819ca7c-1433-4d85-b542-f6159b0e8380" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(45, dtype=int32, weak_type=True)" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "init_val = 0\n", - "start = 0\n", - "stop = 10\n", - "body_fun = lambda i,x: x+i\n", - "lax.fori_loop(start, stop, body_fun, init_val)\n", - "# --> array(45, dtype=int32)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SipXS5qiqk8e" - }, - "source": [ - "#### Summary\n", - "\n", - "$$\n", - "\\begin{array} {r|rr}\n", - "\\hline \\\n", - "\\textrm{construct}\n", - "& \\textrm{jit}\n", - "& \\textrm{grad} \\\\\n", - "\\hline \\\n", - "\\textrm{if} & ❌ & ✔ \\\\\n", - "\\textrm{for} & ✔* & ✔\\\\\n", - "\\textrm{while} & ✔* & ✔\\\\\n", - "\\textrm{lax.cond} & ✔ & ✔\\\\\n", - "\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n", - "\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n", - "\\textrm{lax.scan} & ✔ & ✔\\\\\n", - "\\hline\n", - "\\end{array}\n", - "$$\n", - "\n", - "
\n", - "\n", - "$\\ast$ = argument-value-independent loop condition - unrolls the loop\n", + "## 🔪 Control flow\n", "\n", - "
" + "Moved to {ref}`control-flow`." ] }, { @@ -2209,6 +1613,9 @@ " ```\n", " This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n", "\n", + "## 🔪 Sharp bits covered in tutorials\n", + "- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.\n", + "- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions.\n", "\n", "## Fin.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 741fa3af063c..00955de236e7 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -31,7 +31,7 @@ JAX works great for many numerical and scientific programs, but __only if they a :id: GoK_PCxPeYcy import numpy as np -from jax import grad, jit +from jax import jit from jax import lax from jax import random import jax @@ -536,328 +536,7 @@ for subkey in subkeys: ## 🔪 Control flow -+++ {"id": "izLTvT24dAq0"} - -### ✔ Python control_flow + autodiff ✔ - -If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). - -```{code-cell} ipython3 -:id: aAx0T3F8lLtu -:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52 - -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -print(grad(f)(2.)) # ok! -print(grad(f)(4.)) # ok! -``` - -+++ {"id": "hIfPT7WMmZ2H"} - -### Python control flow + JIT - -Using control flow with `jit` is more complicated, and by default it has more constraints. - -This works: - -```{code-cell} ipython3 -:id: OZ_BJX0CplNC -:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c - -@jit -def f(x): - for i in range(3): - x = 2 * x - return x - -print(f(3)) -``` - -+++ {"id": "22RzeJ4QqAuX"} - -So does this: - -```{code-cell} ipython3 -:id: pinVnmRWp6w6 -:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422 - -@jit -def g(x): - y = 0. - for i in range(x.shape[0]): - y = y + x[i] - return y - -print(g(jnp.array([1., 2., 3.]))) -``` - -+++ {"id": "TStltU2dqf8A"} - -But this doesn't, at least by default: - -```{code-cell} ipython3 -:id: 9z38AIKclRNM -:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac -:tags: [raises-exception] - -@jit -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -# This will fail! -f(2) -``` - -+++ {"id": "pIbr4TVPqtDN"} - -__What gives!?__ - -When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation. - -For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. - -To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. - -By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. - -But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace. - -The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again: - -```{code-cell} ipython3 -:id: -Tzp0H7Bt1Sn -:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a - -def f(x): - if x < 3: - return 3. * x ** 2 - else: - return -4 * x - -f = jit(f, static_argnums=(0,)) - -print(f(2.)) -``` - -+++ {"id": "MHm1hIQAvBVs"} - -Here's another example, this time involving a loop: - -```{code-cell} ipython3 -:id: iwY86_JKvD6b -:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937 - -def f(x, n): - y = 0. - for i in range(n): - y = y + x[i] - return y - -f = jit(f, static_argnums=(1,)) - -f(jnp.array([2., 3., 4.]), 2) -``` - -+++ {"id": "nSPTOX8DvOeO"} - -In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation - -+++ {"id": "wWdg8LTYwCW3"} - -️⚠️ **functions with argument-__value__ dependent shapes** - -These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`. - -```{code-cell} ipython3 -:id: Tqe9uLmUI_Gv -:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883 - -def example_fun(length, val): - return jnp.ones((length,)) * val -# un-jit'd works fine -print(example_fun(5, 4)) -``` - -```{code-cell} ipython3 -:id: fOlR54XRgHpd -:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71 -:tags: [raises-exception] - -bad_example_jit = jit(example_fun) -# this will fail: -bad_example_jit(10, 4) -``` - -```{code-cell} ipython3 -:id: kH0lOD4GgFyI -:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade - -# static_argnums tells JAX to recompile on changes at these argument positions: -good_example_jit = jit(example_fun, static_argnums=(0,)) -# first compile -print(good_example_jit(10, 4)) -# recompiles -print(good_example_jit(5, 4)) -``` - -+++ {"id": "MStx_r2oKxpp"} - -`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! - -Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: - -```{code-cell} ipython3 -:id: m2ABpRd8K094 -:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3 - -@jit -def f(x): - print(x) - y = 2 * x - print(y) - return y -f(2) -``` - -+++ {"id": "uCDcWG4MnVn-"} - -### Structured control flow primitives - -There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives: - - - `lax.cond` _differentiable_ - - `lax.while_loop` __fwd-mode-differentiable__ - - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static. - - `lax.scan` _differentiable_ - -+++ {"id": "Sd9xrLMXeK3A"} - -#### `cond` -python equivalent: - -```python -def cond(pred, true_fun, false_fun, operand): - if pred: - return true_fun(operand) - else: - return false_fun(operand) -``` - -```{code-cell} ipython3 -:id: SGxz9JOWeiyH -:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3 - -from jax import lax - -operand = jnp.array([0.]) -lax.cond(True, lambda x: x+1, lambda x: x-1, operand) -# --> array([1.], dtype=float32) -lax.cond(False, lambda x: x+1, lambda x: x-1, operand) -# --> array([-1.], dtype=float32) -``` - -+++ {"id": "lIYdn1woOS1n"} - -`jax.lax` provides two other functions that allow branching on dynamic predicates: - -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is - like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays - rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is - like `lax.cond`, but allows switching between any number of callable choices. - -In addition, `jax.numpy` provides several numpy-style interfaces to these functions: - -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with - three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) - is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has - an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather - than as functions. It is implemented in terms of multiple calls to `lax.select`. - -+++ {"id": "xkOFAw24eOMg"} - -#### `while_loop` - -python equivalent: -``` -def while_loop(cond_fun, body_fun, init_val): - val = init_val - while cond_fun(val): - val = body_fun(val) - return val -``` - -```{code-cell} ipython3 -:id: jM-D39a-c436 -:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e - -init_val = 0 -cond_fun = lambda x: x < 10 -body_fun = lambda x: x+1 -lax.while_loop(cond_fun, body_fun, init_val) -# --> array(10, dtype=int32) -``` - -+++ {"id": "apo3n3HAeQY_"} - -#### `fori_loop` -python equivalent: -``` -def fori_loop(start, stop, body_fun, init_val): - val = init_val - for i in range(start, stop): - val = body_fun(i, val) - return val -``` - -```{code-cell} ipython3 -:id: dt3tUpOmeR8u -:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380 - -init_val = 0 -start = 0 -stop = 10 -body_fun = lambda i,x: x+i -lax.fori_loop(start, stop, body_fun, init_val) -# --> array(45, dtype=int32) -``` - -+++ {"id": "SipXS5qiqk8e"} - -#### Summary - -$$ -\begin{array} {r|rr} -\hline \ -\textrm{construct} -& \textrm{jit} -& \textrm{grad} \\ -\hline \ -\textrm{if} & ❌ & ✔ \\ -\textrm{for} & ✔* & ✔\\ -\textrm{while} & ✔* & ✔\\ -\textrm{lax.cond} & ✔ & ✔\\ -\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ -\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ -\textrm{lax.scan} & ✔ & ✔\\ -\hline -\end{array} -$$ - -
- -$\ast$ = argument-value-independent loop condition - unrolls the loop - -
+Moved to {ref}`control-flow`. +++ {"id": "OxLsZUyRt_kF"} @@ -1145,6 +824,9 @@ Many such cases are discussed in detail in the sections above; here we list seve ``` This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa. +## 🔪 Sharp bits covered in tutorials +- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators. +- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions. ## Fin. diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 2ff82e0431e2..fe84fc0d7f0a 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -12,6 +12,7 @@ kernelspec: name: python3 --- +(stateful-computations)= # Stateful computations diff --git a/docs/tutorials.rst b/docs/tutorials.rst index a31517155e1a..c9c2fdb1dcc7 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -16,6 +16,7 @@ Tutorials working-with-pytrees sharded-computation stateful-computations + control-flow .. toctree:: :maxdepth: 1 From d823f1720dccf1e58d37ca91111950c4384efa02 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 12 Nov 2024 11:51:55 -0800 Subject: [PATCH 350/698] jnp.logaddexp2: simplify implementation --- jax/_src/numpy/ufuncs.py | 39 ++------------------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 93e116fa4b6a..a844ecbc28ac 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2630,16 +2630,6 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) -def _wrap_between(x, _a): - """Wraps `x` between `[-a, a]`.""" - a = _constant_like(x, _a) - two_a = _constant_like(x, 2 * _a) - zero = _constant_like(x, 0) - rem = lax.rem(lax.add(x, a), two_a) - rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) - return lax.sub(rem, a) - - @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. @@ -2668,33 +2658,8 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array(True, dtype=bool) """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) - return _logaddexp2(x1, x2) - - -@custom_jvp -def _logaddexp2(x1, x2): - amax = lax.max(x1, x2) - if dtypes.issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(lax._isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), - _constant_like(x1, np.log(2))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) - - -@_logaddexp2.defjvp -def _logaddexp2_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) - primal_out = logaddexp2(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out + ln2 = float(np.log(2)) + return logaddexp(x1 * ln2, x2 * ln2) / ln2 @partial(jit, inline=True) From d0f36666ff9f9ae8847b0ca645b6f1ce581907e9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 14 Nov 2024 11:52:21 -0800 Subject: [PATCH 351/698] Update array-api-tests commit --- .github/workflows/jax-array-api.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 942034169e09..84dda34752f0 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻 + ref: '4bbe6be32c6995772f8f46a6ef050ba766581104' # Latest commit as of 2024-11-14 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} From 34d9633b12a1886fcd4e68b42fd8f448d4820a66 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Thu, 14 Nov 2024 13:57:33 -0600 Subject: [PATCH 352/698] Add commit to see if it triggers CI --- .github/workflows/rocm-nightly-upstream-sync.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index e915ccba390d..e8cb5f480313 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -1,5 +1,5 @@ # Pulls the latest changes from upstream into main and opens a PR to merge -# them into rocm-main. +# them into rocm-main branch. name: ROCm Nightly Upstream Sync on: From 1f114b1cf79803462dba76bdb3a9576a6018f618 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Thu, 14 Nov 2024 15:23:26 -0500 Subject: [PATCH 353/698] Add numpy.put_along_axis. --- CHANGELOG.md | 1 + docs/jax.numpy.rst | 1 + jax/_src/numpy/lax_numpy.py | 101 +++++++++++++++++++++++++++++++++++- jax/_src/test_util.py | 25 +++++++++ jax/_src/util.py | 4 ++ jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 2 + tests/lax_numpy_test.py | 42 ++++++++++++++- 8 files changed, 174 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17d15c740b7e..b0b64ac71fd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.tree_util.register_dataclass` now allows metadata fields to be declared inline via {func}`dataclasses.field`. See the function documentation for examples. + * Added {func}`jax.numpy.put_along_axis`. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 3922c92d98de..30553a360155 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -337,6 +337,7 @@ namespace; they are listed below. promote_types ptp put + put_along_axis quantile r_ rad2deg diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b90004e19932..3ff38f16b38a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2) + ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map @@ -11433,6 +11433,105 @@ def replace(tup, val): mode="fill" if mode is None else mode, fill_value=fill_value) +_indices = indices # argument below named 'indices' shadows the function + + +def _make_along_axis_idx(shape, indices, axis): + return tuple_replace(_indices(shape, sparse=True), axis, indices) + + +@partial(jit, static_argnames=('axis', 'inplace', 'mode')) +def put_along_axis( + arr: ArrayLike, + indices: ArrayLike, + values: ArrayLike, + axis: int | None, + inplace: bool = True, + *, + mode: str | None = None, +) -> Array: + """Put values into the destination array by matching 1d index and data slices. + + JAX implementation of :func:`numpy.put_along_axis`. + + The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which + is not possible for JAX's immutable arrays. The JAX version returns a modified + copy of the input, and adds the ``inplace`` parameter which must be set to + `False`` by the user as a reminder of this API difference. + + Args: + arr: array into which values will be put. + indices: array of indices at which to put values. + values: array of values to put into the array. + axis: the axis along which to put values. If not specified, the array will + be flattened before indexing is applied. + inplace: must be set to False to indicate that the input is not modified + in-place, but rather a modified copy is returned. + mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options, + see :attr:`jax.numpy.ndarray.at`. + + Returns: + A copy of ``a`` with specified entries updated. + + See Also: + - :func:`jax.numpy.put`: put elements into an array at given indices. + - :func:`jax.numpy.place`: place elements into an array via boolean mask. + - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing. + - :func:`jax.numpy.take`: extract values from an array at given indices. + - :func:`jax.numpy.take_along_axis`: extract values from an array along an axis. + + Examples: + >>> from jax import numpy as jnp + >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) + >>> i = jnp.argmax(a, axis=1, keepdims=True) + >>> print(i) + [[1] + [0]] + >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) + >>> print(b) + [[10 99 20] + [99 40 50]] + """ + if inplace: + raise ValueError( + "jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays" + "are immutable. Pass inplace=False to instead return an updated array.") + + util.check_arraylike("put_along_axis", arr, indices, values) + arr = asarray(arr) + indices = asarray(indices) + values = asarray(values) + + original_axis = axis + original_arr_shape = arr.shape + + if axis is None: + arr = arr.ravel() + axis = 0 + + if not arr.ndim == indices.ndim: + raise ValueError( + "put_along_axis arguments 'arr' and 'indices' must have same ndim. Got " + f"{arr.ndim=} and {indices.ndim=}." + ) + + try: + values = broadcast_to(values, indices.shape) + except ValueError: + raise ValueError( + "put_along_axis argument 'values' must be broadcastable to 'indices'. Got " + f"{values.shape=} and {indices.shape=}." + ) + + idx = _make_along_axis_idx(arr.shape, indices, axis) + result = arr.at[idx].set(values, mode=mode) + + if original_axis is None: + result = result.reshape(original_arr_shape) + + return result + + ### Indexing def _is_integer_index(idx: Any) -> bool: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 78de511d4ec4..e546ebd2a0f3 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -965,6 +965,31 @@ def fn(shape, dtype): size=shape, replace=False) return fn +def rand_indices_unique_along_axis(rng): + """Sample an array of given shape containing indices up to dim (exclusive), + such that the indices are unique along the given axis. + Optionally, convert some of the resulting indices to negative indices.""" + def fn(dim, shape, axis, allow_negative=True): + batch_size = math.prod(shape[:axis] + shape[axis:][1:]) + idx = [ + rng.choice(dim, size=shape[axis], replace=False) + for _ in range(batch_size) + ] + idx = np.array(idx).reshape(batch_size, shape[axis]) + idx = idx.reshape(shape[:axis] + shape[axis:][1:] + (shape[axis],)) + idx = np.moveaxis(idx, -1, axis) + + # assert that indices are unique along the given axis + count = partial(np.bincount, minlength=dim) + assert (np.apply_along_axis(count, axis, idx) <= 1).all() + + if allow_negative: + mask = rng.choice([False, True], idx.shape) + idx[mask] -= dim + return idx + + return fn + def rand_bool(rng): def generator(shape, dtype): return _cast_to_shape( diff --git a/jax/_src/util.py b/jax/_src/util.py index fce342c493ed..8dcc5eaa5804 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -453,6 +453,10 @@ def tuple_update(t, idx, val): assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] +def tuple_replace(tupl, index, item): + # unlike tuple_update, works with negative indices as well + return tupl[:index] + (item,) + tupl[index:][1:] + class HashableFunction: """Decouples function equality and hash from its identity. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9be73e96adcf..2ab0a0e3d1ab 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -202,6 +202,7 @@ printoptions as printoptions, promote_types as promote_types, put as put, + put_along_axis as put_along_axis, ravel as ravel, ravel_multi_index as ravel_multi_index, repeat as repeat, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d391abd46e13..339174136234 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -742,6 +742,8 @@ def ptp(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ...) -> Array: ... def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = ..., *, inplace: builtins.bool = ...) -> Array: ... +def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, + axis: int | None, inplace: bool = True, *, mode: str | None = None) -> Array: ... def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7c2728af415e..a1817f528f27 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,7 +51,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning +from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace config.parse_flags_with_absl() @@ -5962,6 +5962,45 @@ def np_fun(a, i, v): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis) + for a_shape in nonempty_array_shapes + for axis in list(range(-len(a_shape), len(a_shape))) + for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)] + for v_shape in [(), (1,), i_shape] + ] + [ + dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None) + for a_shape in nonempty_array_shapes + for i_shape in [(J,) for J in range(math.prod(a_shape) + 1)] + for v_shape in [(), (1,), i_shape] + ], + dtype=jtu.dtypes.all, + mode=[None, "promise_in_bounds", "clip"], + ) + def testPutAlongAxis(self, a_shape, i_shape, v_shape, axis, dtype, mode): + a_rng = jtu.rand_default(self.rng()) + if axis is None: + size = math.prod(a_shape) + else: + size = a_shape[axis] + i_rng = jtu.rand_indices_unique_along_axis(self.rng()) + + def args_maker(): + a = a_rng(a_shape, dtype) + i = i_rng(dim=size, shape=i_shape, axis=0 if axis is None else axis) + v = a_rng(v_shape, dtype) + return a, i, v + + def np_fun(a, i, v): + a_copy = a.copy() + np.put_along_axis(a_copy, i, v, axis=axis) + return a_copy + + jnp_fun = partial(jnp.put_along_axis, axis=axis, inplace=False, mode=mode) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + def test_rot90_error(self): with self.assertRaisesRegex( ValueError, @@ -6229,7 +6268,6 @@ def testWrappedSignaturesMatch(self): 'nditer', 'nested_iters', 'poly1d', - 'put_along_axis', 'putmask', 'real_if_close', 'recarray', From 4a3e1155b9ae4d3c941f3aacbb12c4cf43cdf059 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 14 Nov 2024 13:07:15 -0800 Subject: [PATCH 354/698] cleanup: delete unused argument from internal reduction helper --- jax/_src/numpy/reductions.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 08f11d0cb6ad..5acad86eabef 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -82,7 +82,7 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: ReductionOp = Callable[[Any, Any], Any] -def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike, +def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, *, has_identity: bool = True, preproc: Callable[[ArrayLike], ArrayLike] | None = None, bool_op: ReductionOp | None = None, @@ -215,7 +215,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: - return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, + return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric, bool_op=lax.bitwise_or, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.psum, @@ -301,7 +301,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: - return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric, + return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric, bool_op=lax.bitwise_and, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) @@ -386,7 +386,7 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, + return _reduction(a, "max", lax.max, -np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) @@ -468,7 +468,7 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None, def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, + return _reduction(a, "min", lax.min, np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) @@ -548,7 +548,7 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, + return _reduction(a, "all", lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @@ -604,7 +604,7 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, + return _reduction(a, "any", lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @@ -664,7 +664,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: arr = lax_internal.asarray(a) init_val = np.array(-1, dtype=dtype or arr.dtype) - return _reduction(arr, name="reduce_bitwise_and", np_fun=None, op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, + return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -673,7 +673,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_bitwise_or", np_fun=None, op=lax.bitwise_or, init_val=0, preproc=_require_integer, + return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -682,7 +682,7 @@ def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_bitwise_xor", np_fun=None, op=lax.bitwise_xor, init_val=0, preproc=_require_integer, + return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -691,7 +691,7 @@ def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_logical_and", np_fun=None, op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool, + return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -700,7 +700,7 @@ def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_logical_or", np_fun=None, op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool, + return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -709,7 +709,7 @@ def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: - return _reduction(a, name="reduce_logical_xor", np_fun=None, op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool, + return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) From c40d405e439ac782eb47949dd94362aeed29bc00 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 14 Nov 2024 14:03:58 -0800 Subject: [PATCH 355/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ecdba3f23b20e684c5e67a5ddb4f004de724f6df. PiperOrigin-RevId: 696642961 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fdb6b1607816..043b9d019eb1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "2a7890387f812c17fb5f17eec961ee52ac3e059d" -XLA_SHA256 = "cfe1eebc643355f55e6422451cbd750ac6a7f096ed8d6a0605238e4d8ce6d0d1" +XLA_COMMIT = "ecdba3f23b20e684c5e67a5ddb4f004de724f6df" +XLA_SHA256 = "bfb87208d43324cdb20e03c9802360a580062b913e975b1470148dd99dfbb0d1" def repo(): tf_http_archive( From 41a0493e56154d66cda48fac20a2e6c1e7b13a50 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 12 Nov 2024 20:34:18 -0800 Subject: [PATCH 356/698] Add shard map replication rule for ffi_call. --- jax/experimental/shard_map.py | 60 +++++++++++++++++++---------------- tests/extend_test.py | 16 ++++++++++ tests/shard_map_test.py | 10 ++++++ 3 files changed, 59 insertions(+), 27 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 9391d7ddf546..4ad248c17ee2 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -51,6 +51,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) +from jax._src.extend import ffi from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, @@ -1290,30 +1291,38 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): @register_check(control_flow.conditionals.cond_p) def _cond_rule(mesh, *in_rep, branches): _, *args_rep = in_rep - true_out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) - false_out_rep = _check_rep(mesh, branches[1].jaxpr, args_rep) - if not true_out_rep == false_out_rep: - raise Exception("The true and false branches of cond produced mismatched " - f"replication types {true_out_rep} and {false_out_rep}. " - "Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return true_out_rep + out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) + for branch in branches[1:]: + out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) + if not out_rep_ == out_rep: + raise Exception("The branches of cond produced mismatched replication " + "types. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_rep=False argument " + "to shard_map") + return out_rep @register_rewrite(control_flow.conditionals.cond_p) def _cond_rewrite(mesh, in_rep, *args, branches): pred_rep, *args_rep = in_rep - _, true_out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) - _, false_out_rep = _replication_rewrite_nomatch(mesh, branches[1], args_rep) - out_rep = map(op.and_, true_out_rep, false_out_rep) + _, out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) + for branch in branches[1:]: + _, out_rep_ = _replication_rewrite_nomatch(mesh, branch, args_rep) + if out_rep: + out_rep = map(op.and_, out_rep, out_rep_) + else: + out_rep = out_rep_ out_rep = map(partial(op.and_, pred_rep), out_rep) - branches_ = ( - _replication_rewrite_match(mesh, branches[0], args_rep, out_rep), - _replication_rewrite_match(mesh, branches[1], args_rep, out_rep), - ) + branches_ = tuple(_replication_rewrite_match(mesh, branch, args_rep, out_rep) + for branch in branches) out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) return out_vals, out_rep +@register_check(control_flow.conditionals.platform_index_p) +def _platform_index_rule(mesh, *_, **__): + return set(mesh.axis_names) +register_norewrite(control_flow.conditionals.platform_index_p) + @register_rewrite(core.closed_call_p) def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep) @@ -1363,20 +1372,17 @@ def fwd_jaxpr_thunk_(*zeros): def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) - -# TODO(mattjj): make standard_check handle multiple outputs, share code @register_check(control_flow.solves.linear_solve_p) -def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs): - in_rep_ = [r for r in in_rep if r is not None] - assert in_rep - if not in_rep_[:-1] == in_rep_[1:]: - msg = ("shard_map check_rep rewrite failed. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a workaround pass the " - "check_rep=False argument to shard_map") - raise Exception(msg) - return [in_rep_[0]] * len(jaxprs.solve.out_avals) +def _linear_solve_check(mesh, *in_rep, jaxprs, **_): + out_rep = _standard_check(control_flow.solves.linear_solve_p, mesh, *in_rep) + return [out_rep] * len(jaxprs.solve.out_avals) register_standard_rewrite(control_flow.solves.linear_solve_p) +@register_check(ffi.ffi_call_p) +def _ffi_call_check(mesh, *in_rep, result_avals, **_): + out_rep = _standard_check(ffi.ffi_call_p, mesh, *in_rep) + return [out_rep] * len(result_avals) +register_standard_rewrite(ffi.ffi_call_p) del _check_rules[lax.tie_p] diff --git a/tests/extend_test.py b/tests/extend_test.py index b4af8bc23e16..a59c94eab5fc 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -24,6 +24,7 @@ from jax import lax import jax.extend as jex import jax.numpy as jnp +import jax.sharding as shd from jax._src import abstract_arrays from jax._src import api @@ -38,6 +39,7 @@ from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal +from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() @@ -342,6 +344,20 @@ def testInvalidResultType(self): ValueError, "All elements of result_shape_dtypes.*position 1"): jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))() + @jtu.run_on_devices("gpu", "cpu") + def testShardMap(self): + mesh = jtu.create_mesh((1,), ("i",)) + x = self.rng().randn(8, 4, 5).astype(np.float32) + + @partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'), + out_specs=shd.PartitionSpec('i')) + def f(x): + return ffi_call_geqrf(x) + + f(x) # eager mode doesn't crash + jax.jit(f)(x) # neither does JIT + self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) + def ffi_call_geqrf(x, **kwargs): if jtu.test_device_matches(["cpu"]): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index df24315ce110..84017bab5122 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1050,6 +1050,16 @@ def f(a): a = jnp.array([True, False]) shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + def test_switch_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + x = jnp.arange(4) + + def f(n, x, y): + return jax.lax.switch( + n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y) + + shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) + def test_eager_custom_jvp_basic(self): @jax.custom_jvp def foo(x): From a115b2cec508787ebf94061daa5d62feefd60cb3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 14 Nov 2024 16:05:30 -0800 Subject: [PATCH 357/698] Update array-api-tests commit --- .github/workflows/jax-array-api.yml | 2 +- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 84dda34752f0..763a4c04be5d 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '4bbe6be32c6995772f8f46a6ef050ba766581104' # Latest commit as of 2024-11-14 + ref: 'a3f3f376308e64f0ac15b307dfe27be945409e41' # Latest commit as of 2024-11-14 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index 6e625e708d7a..73e1c51fc8af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ filterwarnings = [ # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", - "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", "default:The .* method is good for exploring strategies.*", # NOTE: this is probably not where you want to add code to suppress a From 9a0e9e55d81e8ea1b1fd2fa4eaf67074f5908bec Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 14 Nov 2024 17:31:16 -0800 Subject: [PATCH 358/698] [sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in `wrap_with_sharding_op`. Also lower shardings with `Collective` axes correctly to HloSharding. PiperOrigin-RevId: 696703030 --- jax/_src/interpreters/mlir.py | 14 ++++++++++++ jax/_src/lax/lax.py | 41 ++++++++++------------------------- jax/_src/lax/parallel.py | 40 +++++++++++++++++++--------------- jax/_src/mesh.py | 19 ++++++++++++++++ jax/_src/sharding_impls.py | 7 ++++-- tests/pjit_test.py | 26 ++++++++++++++++++++++ 6 files changed, 99 insertions(+), 48 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bef465c6aa75..ee3c929b26f7 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2474,6 +2474,20 @@ def _wrap_with_spmd_op(name: str, wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape") +def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): + if sharding_proto is None: + proto = aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + else: + proto = sharding_proto + # TODO(yashkatariya): Setting all axes as unspecified should work even when + # any axes is Collective because that's what happens in partial auto shmap. + # Do that after tests for it exists. + unspecified_dims = (set(range(aval.ndim)) + if aval.sharding.mesh.are_all_axes_collective else None) + return wrap_with_sharding_op( + ctx, op, aval, proto, unspecified_dims=unspecified_dims) + + def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c45d8f5c80b2..b780aab870e9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2203,14 +2203,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): for op, in_aval in zip(ops, in_avals): if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: out.append(op) - elif in_aval.sharding.mesh.are_all_axes_collective: - out.append(op) else: - # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains - # CompilerShardingAxis, then specify `unspecified_dims` via - # `wrap_with_sharding_op`. - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() - out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) + proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto)) return out @@ -2226,10 +2221,7 @@ def _nary_lower_hlo(op: Callable, ctx, out = op(*args) if config.sharding_in_types.value: - if aval_out.sharding.mesh.are_all_axes_collective: - return [out] - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] else: return [out] @@ -2646,8 +2638,7 @@ def _integer_pow_lowering(ctx, x, *, y): out, = lowering(ctx, x, y=y) if config.sharding_in_types.value: aval_out, = ctx.avals_out - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) @@ -3029,8 +3020,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, if config.sharding_in_types.value: if sharding is not None: assert aval_out.sharding == sharding - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -3765,8 +3755,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): if config.sharding_in_types.value: if out_type is not None: assert aval_out.sharding == out_type - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp) + result = mlir.lower_sharding_under_shit(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) return [result] @@ -4231,8 +4220,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, if config.sharding_in_types.value: if sharding is not None: assert sharding == aval_out.sharding - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, @@ -4645,8 +4633,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) out = mlir.reshape(ctx, x, aval_out) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] def _reshape_staging_rule( @@ -4726,8 +4713,7 @@ def _transpose_lower(ctx, x, *, permutation): permutation = [*permutation, *trailing_dims] out = hlo.transpose(x, mlir.dense_int_array(permutation)) if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] transpose_p = standard_primitive( @@ -4868,8 +4854,7 @@ def _select_hlo_lowering_opaque(ctx, which, *cases): def _add_shit_to_select(ctx, op, aval_out): if config.sharding_in_types.value: - proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return mlir.wrap_with_sharding_op(ctx, op, aval_out, proto) + return mlir.lower_sharding_under_shit(ctx, op, aval_out) return op def _select_hlo_lowering(ctx, which, *cases): @@ -5241,8 +5226,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): with ir.InsertionPoint(reducer_region): hlo.return_([reducer(*reducer_region.arguments)]) if config.sharding_in_types.value: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, op.result, aval_out, out_sp)] + return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)] return op.results mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp, @@ -5941,8 +5925,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding): out = mlir.iota(ctx, aval_out, dimension=dimension) if config.sharding_in_types.value: assert aval_out.sharding == sharding - proto = sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] mlir.register_lowering(iota_p, _iota_lower) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 3a1c1ef3bcf1..c8cea6a9df5b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,9 +24,11 @@ from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes -from jax._src import sharding_impls +from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, + NamedSharding, PartitionSpec as P) from jax._src.core import AxisName, ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -635,9 +637,15 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") - out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), - arg.dtype) for arg in args] + if config.sharding_in_types.value: + out_avals = [ + ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) + for arg in args + ] + else: + out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype) + for arg in args] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _check_axis_names(axes): @@ -673,10 +681,7 @@ def _positional_reduce(aval, arg): _replica_groups(ctx.module_context.axis_env, named_axes, axis_index_groups)) axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) def all_reduce(aval, x): if is_spmd: @@ -694,7 +699,11 @@ def all_reduce(aval, x): else: op = hlo.AllReduceOp( [x.type], [x], replica_groups=replica_groups, **other_args) - scalar_aval = core.ShapedArray((), aval.dtype) + if config.sharding_in_types.value: + scalar_aval = core.ShapedArray( + (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P())) + else: + scalar_aval = core.ShapedArray((), aval.dtype) scalar_type = mlir.aval_to_ir_type(scalar_aval) reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_block): @@ -778,7 +787,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): axis_context = ctx.module_context.axis_context is_manual = ( - isinstance(axis_context, sharding_impls.SPMDAxisContext) + isinstance(axis_context, SPMDAxisContext) and axis_context.manual_axes ) if is_manual: @@ -896,7 +905,7 @@ def _all_to_all_lowering( raise ValueError('Replica groups must be equally sized') is_spmd = isinstance( ctx.module_context.axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -1129,10 +1138,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, x_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), - ) + is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext)) if not tiled: new_shape = list(x_aval.shape) new_shape.insert(all_gather_dimension, 1) @@ -1260,7 +1266,7 @@ def _reduce_scatter_lowering( axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique @@ -1489,7 +1495,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_context = ctx.module_context.axis_context is_spmd = isinstance( axis_context, - (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + (SPMDAxisContext, ShardingContext), ) if is_spmd: device_id = hlo.partition_id() diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 082c443fade4..6c6017c4b2b7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -107,6 +107,17 @@ class AxisTypes(enum.Enum): User = enum.auto() Collective = enum.auto() +def axis_names_to_types(axis_types) -> dict[str, AxisTypes]: + if axis_types is None: + return {} + d = {} + for t, names in axis_types.items(): + if isinstance(names, tuple): + for n in names: + d[n] = t + else: + d[names] = t + return d _mesh_object_dict = {} # type: ignore @@ -269,6 +280,10 @@ def shape_tuple(self): def axis_sizes(self) -> tuple[int, ...]: return self.devices.shape + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @property def size(self): return math.prod(self.shape.values()) if self.devices.ndim else 0 @@ -390,6 +405,10 @@ def axis_names(self): def axis_sizes(self) -> tuple[int, ...]: return self._axis_sizes + @functools.cached_property + def _name_to_type(self): + return axis_names_to_types(self.axis_types) + @functools.cached_property def size(self): return math.prod(self._axis_sizes) if self._axis_sizes else 0 diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 9b847f15d86a..8957a6186339 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -137,9 +137,12 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - if self._manual_axes: + mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() + if t == mesh_lib.AxisTypes.Collective} + manual_axes = self._manual_axes.union(mesh_manual_axes) + if manual_axes: axis_names = self.mesh.axis_names - for manual_axis in self._manual_axes: + for manual_axis in manual_axes: special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL replicated_mesh_axes = [] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8a63bbe39099..7196a6335960 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5225,6 +5225,32 @@ def f(x, y): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + def test_shard_map_dot(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh.are_all_axes_collective) + self.assertTrue(y.sharding.mesh.are_all_axes_collective) + allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) + z = x @ allgatherd_y + return jax.lax.psum(z, axis_name='y') + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', None))(x, y) + self.assertEqual(z.sharding.spec, P('x', None)) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', None)) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From f652b6ad6aa44e586ee8989a39ca95b63205cec3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 15 Nov 2024 06:03:54 -0800 Subject: [PATCH 359/698] Set __module__ attribute for objects in jax.numpy --- jax/_src/dtypes.py | 3 + jax/_src/numpy/index_tricks.py | 12 +- jax/_src/numpy/lax_numpy.py | 188 +++++++++++++++++++++++++++++++- jax/_src/numpy/polynomial.py | 14 +++ jax/_src/numpy/reductions.py | 40 ++++++- jax/_src/numpy/setops.py | 14 ++- jax/_src/numpy/ufunc_api.py | 5 +- jax/_src/numpy/ufuncs.py | 114 +++++++++++++++++++ jax/_src/numpy/vectorize.py | 5 +- tests/package_structure_test.py | 11 +- 10 files changed, 396 insertions(+), 10 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f5b0c3fd68b1..1c5e285ba08a 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -343,6 +343,7 @@ def _issubclass(a: Any, b: Any) -> bool: # TODO(jakevdp): consider whether to disallow None here. We allow it # because np.issubdtype allows it (and treats it as equivalent to float64). +@set_module('jax.numpy') def issubdtype(a: DTypeLike | ExtendedDType | None, b: DTypeLike | ExtendedDType | None) -> bool: """Returns True if first argument is a typecode lower/equal in type hierarchy. @@ -458,6 +459,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, } +@set_module('jax.numpy') def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool: """Returns a boolean indicating whether a provided dtype is of a specified kind. @@ -650,6 +652,7 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy "JAX's internal logic; please report it to the JAX maintainers." ) +@set_module('jax.numpy') def promote_types(a: DTypeLike, b: DTypeLike) -> DType: """Returns the type to which a binary operation should cast its arguments. diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index 90a17000cf16..ec67d7489f30 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -24,10 +24,14 @@ arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose ) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module import numpy as np +export = set_module('jax.numpy') + + __all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] @@ -87,7 +91,7 @@ def __getitem__(self, key: slice | tuple[slice, ...]) -> Array: return stack(output_arr, 0) -mgrid = _Mgrid() +mgrid = export(_Mgrid()) class _Ogrid: @@ -129,7 +133,7 @@ def __getitem__( return meshgrid(*output, indexing='ij', sparse=True) -ogrid = _Ogrid() +ogrid = export(_Ogrid()) _IndexType = Union[ArrayLike, str, slice] @@ -279,7 +283,7 @@ class RClass(_AxisConcat): op_name = "r_" -r_ = RClass() +r_ = export(RClass()) class CClass(_AxisConcat): @@ -327,7 +331,7 @@ class CClass(_AxisConcat): op_name = "c_" -c_ = CClass() +c_ = export(CClass()) s_ = np.s_ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4cf37f6f7d67..4c261d11196a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,13 +68,16 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, subvals,unzip2, tuple_replace) + ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, + tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) from jax.tree_util import tree_flatten, tree_leaves, tree_map import numpy as np import opt_einsum +export = set_module('jax.numpy') + for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']: try: cuda_plugin_extension = importlib.import_module( @@ -116,6 +119,7 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: printoptions = np.printoptions set_printoptions = np.set_printoptions +@export def iscomplexobj(x: Any) -> bool: """Check if the input is a complex number or an array containing complex elements. @@ -327,6 +331,7 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: return clip(val, min_val, max_val).astype(dtype) +@export def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: """Load JAX arrays from npy files. @@ -376,6 +381,7 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> ### implementations of numpy functions in terms of lax +@export @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise minimum of the input arrays. @@ -427,6 +433,7 @@ def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise maximum of the input arrays. @@ -476,6 +483,7 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) +@export def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: """Return True if arg1 is equal or lower than arg2 in the type hierarchy. @@ -522,6 +530,7 @@ def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: return dtypes.issubdtype(arg1, arg2) +@export def isscalar(element: Any) -> bool: """Return True if the input is a scalar. @@ -620,6 +629,7 @@ def isscalar(element: Any) -> bool: iterable = np.iterable +@export def result_type(*args: Any) -> DType: """Return the result of applying JAX promotion rules to the inputs. @@ -663,6 +673,7 @@ def result_type(*args: Any) -> DType: return dtypes.result_type(*args) +@export @jit def trunc(x: ArrayLike) -> Array: """Round input to the nearest integer towards zero. @@ -739,6 +750,7 @@ def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, return result[0, 0, out_order] +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision: PrecisionLike = None, @@ -814,6 +826,7 @@ def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *, precision=precision, preferred_element_type=preferred_element_type) +@export @partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type')) def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision: PrecisionLike = None, @@ -899,6 +912,7 @@ def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *, precision=precision, preferred_element_type=preferred_element_type) +@export def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, range: None | Array | Sequence[ArrayLike] = None, weights: ArrayLike | None = None) -> Array: @@ -950,6 +964,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10, return linspace(range[0], range[1], bins_int + 1, dtype=dtype) +@export def histogram(a: ArrayLike, bins: ArrayLike = 10, range: Sequence[ArrayLike] | None = None, weights: ArrayLike | None = None, @@ -1031,6 +1046,7 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, return counts, bin_edges +@export def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1120,6 +1136,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = return hist, edges[0], edges[1] +@export def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, @@ -1229,6 +1246,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, return hist, bin_edges_by_dim +@export def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. @@ -1307,6 +1325,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: return lax.transpose(a, axes_) +@export def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: """Permute the axes/dimensions of an array. @@ -1336,6 +1355,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: return lax.transpose(a, axes) +@export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose the last two dimensions of an array. @@ -1389,6 +1409,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return lax.transpose(x, axes) +@export @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: """Rotate an array by 90 degrees counterclockwise in the plane specified by axes. @@ -1472,6 +1493,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: return flip(transpose(m, perm), ax2) +@export def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Reverse the order of elements of an array along the given axis. @@ -1539,6 +1561,7 @@ def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis]) +@export def fliplr(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 1. @@ -1565,6 +1588,7 @@ def fliplr(m: ArrayLike) -> Array: return _flip(asarray(m), 1) +@export def flipud(m: ArrayLike) -> Array: """Reverse the order of elements of an array along axis 0. @@ -1590,6 +1614,8 @@ def flipud(m: ArrayLike) -> Array: util.check_arraylike("flipud", m) return _flip(asarray(m), 0) + +@export @jit def iscomplex(x: ArrayLike) -> Array: """Return boolean array showing where the input is complex. @@ -1613,6 +1639,8 @@ def iscomplex(x: ArrayLike) -> Array: i = ufuncs.imag(x) return lax.ne(i, _lax_const(i, 0)) + +@export @jit def isreal(x: ArrayLike) -> Array: """Return boolean array showing where the input is real. @@ -1637,6 +1665,7 @@ def isreal(x: ArrayLike) -> Array: return lax.eq(i, _lax_const(i, 0)) +@export @partial(jit, static_argnames=['deg']) def angle(z: ArrayLike, deg: bool = False) -> Array: """Return the angle of a complex valued number or array. @@ -1688,6 +1717,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: return ufuncs.degrees(result) if deg else result +@export @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, @@ -1800,6 +1830,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, return arr +@export @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: @@ -1862,6 +1893,8 @@ def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype)))) return result + +@export @partial(jit, static_argnames=("axis", "edge_order")) def gradient( f: ArrayLike, @@ -1992,6 +2025,7 @@ def gradient_along_axis(a, h, axis): return a_grad[0] if len(axis_tuple) == 1 else a_grad +@export def isrealobj(x: Any) -> bool: """Check if the input is not a complex number or an array containing complex elements. @@ -2026,6 +2060,7 @@ def isrealobj(x: Any) -> bool: return not iscomplexobj(x) +@export def reshape( a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), @@ -2129,6 +2164,7 @@ def reshape( return asarray(a).reshape(shape, order=order) +@export @partial(jit, static_argnames=('order',), inline=True) def ravel(a: ArrayLike, order: str = "C") -> Array: """Flatten array into a 1-dimensional shape. @@ -2182,6 +2218,7 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: return reshape(a, (size(a),), order) +@export def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = 'raise', order: str = 'C') -> Array: """Convert multi-dimensional indices into flat indices. @@ -2273,6 +2310,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], return result +@export def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: """Convert flat indices into multi-dimensional indices. @@ -2336,6 +2374,7 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: for s, i in safe_zip(shape, out_indices)) +@export @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: """Return a new array with specified shape. @@ -2387,6 +2426,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: return reshape(arr, new_shape) +@export def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: """Remove one or more length-1 axes from array @@ -2457,6 +2497,7 @@ def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: return lax.squeeze(a, axis) +@export def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: """Insert dimensions of length 1 into array @@ -2527,6 +2568,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: return lax.expand_dims(a, axis) +@export @partial(jit, static_argnames=('axis1', 'axis2'), inline=True) def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: """Swap two axes of an array. @@ -2574,6 +2616,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: return lax.transpose(a, list(perm)) +@export def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: """Move an array axis to a new position @@ -2639,6 +2682,7 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) - return lax.transpose(a, perm) +@export @partial(jit, static_argnames=('equal_nan',)) def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -2783,6 +2827,7 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f +@export def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, @@ -2865,6 +2910,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, ) -> Array | tuple[Array, ...]: ... +@export def where(condition, x=None, y=None, /, *, size=None, fill_value=None): """Select elements from two arrays based on a condition. @@ -2940,6 +2986,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): return util._where(condition, x, y) +@export def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], @@ -3007,6 +3054,7 @@ def select( return lax.select_n(*broadcast_arrays(idx, *choicelist)) +@export def bincount(x: ArrayLike, weights: ArrayLike | None = None, minlength: int = 0, *, length: int | None = None ) -> Array: @@ -3099,6 +3147,7 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ... +@export def broadcast_shapes(*shapes): """Broadcast input shapes to a common output shape. @@ -3139,6 +3188,7 @@ def broadcast_shapes(*shapes): return lax.broadcast_shapes(*shapes) +@export def broadcast_arrays(*args: ArrayLike) -> list[Array]: """Broadcast arrays to a common shape. @@ -3178,6 +3228,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: return util._broadcast_arrays(*args) +@export def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: """Broadcast an array to a specified shape. @@ -3254,6 +3305,7 @@ def _split(op: str, ary: ArrayLike, for start, end in zip(split_indices[:-1], split_indices[1:])] +@export def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3317,6 +3369,7 @@ def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, return _split("split", ary, indices_or_sections, axis=axis) +@export def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays vertically. @@ -3351,6 +3404,7 @@ def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("vsplit", ary, indices_or_sections, axis=0) +@export def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays horizontally. @@ -3391,6 +3445,7 @@ def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1) +@export def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: """Split an array into sub-arrays depth-wise. @@ -3432,6 +3487,7 @@ def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) return _split("dsplit", ary, indices_or_sections, axis=2) +@export def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: """Split an array into sub-arrays. @@ -3457,6 +3513,7 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array return _split("array_split", ary, indices_or_sections, axis=axis) +@export @jit def clip( arr: ArrayLike | None = None, @@ -3528,6 +3585,7 @@ def clip( return asarray(arr) +@export @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Round input evenly to the given number of decimals. @@ -3599,12 +3657,14 @@ def _round_float(x: ArrayLike) -> Array: return _round_float(a) +@export @partial(jit, static_argnames=('decimals',)) def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: """Alias of :func:`jax.numpy.round`""" return round(a, decimals, out) +@export @jit def fix(x: ArrayLike, out: None = None) -> Array: """Round input to the nearest integer towards zero. @@ -3643,6 +3703,7 @@ def fix(x: ArrayLike, out: None = None) -> Array: return where(lax.ge(x, zero), ufuncs.floor(x), ufuncs.ceil(x)) +@export @jit def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, posinf: ArrayLike | None = None, @@ -3708,6 +3769,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0, return out +@export @partial(jit, static_argnames=('equal_nan',)) def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array: @@ -3756,6 +3818,7 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, return reductions.all(isclose(a, b, rtol, atol, equal_nan)) +@export def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: @@ -3863,6 +3926,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, return out +@export def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array: """Return indices of nonzero elements in a flattened array @@ -3908,6 +3972,7 @@ def flatnonzero(a: ArrayLike, *, size: int | None = None, return nonzero(ravel(a), size=size, fill_value=fill_value)[0] +@export @partial(jit, static_argnames=('axis',)) def unwrap(p: ArrayLike, discont: ArrayLike | None = None, axis: int = -1, period: ArrayLike = 2 * pi) -> Array: @@ -4337,6 +4402,7 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") +@export def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: """Add padding to an array. @@ -4493,6 +4559,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], ### Array-creation functions +@export def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: """Join arrays along a new axis. @@ -4559,6 +4626,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], return concatenate(new_arrays, axis=axis, dtype=dtype) +@export @partial(jit, static_argnames="axis") def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: """Unstack an array along an axis. @@ -4599,6 +4667,8 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: ) return tuple(moveaxis(x, axis, 0)) + +@export def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: """Construct an array by repeating ``A`` along specified dimensions. @@ -4662,6 +4732,7 @@ def _concatenate_array(arr: ArrayLike, axis: int | None, return lax.reshape(arr, shape, dimensions) +@export def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: """Join arrays along an existing axis. @@ -4725,6 +4796,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] +@export def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: """Join arrays along an existing axis. @@ -4765,6 +4837,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: return jax.numpy.concatenate(arrays, axis=axis) +@export def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Vertically stack arrays. @@ -4825,6 +4898,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0, dtype=dtype) +@export def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Horizontally stack arrays. @@ -4885,6 +4959,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) +@export def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: """Stack arrays depth-wise. @@ -4945,6 +5020,7 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=2, dtype=dtype) +@export def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: """Stack arrays column-wise. @@ -5005,6 +5081,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: return concatenate(arrs, axis=1) +@export def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: """Construct an array by stacking slices of choice arrays. @@ -5129,6 +5206,7 @@ def _block(xs: ArrayLike | list[ArrayLike]) -> tuple[Array, int]: return asarray(xs), 1 +@export @jit def block(arrays: ArrayLike | list[ArrayLike]) -> Array: """Create an array from a list of blocks. @@ -5212,6 +5290,7 @@ def atleast_1d(x: ArrayLike, /) -> Array: @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 1 dimension. @@ -5266,6 +5345,7 @@ def atleast_2d(x: ArrayLike, /) -> Array: @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 2 dimensions. @@ -5329,6 +5409,7 @@ def atleast_3d(x: ArrayLike, /) -> Array: @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +@export @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: """Convert inputs to arrays with at least 3 dimensions. @@ -5405,6 +5486,7 @@ def _supports_buffer_protocol(obj): return True +@export def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5597,6 +5679,7 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x +@export def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: @@ -5662,6 +5745,7 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result +@export def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None, device: xc.Device | Sharding | None = None) -> Array: @@ -5743,6 +5827,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) +@export def copy(a: ArrayLike, order: str | None = None) -> Array: """Return a copy of the array. @@ -5791,6 +5876,7 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: return array(a, copy=True, order=order) +@export def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5833,6 +5919,7 @@ def zeros_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5875,6 +5962,7 @@ def ones_like(a: ArrayLike | DuckTypedArray, return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) +@export def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -5924,6 +6012,7 @@ def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | No return device +@export def full(shape: Any, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -5972,6 +6061,7 @@ def full(shape: Any, fill_value: ArrayLike, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, shape: Any = None, *, @@ -6028,6 +6118,7 @@ def full_like(a: ArrayLike | DuckTypedArray, broadcast_to(asarray(fill_value, dtype=dtype), shape), device) +@export def zeros(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of zeros. @@ -6064,6 +6155,7 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def ones(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an array full of ones. @@ -6100,6 +6192,7 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) +@export def empty(shape: Any, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: """Create an empty array. @@ -6143,6 +6236,7 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore "with a single tuple argument for the shape?") +@export def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: """Check if two arrays are element-wise equal. @@ -6184,6 +6278,7 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: return reductions.all(eq) +@export def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: """Check if two arrays are element-wise equal. @@ -6224,6 +6319,7 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. +@export def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: r"""Convert a buffer into a 1-D JAX array. @@ -6271,6 +6367,7 @@ def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) +@export def fromfile(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromfile. @@ -6289,6 +6386,7 @@ def fromfile(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def fromiter(*args, **kwargs): """Unimplemented JAX wrapper for jnp.fromiter. @@ -6307,6 +6405,7 @@ def fromiter(*args, **kwargs): "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") +@export def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None) -> Array: """Construct a JAX array via DLPack. @@ -6367,6 +6466,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, return from_dlpack(x, device=device, copy=copy) +@export def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = float, **kwargs) -> Array: """Create an array from a function applied over indices. @@ -6453,6 +6553,7 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) +@export def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: """Convert a string of text into 1-D JAX array. @@ -6481,6 +6582,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) +@export def eye(N: DimSize, M: DimSize | None = None, k: int | ArrayLike = 0, dtype: DTypeLike | None = None, @@ -6560,6 +6662,7 @@ def _eye(N: DimSize, M: DimSize | None = None, return (i + offset == j).astype(dtype) +@export def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: """Create a square identity matrix @@ -6593,6 +6696,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: return eye(n, dtype=dtype) +@export def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, step: ArrayLike | None = None, dtype: DTypeLike | None = None, *, device: xc.Device | Sharding | None = None) -> Array: @@ -6760,6 +6864,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@export def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, @@ -6885,6 +6990,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return (result, delta) if retstep else result +@export def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: @@ -6970,6 +7076,7 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(ufuncs.power(base, lin), dtype) +@export def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: """Generate geometrically-spaced values. @@ -7044,6 +7151,7 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) +@export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: """Construct N-dimensional grid arrays from N 1-dimensional vectors. @@ -7125,6 +7233,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output +@export @jit def i0(x: ArrayLike) -> Array: r"""Calculate modified Bessel function of first kind, zeroth order. @@ -7174,6 +7283,7 @@ def _i0_jvp(primals, tangents): primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) +@export def ix_(*args: ArrayLike) -> tuple[Array, ...]: """Return a multi-dimensional grid (open mesh) from N one-dimensional sequences. @@ -7237,6 +7347,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @overload def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: ... +@export def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) -> Array | tuple[Array, ...]: """Generate arrays of grid indices. @@ -7287,6 +7398,7 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, return stack(output, 0) if output else array([], dtype=dtype) +@export def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None) -> Array: """Construct an array from repeated elements. @@ -7431,6 +7543,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, return take(a, gather_indices, axis=axis) +@export @partial(jit, static_argnames=('axis',)) def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array: @@ -7490,6 +7603,7 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) +@export def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None) -> Array: r"""Return an array with ones on and below the diagonal and zeros elsewhere. @@ -7546,6 +7660,7 @@ def tri(N: int, M: int | None = None, k: int = 0, dtype: DTypeLike | None = None return lax_internal._tri(dtype, (N, M), k) +@export @partial(jit, static_argnames=('k',)) def tril(m: ArrayLike, k: int = 0) -> Array: r"""Return lower triangle of an array. @@ -7607,6 +7722,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) +@export @partial(jit, static_argnames=('k',)) def triu(m: ArrayLike, k: int = 0) -> Array: r"""Return upper triangle of an array. @@ -7672,6 +7788,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) +@export @partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -7737,6 +7854,7 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int return reductions.sum(a, axis=(-2, -1), dtype=dtype) +@export def mask_indices(n: int, mask_func: Callable[[ArrayLike, int], Array], k: int = 0, *, size: int | None = None) -> tuple[Array, Array]: @@ -7796,6 +7914,7 @@ def _triu_size(n, m, k): return mk * (mk + 1) // 2 + mk * (m - k - mk) +@export def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of upper triangle of an array of size ``(n, m)``. @@ -7854,6 +7973,7 @@ def triu_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array]: """Return the indices of lower triangle of an array of size ``(n, m)``. @@ -7912,6 +8032,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array return i, j +@export def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. @@ -7969,6 +8090,7 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. @@ -8026,6 +8148,7 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) +@export def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, inplace: bool = True) -> Array: """Return a copy of the array with the diagonal overwritten. @@ -8107,6 +8230,7 @@ def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: bool = False, *, return a.at[idx].set(val if val.ndim == 0 else _tile_to_size(val.ravel(), n)) +@export def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a multidimensional array. @@ -8142,6 +8266,8 @@ def diag_indices(n: int, ndim: int = 2) -> tuple[Array, ...]: .format(ndim)) return (lax.iota(int_, n),) * ndim + +@export def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: """Return indices for accessing the main diagonal of a given array. @@ -8183,6 +8309,8 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: return diag_indices(s[0], ndim=nd) + +@export @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: @@ -8234,6 +8362,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, return a[..., i, j] if offset >= 0 else a[..., j, i] +@export def diag(v: ArrayLike, k: int = 0) -> Array: """Returns the specified diagonal or constructs a diagonal array. @@ -8297,6 +8426,8 @@ def _diag(v, k): else: raise ValueError("diag input must be 1d or 2d") + +@export def diagflat(v: ArrayLike, k: int = 0) -> Array: """Return a 2-D array with the flattened input array laid out on the diagonal. @@ -8353,6 +8484,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: # TODO(jakevdp): add support for N-dimensional inputs as in NumPy v2.2 +@export def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: """Trim leading and/or trailing zeros of the input array. @@ -8407,6 +8539,8 @@ def trim_zeros_tol(filt, tol, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] + +@export @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None @@ -8461,6 +8595,7 @@ def append( return concatenate([arr, values], axis=axis) +@export def delete( arr: ArrayLike, obj: ArrayLike | slice, @@ -8585,6 +8720,7 @@ def delete( return a[tuple(slice(None) for i in range(axis)) + (mask,)] +@export def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = None) -> Array: """Insert entries into an array at specified indices. @@ -8684,6 +8820,7 @@ def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, return out +@export def apply_along_axis( func1d: Callable, axis: int, arr: ArrayLike, *args, **kwargs ) -> Array: @@ -8761,6 +8898,7 @@ def apply_along_axis( return func(arr) +@export def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, axes: Sequence[int]) -> Array: """Apply a function repeatedly over specified axes. @@ -8819,6 +8957,7 @@ def apply_over_axes(func: Callable[[ArrayLike, int], Array], a: ArrayLike, ### Tensor contraction operations +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -8908,6 +9047,7 @@ def dot(a: ArrayLike, b: ArrayLike, *, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def matmul(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -9031,6 +9171,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def vdot( a: ArrayLike, b: ArrayLike, *, @@ -9079,6 +9220,7 @@ def vdot( preferred_element_type=preferred_element_type) +@export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -9134,6 +9276,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, signature="(n),(n)->()")(x1_arr, x2_arr) +@export def tensordot(a: ArrayLike, b: ArrayLike, axes: int | Sequence[int] | Sequence[Sequence[int]] = 2, *, precision: PrecisionLike = None, @@ -9279,6 +9422,7 @@ def einsum( out_type=None, ) -> Array: ... +@export def einsum( subscripts, /, *operands, @@ -9554,6 +9698,7 @@ def einsum_path( optimize: bool | str | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... +@export def einsum_path( subscripts, /, *operands, @@ -9787,6 +9932,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): output_weak_type) +@export @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) def inner( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, @@ -9843,6 +9989,7 @@ def inner( preferred_element_type=preferred_element_type) +@export @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """Compute the outer product of two arrays. @@ -9877,6 +10024,7 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: return ravel(a)[:, None] * ravel(b)[None, :] +@export @partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis')) def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int | None = None): @@ -9977,6 +10125,7 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) +@export @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: """Compute the Kronecker product of two input arrays. @@ -10022,6 +10171,7 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) +@export @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False @@ -10085,6 +10235,7 @@ def vander( ### Misc +@export def argwhere( a: ArrayLike, *, @@ -10150,6 +10301,7 @@ def argwhere( return result.reshape(result.shape[0], ndim(a)) +@export def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the maximum value of an array. @@ -10205,6 +10357,7 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: """Return the index of the minimum value of an array. @@ -10260,6 +10413,7 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result +@export def nanargmax( a: ArrayLike, axis: int | None = None, @@ -10327,6 +10481,7 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export def nanargmin( a: ArrayLike, axis: int | None = None, @@ -10387,6 +10542,7 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, @@ -10450,6 +10606,7 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result +@export @jit def sort_complex(a: ArrayLike) -> Array: """Return a sorted copy of complex array. @@ -10487,6 +10644,7 @@ def sort_complex(a: ArrayLike) -> Array: return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) +@export @partial(jit, static_argnames=('axis',)) def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array: """Sort a sequence of keys in lexicographic order. @@ -10564,6 +10722,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1] +@export @partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, @@ -10644,6 +10803,7 @@ def argsort( return lax.rev(indices, dimensions=[dimension]) if descending else indices +@export @partial(jit, static_argnames=['kth', 'axis']) def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns a partially-sorted copy of an array. @@ -10714,6 +10874,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: return swapaxes(out, -1, axis) +@export @partial(jit, static_argnames=['kth', 'axis']) def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: """Returns indices that partially sort an array. @@ -10818,6 +10979,8 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: dimension=ax) return a + +@export def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: """Roll the elements of an array along a specified axis. @@ -10871,6 +11034,7 @@ def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], return _roll_static(arr, shift, axis) +@export @partial(jit, static_argnames=('axis', 'start')) def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: """Roll the specified axis to a given position. @@ -10936,6 +11100,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: return moveaxis(a, axis, start) +@export @partial(jit, static_argnames=('axis', 'bitorder')) def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array: """Pack array of bits into a uint8 array. @@ -11020,6 +11185,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar return swapaxes(packed, axis, -1) +@export @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, @@ -11111,6 +11277,7 @@ def unpackbits( return swapaxes(unpacked, axis, -1) +@export def take( a: ArrayLike, indices: ArrayLike, @@ -11268,6 +11435,7 @@ def _normalize_index(index, axis_size): return lax.select(index < 0, lax.add(index, axis_size_val), index) +@export @partial(jit, static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, @@ -11462,6 +11630,7 @@ def _make_along_axis_idx(shape, indices, axis): return tuple_replace(_indices(shape, sparse=True), axis, indices) +@export @partial(jit, static_argnames=('axis', 'inplace', 'mode')) def put_along_axis( arr: ArrayLike, @@ -12206,6 +12375,7 @@ def clamp_index(i: DimSize, which: str): return start, step, slice_size +@export def blackman(M: int) -> Array: """Return a Blackman window of size M. @@ -12236,6 +12406,7 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) +@export def bartlett(M: int) -> Array: """Return a Bartlett window of size M. @@ -12266,6 +12437,7 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) +@export def hamming(M: int) -> Array: """Return a Hamming window of size M. @@ -12296,6 +12468,7 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) +@export def hanning(M: int) -> Array: """Return a Hanning window of size M. @@ -12326,6 +12499,7 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) +@export def kaiser(M: int, beta: ArrayLike) -> Array: """Return a Kaiser window of size M. @@ -12368,6 +12542,8 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) + +@export @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the greatest common divisor of two arrays. @@ -12414,6 +12590,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd +@export @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: """Compute the least common multiple of two arrays. @@ -12461,6 +12638,7 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: ufuncs.multiply(x1, ufuncs.floor_divide(x2, d))) +@export def extract(condition: ArrayLike, arr: ArrayLike, *, size: int | None = None, fill_value: ArrayLike = 0) -> Array: """Return the elements of an array that satisfy a condition. @@ -12522,6 +12700,7 @@ def extract(condition: ArrayLike, arr: ArrayLike, return compress(ravel(condition), ravel(arr), size=size, fill_value=fill_value) +@export def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, *, size: int | None = None, fill_value: ArrayLike = 0, out: None = None) -> Array: """Compress an array along a given axis using a boolean condition. @@ -12616,6 +12795,7 @@ def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = None, return moveaxis(result, 0, axis) +@export @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, bias: bool = False, ddof: int | None = None, @@ -12774,6 +12954,7 @@ def cov(m: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True, return ufuncs.true_divide(dot(X, X_T.conj()), f).squeeze() +@export @partial(jit, static_argnames=('rowvar',)) def corrcoef(x: ArrayLike, y: ArrayLike | None = None, rowvar: bool = True) -> Array: r"""Compute the Pearson correlation coefficients. @@ -12903,6 +13084,7 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt return comparisons.sum(dtype=dtype, axis=0) +@export @partial(jit, static_argnames=('side', 'method')) def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: @@ -12992,6 +13174,7 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', return impl(asarray(a), asarray(v), side, dtype) # type: ignore +@export @partial(jit, static_argnames=('right', 'method')) def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str | None = None) -> Array: @@ -13047,6 +13230,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, ) +@export def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], funclist: list[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: @@ -13154,6 +13338,7 @@ def _tile_to_size(arr: Array, size: int) -> Array: return arr[:size] if arr.size > size else arr +@export def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: bool = True) -> Array: """Update array elements based on a mask. @@ -13229,6 +13414,7 @@ def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape) +@export def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = None, *, inplace: bool = True) -> Array: """Put elements into an array at given indices. diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 10cc90575cef..19388b903e5d 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -33,6 +33,10 @@ from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, _where) from jax._src.typing import Array, ArrayLike +from jax._src.util import set_module + + +export = set_module('jax.numpy') @jit @@ -57,6 +61,7 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array: return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan)) +@export def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: r"""Returns the roots of a polynomial given the coefficients ``p``. @@ -116,6 +121,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: return _roots_with_zeros(p_arr, num_leading_zeros) +@export @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, full: bool = False, w: ArrayLike | None = None, cov: bool = False @@ -287,6 +293,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, return c +@export @jit def poly(seq_of_zeros: ArrayLike) -> Array: r"""Returns the coefficients of a polynomial for the given sequence of roots. @@ -369,6 +376,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: return a +@export @partial(jit, static_argnames=['unroll']) def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: r"""Evaluates the polynomial at specific values. @@ -432,6 +440,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: return y +@export @jit def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the sum of the two polynomials. @@ -489,6 +498,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) +@export @partial(jit, static_argnames=('m',)) def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: r"""Returns the coefficients of the integration of specified order of a polynomial. @@ -557,6 +567,7 @@ def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array return true_divide(concatenate((p_arr, k_arr)), coeff) +@export @partial(jit, static_argnames=('m',)) def polyder(p: ArrayLike, m: int = 1) -> Array: r"""Returns the coefficients of the derivative of specified order of a polynomial. @@ -607,6 +618,7 @@ def polyder(p: ArrayLike, m: int = 1) -> Array: return p_arr[:-m] * coeff[::-1] +@export def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: r"""Returns the product of two polynomials. @@ -673,6 +685,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - return convolve(a1_arr, a2_arr, mode='full') +@export def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: r"""Returns the quotient and remainder of polynomial division. @@ -732,6 +745,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> return q, u_arr +@export @jit def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: r"""Returns the difference of two polynomials. diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 5acad86eabef..bc85bc3e8761 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -37,9 +37,11 @@ from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, - NumpyComplexWarning) + set_module, NumpyComplexWarning) +export = set_module('jax.numpy') + _all = builtins.all _lax_const = lax_internal._const @@ -222,6 +224,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) +@export def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: @@ -296,6 +299,7 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, promote_integers=promote_integers) + @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, @@ -307,6 +311,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None initial=initial, where_=where, promote_integers=promote_integers) +@export def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, @@ -391,6 +396,7 @@ def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where_=where, parallel_reduce=lax.pmax) +@export def max(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -473,6 +479,7 @@ def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where_=where, parallel_reduce=lax.pmin) +@export def min(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -552,6 +559,7 @@ def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def all(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether all array elements along a given axis evaluate to True. @@ -608,6 +616,7 @@ def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, axis=axis, out=out, keepdims=keepdims, where_=where) +@export def any(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: r"""Test whether any of the array elements along a given axis evaluate to True. @@ -714,6 +723,7 @@ def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) +@export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -721,6 +731,7 @@ def amin(a: ArrayLike, axis: Axis = None, out: None = None, return min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) +@export def amax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -740,6 +751,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): return size +@export def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: @@ -843,6 +855,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, * @overload def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: ... +@export def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: """Compute the weighed average. @@ -953,6 +966,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, return avg +@export def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -1093,6 +1107,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy return _upcast_f16(computation_dtype), np.dtype(dtype) +@export def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: @@ -1185,6 +1200,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) +@export def ptp(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: r"""Return the peak-to-peak range along a given axis. @@ -1236,6 +1252,7 @@ def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, return lax.sub(x, y) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def count_nonzero(a: ArrayLike, axis: Axis = None, keepdims: bool = False) -> Array: @@ -1295,6 +1312,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], return out +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1377,6 +1395,7 @@ def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'keepdims')) def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1459,6 +1478,7 @@ def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1542,6 +1562,7 @@ def nansum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -1625,6 +1646,7 @@ def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out initial=initial, where=where) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, where: ArrayLike | None = None) -> Array: @@ -1716,6 +1738,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out return td +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1818,6 +1841,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: return lax.convert_element_type(result, dtype) +@export @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, @@ -1939,6 +1963,7 @@ def _cumulative_reduction( return result +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -1975,6 +2000,7 @@ def cumsum(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def cumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2010,6 +2036,7 @@ def cumprod(a: ArrayLike, axis: int | None = None, return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumsum(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2059,6 +2086,7 @@ def nancumsum(a: ArrayLike, axis: int | None = None, fill_nan=True, fill_value=0) +@export @partial(api.jit, static_argnames=('axis', 'dtype')) def nancumprod(a: ArrayLike, axis: int | None = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -2115,6 +2143,7 @@ def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None, a, axis, dtype, out, promote_integers=True) +@export def cumulative_sum( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, @@ -2176,6 +2205,7 @@ def cumulative_sum( return out +@export def cumulative_prod( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, @@ -2239,6 +2269,7 @@ def cumulative_prod( # Quantiles # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2295,6 +2326,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", @@ -2475,7 +2507,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = result.reshape(keepdim) return lax.convert_element_type(result, a.dtype) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2531,7 +2565,9 @@ def percentile(a: ArrayLike, q: ArrayLike, return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) + # TODO(jakevdp): interpolation argument deprecated 2024-05-16 +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, @@ -2591,6 +2627,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, method=method, keepdims=keepdims) +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, @@ -2642,6 +2679,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, keepdims=keepdims, method='midpoint') +@export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 6491a7617d8d..0d5ea905becc 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -35,10 +35,12 @@ from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan from jax._src.numpy.util import check_arraylike, promote_dtypes -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module from jax._src.typing import Array, ArrayLike +export = set_module('jax.numpy') + _lax_const = lax_internal._const @@ -88,6 +90,7 @@ def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: return arr, num_unique1 + num_unique2 +@export def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set difference of two 1D arrays. @@ -175,6 +178,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value) +@export def union1d(ar1: ArrayLike, ar2: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set union of two 1D arrays. @@ -278,6 +282,7 @@ def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, return where(arange(len(vals)) < num_results, vals, fill_value) +@export def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set-wise xor of elements in two arrays. @@ -417,6 +422,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as return vals +@export def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return_indices: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]: @@ -524,6 +530,7 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return int1d +@export def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: bool = False, invert: bool = False, *, method='auto') -> Array: @@ -652,6 +659,7 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo return ret[0] if len(ret) == 1 else ret +@export def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False, return_counts: bool = False, axis: int | None = None, *, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None): @@ -863,6 +871,7 @@ class _UniqueInverseResult(NamedTuple): inverse_indices: Array +@export def unique_all(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueAllResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -945,6 +954,7 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) +@export def unique_counts(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueCountsResult: """Return unique values from x, along with counts. @@ -1005,6 +1015,7 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, return _UniqueCountsResult(values=values, counts=counts) +@export def unique_inverse(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> _UniqueInverseResult: """Return unique values from x, along with indices, inverse indices, and counts. @@ -1070,6 +1081,7 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None, return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) +@export def unique_values(x: ArrayLike, /, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Return unique values from x, along with indices, inverse indices, and counts. diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 27e2973b212b..5dbd67e62a9f 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -33,6 +33,8 @@ import numpy as np +export = set_module("jax.numpy") + _AT_INPLACE_WARNING = """\ Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g. @@ -40,7 +42,7 @@ """ -@set_module('jax.numpy') +@export class ufunc: """Universal functions which operation element-by-element on arrays. @@ -586,6 +588,7 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: return result.reshape(*np.shape(A), *np.shape(B)) +@export def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, *, identity: Any = None) -> ufunc: """Create a JAX ufunc from an arbitrary JAX-compatible scalar function. diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index a844ecbc28ac..bbbce9733aa5 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -38,6 +38,10 @@ promote_shapes, _where, check_no_float0s) from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy import reductions +from jax._src.util import set_module + + +export = set_module('jax.numpy') _lax_const = lax._const @@ -75,6 +79,7 @@ def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc: return decorator +@export @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. @@ -119,18 +124,21 @@ def fabs(x: ArrayLike, /) -> Array: return lax.abs(*promote_args_inexact('fabs', x)) +@export @partial(jit, inline=True) def bitwise_invert(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_invert', x)) +@export @partial(jit, inline=True) def bitwise_not(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.invert`.""" return lax.bitwise_not(*promote_args('bitwise_not', x)) +@export @partial(jit, inline=True) def invert(x: ArrayLike, /) -> Array: """Compute the bitwise inversion of an input. @@ -223,6 +231,7 @@ def negative(x: ArrayLike, /) -> Array: return lax.neg(*promote_args('negative', x)) +@export @partial(jit, inline=True) def positive(x: ArrayLike, /) -> Array: """Return element-wise positive values of the input. @@ -271,6 +280,7 @@ def positive(x: ArrayLike, /) -> Array: return lax.asarray(*promote_args('positive', x)) +@export @partial(jit, inline=True) def sign(x: ArrayLike, /) -> Array: r"""Return an element-wise indication of sign of the input. @@ -321,6 +331,7 @@ def sign(x: ArrayLike, /) -> Array: return lax.sign(*promote_args('sign', x)) +@export @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: """Round input to the nearest integer downwards. @@ -359,6 +370,7 @@ def floor(x: ArrayLike, /) -> Array: return lax.floor(*promote_args_inexact('floor', x)) +@export @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: """Round input to the nearest integer upwards. @@ -397,6 +409,7 @@ def ceil(x: ArrayLike, /) -> Array: return lax.ceil(*promote_args_inexact('ceil', x)) +@export @partial(jit, inline=True) def exp(x: ArrayLike, /) -> Array: """Calculate element-wise exponential of the input. @@ -438,6 +451,7 @@ def exp(x: ArrayLike, /) -> Array: return lax.exp(*promote_args_inexact('exp', x)) +@export @partial(jit, inline=True) def log(x: ArrayLike, /) -> Array: """Calculate element-wise natural logarithm of the input. @@ -475,6 +489,7 @@ def log(x: ArrayLike, /) -> Array: return lax.log(*promote_args_inexact('log', x)) +@export @partial(jit, inline=True) def expm1(x: ArrayLike, /) -> Array: """Calculate ``exp(x)-1`` of each element of the input. @@ -519,6 +534,7 @@ def expm1(x: ArrayLike, /) -> Array: return lax.expm1(*promote_args_inexact('expm1', x)) +@export @partial(jit, inline=True) def log1p(x: ArrayLike, /) -> Array: """Calculates element-wise logarithm of one plus input, ``log(x+1)``. @@ -559,6 +575,7 @@ def log1p(x: ArrayLike, /) -> Array: return lax.log1p(*promote_args_inexact('log1p', x)) +@export @partial(jit, inline=True) def sin(x: ArrayLike, /) -> Array: """Compute a trigonometric sine of each element of input. @@ -590,6 +607,7 @@ def sin(x: ArrayLike, /) -> Array: return lax.sin(*promote_args_inexact('sin', x)) +@export @partial(jit, inline=True) def cos(x: ArrayLike, /) -> Array: """Compute a trigonometric cosine of each element of input. @@ -620,6 +638,7 @@ def cos(x: ArrayLike, /) -> Array: return lax.cos(*promote_args_inexact('cos', x)) +@export @partial(jit, inline=True) def tan(x: ArrayLike, /) -> Array: """Compute a trigonometric tangent of each element of input. @@ -650,6 +669,7 @@ def tan(x: ArrayLike, /) -> Array: return lax.tan(*promote_args_inexact('tan', x)) +@export @partial(jit, inline=True) def arcsin(x: ArrayLike, /) -> Array: r"""Compute element-wise inverse of trigonometric sine of input. @@ -691,6 +711,7 @@ def arcsin(x: ArrayLike, /) -> Array: return lax.asin(*promote_args_inexact('arcsin', x)) +@export @partial(jit, inline=True) def arccos(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric cosine of input. @@ -733,6 +754,7 @@ def arccos(x: ArrayLike, /) -> Array: return lax.acos(*promote_args_inexact('arccos', x)) +@export @partial(jit, inline=True) def arctan(x: ArrayLike, /) -> Array: """Compute element-wise inverse of trigonometric tangent of input. @@ -773,6 +795,7 @@ def arctan(x: ArrayLike, /) -> Array: return lax.atan(*promote_args_inexact('arctan', x)) +@export @partial(jit, inline=True) def sinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic sine of input. @@ -827,6 +850,7 @@ def sinh(x: ArrayLike, /) -> Array: return lax.sinh(*promote_args_inexact('sinh', x)) +@export @partial(jit, inline=True) def cosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic cosine of input. @@ -880,6 +904,7 @@ def cosh(x: ArrayLike, /) -> Array: return lax.cosh(*promote_args_inexact('cosh', x)) +@export @partial(jit, inline=True) def arcsinh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic sine of input. @@ -929,6 +954,7 @@ def arcsinh(x: ArrayLike, /) -> Array: return lax.asinh(*promote_args_inexact('arcsinh', x)) +@export @jit def arccosh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic cosine of input. @@ -984,6 +1010,7 @@ def arccosh(x: ArrayLike, /) -> Array: return result +@export @partial(jit, inline=True) def tanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise hyperbolic tangent of input. @@ -1037,6 +1064,7 @@ def tanh(x: ArrayLike, /) -> Array: return lax.tanh(*promote_args_inexact('tanh', x)) +@export @partial(jit, inline=True) def arctanh(x: ArrayLike, /) -> Array: r"""Calculate element-wise inverse of hyperbolic tangent of input. @@ -1085,6 +1113,7 @@ def arctanh(x: ArrayLike, /) -> Array: return lax.atanh(*promote_args_inexact('arctanh', x)) +@export @partial(jit, inline=True) def sqrt(x: ArrayLike, /) -> Array: """Calculates element-wise non-negative square root of the input array. @@ -1117,6 +1146,7 @@ def sqrt(x: ArrayLike, /) -> Array: return lax.sqrt(*promote_args_inexact('sqrt', x)) +@export @partial(jit, inline=True) def cbrt(x: ArrayLike, /) -> Array: """Calculates element-wise cube root of the input array. @@ -1144,6 +1174,7 @@ def cbrt(x: ArrayLike, /) -> Array: """ return lax.cbrt(*promote_args_inexact('cbrt', x)) + def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.add.at.""" if a.dtype == bool: @@ -1152,6 +1183,7 @@ def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: return a.at[indices].add(b).astype(bool) return a.at[indices].add(b) + @binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) def add(x: ArrayLike, y: ArrayLike, /) -> Array: """Add two arrays element-wise. @@ -1182,6 +1214,7 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.multiply.at.""" if a.dtype == bool: @@ -1191,6 +1224,7 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: else: return a.at[indices].mul(b) + @binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: """Multiply two arrays element-wise. @@ -1221,6 +1255,7 @@ def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) + @binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and) def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise AND operation elementwise. @@ -1250,6 +1285,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) + @binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or) def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise OR operation elementwise. @@ -1279,6 +1315,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) + @binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor) def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise XOR operation elementwise. @@ -1309,6 +1346,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) +@export @partial(jit, inline=True) def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: r"""Shift bits of ``x`` to left by the amount specified in ``y``, element-wise. @@ -1364,12 +1402,14 @@ def left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.shift_left(*promote_args_numeric("left_shift", x, y)) +@export @partial(jit, inline=True) def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.left_shift`.""" return lax.shift_left(*promote_args_numeric("bitwise_left_shift", x, y)) +@export @partial(jit, inline=True) def equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x == y``. @@ -1419,6 +1459,7 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.eq(*promote_args("equal", x, y)) +@export @partial(jit, inline=True) def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Returns element-wise truth value of ``x != y``. @@ -1472,6 +1513,7 @@ def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array: """Implementation of jnp.subtract.at.""" return a.at[indices].subtract(b) + @binary_ufunc(identity=None, at=_subtract_at) def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: """Subtract two arrays element-wise. @@ -1502,6 +1544,7 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.sub(*promote_args("subtract", x, y)) +@export @partial(jit, inline=True) def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the arctangent of x1/x2, choosing the correct quadrant. @@ -1557,6 +1600,7 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) +@export @partial(jit, inline=True) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise minimum of the input arrays. @@ -1617,6 +1661,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.min(*promote_args("minimum", x, y)) +@export @partial(jit, inline=True) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise maximum of the input arrays. @@ -1676,6 +1721,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.max(*promote_args("maximum", x, y)) +@export @partial(jit, inline=True) def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: """Calculate element-wise base ``x`` exponential of ``y``. @@ -1722,6 +1768,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.pow(*promote_args_inexact("float_power", x, y)) +@export @partial(jit, inline=True) def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise next floating point value after ``x`` towards ``y``. @@ -1749,6 +1796,7 @@ def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.nextafter(*promote_args_inexact("nextafter", x, y)) +@export @partial(jit, inline=True) def spacing(x: ArrayLike, /) -> Array: """Return the spacing between ``x`` and the next adjacent number. @@ -1856,6 +1904,7 @@ def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) +@export @partial(jit, inline=True) def logical_not(x: ArrayLike, /) -> Array: """Compute NOT bool(x) element-wise. @@ -1901,6 +1950,8 @@ def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], lax_op(x.real, y.real)) return lax_op(x, y) + +@export @partial(jit, inline=True) def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x >= y``. @@ -1946,6 +1997,7 @@ def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) +@export @partial(jit, inline=True) def greater(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x > y``. @@ -1992,6 +2044,7 @@ def greater(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.gt, *promote_args("greater", x, y)) +@export @partial(jit, inline=True) def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x <= y``. @@ -2038,6 +2091,7 @@ def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) +@export @partial(jit, inline=True) def less(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise truth value of ``x < y``. @@ -2083,42 +2137,58 @@ def less(x: ArrayLike, y: ArrayLike, /) -> Array: """ return _complex_comparison(lax.lt, *promote_args("less", x, y)) + # Array API aliases +@export @partial(jit, inline=True) def acos(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccos`""" return arccos(*promote_args('acos', x)) + +@export @partial(jit, inline=True) def acosh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arccosh`""" return arccosh(*promote_args('acosh', x)) + +@export @partial(jit, inline=True) def asin(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsin`""" return arcsin(*promote_args('asin', x)) + +@export @partial(jit, inline=True) def asinh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arcsinh`""" return arcsinh(*promote_args('asinh', x)) + +@export @partial(jit, inline=True) def atan(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan`""" return arctan(*promote_args('atan', x)) + +@export @partial(jit, inline=True) def atanh(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctanh`""" return arctanh(*promote_args('atanh', x)) + +@export @partial(jit, inline=True) def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan2`""" return arctan2(*promote_args('atan2', x1, x2)) + +@export @jit def bitwise_count(x: ArrayLike, /) -> Array: r"""Counts the number of 1 bits in the binary representation of the absolute value @@ -2154,6 +2224,8 @@ def bitwise_count(x: ArrayLike, /) -> Array: # Following numpy we take the absolute value and return uint8. return lax.population_count(abs(x)).astype('uint8') + +@export @partial(jit, inline=True) def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. @@ -2205,12 +2277,14 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_fn(x1, x2) +@export @partial(jit, inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.right_shift`.""" return right_shift(x1, x2) +@export @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. @@ -2246,12 +2320,14 @@ def absolute(x: ArrayLike, /) -> Array: return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) +@export @partial(jit, inline=True) def abs(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.absolute`.""" return absolute(x) +@export @jit def rint(x: ArrayLike, /) -> Array: """Rounds the elements of x to the nearest integer @@ -2291,6 +2367,7 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) +@export @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. @@ -2330,6 +2407,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) +@export @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the division of x1 by x2 element-wise @@ -2368,11 +2446,13 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.div(x1, x2) +@export def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.true_divide`.""" return true_divide(x1, x2) +@export @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the floor division of x1 by x2 element-wise @@ -2427,6 +2507,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _float_divmod(x1, x2)[0] +@export @jit def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: """Calculates the integer quotient and remainder of x1 by x2 element-wise @@ -2481,6 +2562,7 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod +@export def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise base ``x1`` exponential of ``x2``. @@ -2565,6 +2647,7 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) +@export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.power`""" return power(x1, x2) @@ -2604,6 +2687,7 @@ def _pow_int_int(x1, x2): return acc +@export @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. @@ -2630,6 +2714,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) +@export @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. @@ -2662,6 +2747,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return logaddexp(x1 * ln2, x2 * ln2) / ln2 +@export @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: """Calculates the base-2 logarithm of ``x`` element-wise. @@ -2684,6 +2770,7 @@ def log2(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) +@export @partial(jit, inline=True) def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise @@ -2707,6 +2794,7 @@ def log10(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) +@export @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: """Calculate element-wise base-2 exponential of input. @@ -2741,6 +2829,7 @@ def exp2(x: ArrayLike, /) -> Array: return lax.exp2(x) +@export @jit def signbit(x: ArrayLike, /) -> Array: """Return the sign bit of array elements. @@ -2813,6 +2902,7 @@ def _normalize_float(x): return lax.bitcast_convert_type(x1, int_type), x2 +@export @jit def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute x1 * 2 ** x2 @@ -2862,6 +2952,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(isinf(x1) | (x1 == 0), x1, x) +@export @jit def frexp(x: ArrayLike, /) -> tuple[Array, Array]: """Split floating point values into mantissa and twos exponent. @@ -2915,6 +3006,7 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) +@export @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Returns element-wise remainder of the division. @@ -2962,11 +3054,13 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) +@export def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.remainder`""" return remainder(x1, x2) +@export @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculate element-wise floating-point modulo operation. @@ -3008,6 +3102,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.rem(*promote_args_numeric("fmod", x1, x2)) +@export @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: """Calculate element-wise square of the input array. @@ -3057,6 +3152,7 @@ def square(x: ArrayLike, /) -> Array: return lax.square(x) +@export @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: r"""Convert angles from degrees to radians. @@ -3091,6 +3187,7 @@ def deg2rad(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, np.pi / 180)) +@export @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: r"""Convert angles from radians to degrees. @@ -3126,15 +3223,19 @@ def rad2deg(x: ArrayLike, /) -> Array: return lax.mul(x, _lax_const(x, 180 / np.pi)) +@export def degrees(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.rad2deg`""" return rad2deg(x) + +@export def radians(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.deg2rad`""" return deg2rad(x) +@export @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: """Return element-wise complex-conjugate of the input. @@ -3164,11 +3265,13 @@ def conjugate(x: ArrayLike, /) -> Array: return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) +@export def conj(x: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.conjugate`""" return conjugate(x) +@export @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: """Return element-wise imaginary of part of the complex argument. @@ -3200,6 +3303,7 @@ def imag(val: ArrayLike, /) -> Array: return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) +@export @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: """Return element-wise real part of the complex argument. @@ -3231,6 +3335,7 @@ def real(val: ArrayLike, /) -> Array: return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) +@export @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: """Return element-wise fractional and integral parts of the input array. @@ -3264,6 +3369,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: return x - whole, whole +@export @partial(jit, inline=True) def isfinite(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is finite. @@ -3304,6 +3410,7 @@ def isfinite(x: ArrayLike, /) -> Array: return lax.full_like(x, True, dtype=np.bool_) +@export @jit def isinf(x: ArrayLike, /) -> Array: """Return a boolean array indicating whether each element of input is infinite. @@ -3359,6 +3466,7 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: return lax.full_like(x, False, dtype=np.bool_) +@export def isposinf(x, /, out=None): """ Return boolean array indicating whether each element of input is positive infinite. @@ -3392,6 +3500,7 @@ def isposinf(x, /, out=None): return _isposneginf(np.inf, x, out) +@export def isneginf(x, /, out=None): """ Return boolean array indicating whether each element of input is negative infinite. @@ -3425,6 +3534,7 @@ def isneginf(x, /, out=None): return _isposneginf(-np.inf, x, out) +@export @partial(jit, inline=True) def isnan(x: ArrayLike, /) -> Array: """Returns a boolean array indicating whether each element of input is ``NaN``. @@ -3459,6 +3569,7 @@ def isnan(x: ArrayLike, /) -> Array: return lax.ne(x, x) +@export @jit def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Compute the heaviside step function. @@ -3508,6 +3619,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) +@export @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: r""" @@ -3556,6 +3668,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(idx_inf, _lax_const(x, np.inf), x) +@export @partial(jit, inline=True) def reciprocal(x: ArrayLike, /) -> Array: """Calculate element-wise reciprocal of the input. @@ -3589,6 +3702,7 @@ def reciprocal(x: ArrayLike, /) -> Array: return lax.integer_pow(x, -1) +@export @jit def sinc(x: ArrayLike, /) -> Array: r"""Calculate the normalized sinc function. diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index e7a0e2142327..f1e6d399b97b 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -23,9 +23,11 @@ from jax._src import config from jax import lax from jax._src.numpy import lax_numpy as jnp -from jax._src.util import safe_map as map, safe_zip as zip +from jax._src.util import set_module, safe_map as map, safe_zip as zip +export = set_module('jax.numpy') + # See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html _DIMENSION_NAME = r'\w+' _CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME) @@ -185,6 +187,7 @@ def new_func(*args, **kwargs): return new_func, dynamic_args, dynamic_kwargs +@export def vectorize(pyfunc, *, excluded=frozenset(), signature=None): """Define a vectorized function with broadcasting. diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 71d48c2b121c..9bc8d0f6d71c 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -32,6 +32,14 @@ class PackageStructureTest(jtu.JaxTestCase): @parameterized.parameters([ # TODO(jakevdp): expand test to other public modules. _mod("jax.errors", exclude=["JaxRuntimeError"]), + _mod( + "jax.numpy", + exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating", + "dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo", + "flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim", + "number", "object_", "printoptions", "save", "savez", "set_printoptions", + "shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"] + ), _mod("jax.nn.initializers"), _mod( "jax.tree_util", @@ -46,7 +54,8 @@ def test_exported_names_match_module(self, module_name, include, exclude): if name not in include and (name.startswith('_') or name in exclude): continue obj = getattr(module, name) - if isinstance(obj, types.ModuleType): + if obj is None or isinstance(obj, (bool, int, float, complex, types.ModuleType)): + # No __module__ attribute expected. continue self.assertEqual(obj.__module__, module_name, f"{obj} has {obj.__module__=}, expected {module_name}") From 1471702adc286bcf40e87c42877d538b4d589f90 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 15 Nov 2024 06:41:14 -0800 Subject: [PATCH 360/698] [Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to treat 1D (N,) as (1, N) and then tile it as (1, 128) PiperOrigin-RevId: 696870258 --- jaxlib/mosaic/dialect/tpu/layout.h | 33 +++++++---- .../tpu/transforms/apply_vector_layout.cc | 56 ++++++++++++------ .../tpu/transforms/infer_vector_layout.cc | 59 +++++++++++-------- 3 files changed, 95 insertions(+), 53 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 66217858fa7d..6edad713b17a 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" @@ -39,6 +38,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "absl/log/check.h" namespace mlir::tpu { @@ -259,18 +259,23 @@ class VectorLayout { int layout_rank() const { return layout_rank(implicit_dim_); } bool operator==(const VectorLayout &other) const; - bool operator!=(const VectorLayout &other) const { - return !(*this == other); - } - - // How many tiles fit in each vector register. - int64_t tilesPerVreg(const std::array target_shape) const { - const int64_t tile_elems = tiling_[0] * tiling_[1]; - const int64_t vreg_capacity = packing() * target_shape[0] * target_shape[1]; + bool operator!=(const VectorLayout &other) const { return !(*this == other); } + + static int64_t tilesPerVreg(const std::array target_shape, + const int8_t bitwidth, + const std::array tiling) { + CHECK_NE(0, bitwidth) << "bitwidth cannot be 0"; + const int64_t tile_elems = tiling[0] * tiling[1]; + const int64_t vreg_capacity = + (32 / bitwidth) * target_shape[0] * target_shape[1]; const auto [tiles_per_vreg, rem] = std::div(vreg_capacity, tile_elems); CHECK_EQ(rem, 0); return tiles_per_vreg; } + // How many tiles fit in each vector register. + int64_t tilesPerVreg(const std::array target_shape) const { + return VectorLayout::tilesPerVreg(target_shape, bitwidth_, tiling_); + } int64_t sublanesPerTile(const std::array target_shape) const { auto [sublanes_per_tile, rem] = @@ -283,8 +288,16 @@ class VectorLayout { // // We never reuse the same vector register to store data of multiple rows, // so only the minormost dimension can increase. + static std::array vregSlice(std::array target_shape, + const int8_t bitwidth, + const std::array tiling) { + return { + tiling[0], + VectorLayout::tilesPerVreg(target_shape, bitwidth, tiling) * tiling[1]}; + } + std::array vregSlice(std::array target_shape) const { - return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]}; + return VectorLayout::vregSlice(target_shape, bitwidth_, tiling_); } template diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8792503f4636..2732b63d7638 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2554,7 +2554,10 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(res_layout.has_value()); auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank(); - if (dimension >= num_untiled_dims) { + if (res_ty.getRank() == 1 && + res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) { + tiling_dim = 1; + } else if (dimension >= num_untiled_dims) { tiling_dim = dimension - num_untiled_dims; } @@ -2576,6 +2579,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: result/input offsets mismatch."); } + if (layout.implicit_dim() != res_layout->implicit_dim()) { + return op.emitOpError( + "Not implemented: result/input implicit dim mismatch."); + } + if (i > 1) { auto curr_offsets = layout.offsets(); auto last_operand_offsets = layouts_in[i - 1]->offsets(); @@ -2611,29 +2619,47 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, if (!tiling_dim.has_value()) { out_vregs = concatenate(operand_vregs, dimension); } else { - if (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) { + bool is_rank1_with_no_implicit_dim = res_ty.getRank() == 1 && + res_layout->implicit_dim() == + VectorLayout::ImplicitDim::kNone; + if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kMinor || + is_rank1_with_no_implicit_dim) { return op.emitOpError("Not implemented: implicit dim"); } + if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && + res_layout->bitwidth() != 32) { + return op.emitOpError( + "Not implemented: only 32-bit bitwidth supported for SecondMinor " + "implicit dim"); + } if (res_layout->offsets()[tiling_dim.value()] != 0) { return op.emitOpError("Not implemented: result non-zero offset."); } - if (!res_layout->hasNativeTiling(ctx.target_shape)) { + if (!res_layout->hasNativeTiling(ctx.target_shape) && + res_ty.getRank() != 1) { return op.emitOpError("Not implemented: Non native tiling in concat."); } int64_t offset_at_dim = 0; { for (int i = 0; i < op.getNumOperands(); ++i) { - auto operand = op.getOperand(i); - auto const &layout = *layouts_in[i]; - - auto vty = cast(operand.getType()); - auto shape = vty.getShape(); - - auto starting_point = offset_at_dim; - auto offset_amount = - starting_point % layout.tiling()[tiling_dim.value()]; - if (offset_amount != layout.offsets()[tiling_dim.value()]) { + Value operand = op.getOperand(i); + const Layout &layout = *layouts_in[i]; + xla::Array vreg_array = operand_vregs[i]; + std::array vreg_slice = layout->vregSlice(ctx.target_shape); + std::array tiling = layout->tiling(); + + VectorType vty = cast(operand.getType()); + ArrayRef shape = vty.getShape(); + + int64_t starting_point = offset_at_dim; + int64_t offset_amount = + starting_point % vreg_slice[tiling_dim.value()]; + if (offset_amount >= tiling[tiling_dim.value()]) { + return op.emitError( + "Not implemented: Input offsets outside of the first tile"); + } + if (offset_amount != layout->offsets()[tiling_dim.value()]) { return op.emitOpError( "Not implemented: Relayout not called, unaligned dims " "concatenated without proper offsets. Ensure that " @@ -2649,10 +2675,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, auto &vreg = operand_vregs[i]; const auto &layout = layouts_in[i]; - if (layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return op.emitOpError("Not implemented: implicit dim"); - } - const int64_t operand_offset = *layout->offsets()[tiling_dim.value()]; if (operand_offset != 0) { // We are offset, so we must blend with the previous vreg. diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index bf668b8ecb52..30486b6e995c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -770,14 +770,11 @@ class VectorLayoutInferer { LogicalResult infer(tpu::ConcatenateOp op) { TPU_CHECK_OP(!op.getSources().empty(), "Need at least one vector to concatenate"); - auto res_rank = op.getType().getRank(); - auto dimension = op.getDimension(); + int64_t res_rank = op.getType().getRank(); + uint32_t dimension = op.getDimension(); TPU_CHECK_OP(0 <= dimension && dimension < res_rank, "Expect a valid concatenate dimension"); - if (res_rank == 1) { - NYI("Support concatenation with 1D vectors"); - } - auto res_ty = op.getResult().getType(); + VectorType res_ty = op.getResult().getType(); int8_t bitwidth = res_ty.getElementTypeBitWidth(); std::optional tiling_dim; @@ -790,29 +787,39 @@ class VectorLayoutInferer { if (tiling_dim.has_value()) { int64_t starting_point = 0; - auto first_layout = getLayout(op.getSources().front()); - auto op_layouts = getLayoutFromOperands(op); + Layout first_layout = getLayout(op.getSources().front()); + SmallVector op_layouts = getLayoutFromOperands(op); SmallVector in_layouts; in_layouts.reserve(op.getSources().size()); - auto native_tiling = nativeTiling(bitwidth); - + // Set implicit dim to treat 1D as (1, N) and tile it as (1, 128) + std::array tiling = + res_rank == 1 ? std::array{1L, target_shape_[1]} + : nativeTiling(bitwidth); + ImplicitDim implicit_dim = + res_rank == 1 ? ImplicitDim::kSecondMinor : ImplicitDim::kNone; + std::array vreg_slice = + VectorLayout::vregSlice(target_shape_, bitwidth, tiling); for (int i = 0; i < op.getSources().size(); ++i) { // Compute the offset per source. // Ex: for a cat of (10, 128), (10, 128) on dim 0, where the - // vreg_sice for that dim is 8, the first source starts at + // vreg_slice for that dim is 8, the first source starts at // offset 0, and overflows the vreg // by 2, so the offset for the second input is 2. - auto op_shape = + ArrayRef op_shape = cast(op.getSources()[i].getType()).getShape(); - auto offset_amount = starting_point % native_tiling[tiling_dim.value()]; - auto op_layout = op_layouts[i]; + Layout op_layout = op_layouts[i]; + int64_t offset_amount = starting_point % vreg_slice[tiling_dim.value()]; + if (offset_amount >= tiling[tiling_dim.value()]) { + return op.emitError( + "Not implemented: Input offsets outside of the first tile"); + } SmallVector in_idx{op_layout->offsets()[0].value_or(0), op_layout->offsets()[1].value_or(0)}; in_idx[tiling_dim.value()] = offset_amount; starting_point += op_shape[dimension]; in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]}, - native_tiling, ImplicitDim::kNone)); + tiling, implicit_dim)); } SmallVector res_layout_offsets( {first_layout->offsets()[0].value_or(0), @@ -821,13 +828,13 @@ class VectorLayoutInferer { // TODO(mvoz): A tiny optimization we could do here later is to // no-op setting tiling when sublane dim size is aligned to sublane // tiling. - auto res_layout = + VectorLayout res_layout = VectorLayout(bitwidth, {res_layout_offsets[0], res_layout_offsets[1]}, - native_tiling, ImplicitDim::kNone); + tiling, implicit_dim); setLayout(op, in_layouts, res_layout); return success(); } else { - auto layout = getLayout(op.getSources().front()); + Layout layout = getLayout(op.getSources().front()); // When concatenating vectors with replicated offsets, we want to reset // the replicated offset to zero. Because we are not sure if the // replicated value from each vector are same. @@ -1464,11 +1471,11 @@ class VectorLayoutInferer { // unfolding, it's still a no-op, but we need to // add support in apply-vector-layout. LayoutOffsets offsets = {0, layout.offsets()[1]}; - setLayout(op, - VectorLayout(layout.bitwidth(), offsets, tiling, - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), offsets, tiling, - implicit_dim)); + setLayout( + op, + VectorLayout(layout.bitwidth(), offsets, tiling, + layout.implicit_dim()), + VectorLayout(layout.bitwidth(), offsets, tiling, implicit_dim)); return success(); } sublane_tiling /= 2; @@ -1845,9 +1852,9 @@ class VectorLayoutInferer { "only 32-bit random bit generation supported"); // TODO: b/342054464 - Support implicit dims for PRNGRandomBitsOp. LayoutOffsets offsets = {0, 0}; - setOutLayout(op, VectorLayout( - kNativeBitwidth, offsets, nativeTiling(kNativeBitwidth), - ImplicitDim::kNone)); + setOutLayout( + op, VectorLayout(kNativeBitwidth, offsets, + nativeTiling(kNativeBitwidth), ImplicitDim::kNone)); return success(); } From 23e9142d2873436472991b4a96f14234f472d8df Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 15 Nov 2024 08:49:35 -0800 Subject: [PATCH 361/698] Lower threefry as an out-of-line MLIR function on TPU. On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size. --- jax/_src/prng.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index d2df5d8bbace..2256e12da1d4 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -891,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): return tuple(x) -_threefry2x32_lowering_rule = mlir.lower_fun( +# Since the unrolled lowering is large, emit it as an out-of-line function. +_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=False), - multiple_results=True) + multiple_results=True)) _threefry2x32_cpu_lowering_rule = mlir.lower_fun( partial(_threefry2x32_lowering, use_rolled_loops=True), From 5f9428443219afb80192e16eb078368eeb7c48ef Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 15 Nov 2024 12:14:55 -0800 Subject: [PATCH 362/698] Add missing functions to jax.numpy type interface --- jax/numpy/__init__.pyi | 45 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 339174136234..af7b056fcbb0 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -29,6 +29,46 @@ _Device = Device ComplexWarning: type +class ufunc: + def __init__(self, func: Callable[..., Any], /, + nin: int, nout: int, *, + name: str | None = None, + nargs: int | None = None, + identity: Any = None, + call: Callable[..., Any] | None = None, + reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None, + ): ... + @property + def nin(self) -> int: ... + @property + def nout(self) -> int: ... + @property + def nargs(self) -> int: ... + @property + def identity(self) -> builtins.bool | int | float: ... + def __call__(self, *args: ArrayLike) -> Any: ... + def reduce(self, a: ArrayLike, /, *, + axis: int | None = 0, + dtype: DTypeLike | None = None, + out: None = None, + keepdims: builtins.bool = False, + initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def accumulate(self, a: ArrayLike, /, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, + inplace: builtins.bool = True) -> Array: ... + def reduceat(self, a: ArrayLike, indices: Any, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ... + class BinaryUfunc(Protocol): @property def nin(self) -> int: ... @@ -39,9 +79,10 @@ class BinaryUfunc(Protocol): @property def identity(self) -> builtins.bool | int | float: ... def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ... - def reduce(self, arr: ArrayLike, /, *, + def reduce(self, a: ArrayLike, /, *, axis: int | None = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: builtins.bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: ... @@ -434,6 +475,8 @@ def fromfile(*args, **kwargs): ... def fromfunction(function: Callable[..., Array], shape: Any, *, dtype: DTypeLike = ..., **kwargs) -> Array: ... def fromiter(*args, **kwargs): ... +def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, + *, identity: Any = None) -> ufunc: ... def fromstring( string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str ) -> Array: ... From 5f1e3f5644b6705b21b5e030d241a514c244c2c4 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 15 Nov 2024 11:26:52 -0800 Subject: [PATCH 363/698] Add an example on logical operators to the tutorial. --- docs/control-flow.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/control-flow.md b/docs/control-flow.md index 04eb3cac8d24..7cb959f3e434 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -340,6 +340,39 @@ $\ast$ = argument-value-independent loop condition - unrolls the loop `jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`. +For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar. + +```{code-cell} +def python_check_positive_even(x): + is_even = x % 2 == 0 + # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated. + return is_even and (x > 0) + +@jit +def jax_check_positive_even(x): + is_even = x % 2 == 0 + # `logical_and` does not short circuit, so `x > 0` is always evaluated. + return jnp.logical_and(is_even, x > 0) + +print(python_check_positive_even(24)) +print(jax_check_positive_even(24)) +``` + +When the JAX version with `logical_and` is applied to an array, it returns elementwise values. + +```{code-cell} +x = jnp.array([-1, 2, 5]) +print(jax_check_positive_even(x)) +``` + +Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior. + +```{code-cell} +:tags: [raises-exception] + +print(python_check_positive_even(x)) +``` + +++ {"id": "izLTvT24dAq0"} ## Python control flow + autodiff From 1780ff2964803c292dfa81adbba0f738ebafc0b9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 15 Nov 2024 13:27:42 -0800 Subject: [PATCH 364/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/195f45b7082930033f6533a160b0f8f7f1cbfb40. PiperOrigin-RevId: 696984108 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 043b9d019eb1..e7ae7fe718a6 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ecdba3f23b20e684c5e67a5ddb4f004de724f6df" -XLA_SHA256 = "bfb87208d43324cdb20e03c9802360a580062b913e975b1470148dd99dfbb0d1" +XLA_COMMIT = "195f45b7082930033f6533a160b0f8f7f1cbfb40" +XLA_SHA256 = "75e77091bae789175f3de24efee9debf8835b167770490db75571bf65c27b727" def repo(): tf_http_archive( From 81cdc882aee6ab1ddb48dea6144fa52d0dc7a9c9 Mon Sep 17 00:00:00 2001 From: barnesjoseph Date: Fri, 15 Nov 2024 13:44:31 -0800 Subject: [PATCH 365/698] DOC: update main landing page style Co-authored-by: Jake VanderPlas --- docs/_static/jax-hero.svg | 118 ++++++++++++++++++ docs/_static/style.css | 255 +++++++++++++++++++++++++++++++++++++- docs/hero.html | 8 ++ docs/index.rst | 19 ++- docs/requirements.txt | 3 +- 5 files changed, 394 insertions(+), 9 deletions(-) create mode 100644 docs/_static/jax-hero.svg create mode 100644 docs/hero.html diff --git a/docs/_static/jax-hero.svg b/docs/_static/jax-hero.svg new file mode 100644 index 000000000000..04626f43eacd --- /dev/null +++ b/docs/_static/jax-hero.svg @@ -0,0 +1,118 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/style.css b/docs/_static/style.css index 2c1dfcbcbf08..32033940e8c4 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1,34 +1,279 @@ @import url("theme.css"); +/* Base LP sidebar modifications */ +body:has(.hero) .sidebar-toggle, +body:has(.hero) .bd-sidebar-secondary { + display: none !important; +} + +body:has(.hero) .search-button { + display: flex !important; +} + +body:has(.hero) .primary-toggle { + display: inline-block !important; +} + +body:has(.hero) .prev-next-footer { + display: none; +} + +body:has(.hero) .bd-article-container { + max-width: unset !important; +} + +body:has(.hero) .bd-page-width { + max-width: unset !important; +} + +body:has(.hero) .bd-article { + display: flex; + flex-direction: column; + padding: 0; +} + +body:has(.hero) .bd-container { + flex-direction: column; +} + +@media (min-width: 960px) { + body:has(.hero) .bd-header-article { + justify-content: center; + } + + body:has(.hero) .header-article-items, + body:has(.hero) .bd-article > section { + max-width: 65rem !important; + align-self: center; + width: -moz-available; + width: -webkit-fill-available; + width: fill-available; + } +} + +/* Custom CSS */ + :root { --block-bg-opacity: .5; } +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) { + padding: 0; +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * { + padding-inline: 2rem !important; +} + +@media (max-width: 768px) { + .bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > * { + padding-inline: 1rem !important; + } +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) h1 { + display: none; +} + .wy-side-nav-search { background-color: #fff; } +.getting-started, +.user-guides, .installation { - background-color: rgba(78, 150, 253, var(--block-bg-opacity)); + background: #3C4043; + color: white; + height: 170px; + border: none !important; + border-radius: 12px; +} + +.getting-started:hover, +.user-guides:hover, +.installation:hover { + background: #AECBFA; + color: #202124; + transform: unset !important; +} + +.getting-started .sd-card-body, +.user-guides .sd-card-body, +.installation .sd-card-body { + display: flex; + align-items: center; + justify-content: center; + font: 500 24px 'Roboto'; +} + +.getting-started .sd-card-title, +.user-guides .sd-card-title, +.installation .sd-card-title { + display: flex; + flex-direction: column; + align-items: center; + gap: 12px; +} + +.getting-started svg, +.user-guides svg, +.installation svg { + color: #8AB4F8; +} + +.getting-started:hover svg, +.user-guides:hover svg, +.installation:hover svg { + color: #3C4043; +} + +.bd-main .bd-content .bd-article-container .bd-article:has(.hero) > section > .hero { + padding-inline: 2rem 0 !important; } -.getting-started { - background-color: rgba(0, 169, 154, var(--block-bg-opacity)); +.hero { + display: grid; + grid: auto-flow / 1fr .6fr; + align-items: center; + background: rgb(32,33,36); + background: linear-gradient(90deg, rgba(32,33,36,1) 0%, rgba(39,45,56,1) 100%); + position: relative; + overflow: hidden; + border-radius: 24px; } -.user-guides { - background-color: rgba(171, 0, 182, var(--block-bg-opacity)); +.hero > img { + position: absolute; + top: 0; + right: 0; + height: 100%; + background: transparent !important; +} + +.hero-left { + padding-block: 24px; + display: flex; + flex-direction: column; +} + +.hero-left img { + width: 100px; + height: auto; + position: relative; + margin-bottom: 16px; + background: transparent !important; +} + +.hero-left h2 { + font: 500 32px 'Google Sans'; + color: white; + margin-top: 0; +} + +.hero-left p { + font: 400 16px 'Roboto'; + color: white; +} + +@media (max-width: 1295px) { + .hero > img { + right: -75px; + } +} + +@media (max-width: 750px) { + .hero { + grid: auto-flow / 1fr; + } + + .hero-left { + padding-right: 2rem; + } + + .hero > img { + display: none; + } +} + +.product-offerings { + margin-block: 32px !important; +} + +.product-offerings .sd-card-title { + font: 400 24px 'Google Sans'; +} + +.color-cards { + background: #E8EAED; + color: #222832; + padding: 48px 12px 0 12px; + margin-bottom: 0 !important; + border-radius: 24px 24px 0 0; +} + +.color-cards > div { + gap: 24px 0; +} + +.color-cards + p { + background: #E8EAED; + padding: 24px 12px 48px 12px; + font-weight: 600; + color: #222832; + border-radius: 0 0 24px 24px; +} + +.color-cards + p > a { + color: #222832; +} + +.color-cards + p > a:hover, +html[data-theme="dark"] .color-cards + p > a:hover { + color: #e89217; +} + +html[data-theme="dark"] .color-cards, +html[data-theme="dark"] .hero, +html[data-theme="dark"] .color-cards + p, +html[data-theme="dark"] .color-cards + p > a { + background: #202124; + color: white; } .ecosystem-grid { font-size: smaller; } +.ecosystem-grid > div { + gap: 20px; +} + +.ecosystem-grid .sd-col { + border: 1px solid #dadce0; + border-radius: 8px; + width: calc(50% - 10px); + padding: 16px; +} + +.ecosystem-grid .sd-col > p { + display: flex; + flex-direction: column; + gap: 10px; +} + +.ecosystem-grid .sd-col > p > svg { + color: #00897B; +} + .ecosystem-grid ul { list-style-type: none; padding-inline-start: 0.5em; } +.ecosystem-grid a { + text-decoration: none; +} + div.red-background pre { background-color: rgba(244, 204, 204, var(--block-bg-opacity)); } diff --git a/docs/hero.html b/docs/hero.html new file mode 100644 index 000000000000..a2ee3b8e206f --- /dev/null +++ b/docs/hero.html @@ -0,0 +1,8 @@ +
+
+ +

High performance array computing

+

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

+
+ +
\ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 5f3bce5cf7da..ba8ebcbdd128 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,10 +1,22 @@ JAX: High performance array computing ===================================== -JAX is a Python library for accelerator-oriented array computation and program transformation, -designed for high-performance numerical computing and large-scale machine learning. +.. raw:: html + + + + +.. raw:: html + :file: hero.html .. grid:: 3 + :class-container: product-offerings :margin: 0 :padding: 0 :gutter: 0 @@ -31,6 +43,7 @@ designed for high-performance numerical computing and large-scale machine learni The same code executes on multiple backends, including CPU, GPU, & TPU .. grid:: 3 + :class-container: color-cards .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Installation :columns: 12 6 6 4 @@ -59,7 +72,7 @@ JAX itself is narrowly-scoped and focuses on efficient array operations & progra transformations. Built around JAX is an evolving ecosystem of machine learning and numerical computing tools; the following is just a small sample of what is out there: -.. grid:: 4 +.. grid:: 2 :class-container: ecosystem-grid .. grid-item:: :material-outlined:`hub;2em` **Neural networks** diff --git a/docs/requirements.txt b/docs/requirements.txt index 41d8aa6d9ee7..bfbb4e271d42 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,8 @@ absl-py ipython>=8.8.0 # 8.7.0 has ipython3 lexer error +pydata-sphinx-theme==0.14.4 # v0.15 breaks sidebar toggling sphinx>=7.3.2,<8.0 # 7.3.0 breaks sphinx-book-theme; 8.0 breaks myst-nb 1.1 -sphinx-book-theme>=1.0.1 # Older versions fail to pin pydata-sphinx-theme +sphinx-book-theme==1.1.1 # v1.1.2 requires pydata-sphinx-theme v0.15 sphinx-copybutton>=0.5.0 sphinx-remove-toctrees sphinx-design From 225a2a5f8bfe710e6a4aecb182d5bdd87683193b Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 15 Nov 2024 10:30:13 -0800 Subject: [PATCH 366/698] Consolidate material on PRNGs and add a short summary to Key Concepts. --- README.md | 2 +- docs/key-concepts.md | 40 +++ docs/notebooks/Common_Gotchas_in_JAX.ipynb | 307 +-------------------- docs/notebooks/Common_Gotchas_in_JAX.md | 148 +--------- docs/random-numbers.md | 20 +- jax/_src/errors.py | 4 +- 6 files changed, 65 insertions(+), 456 deletions(-) diff --git a/README.md b/README.md index 1395ae23a46e..b001a8ceeb15 100644 --- a/README.md +++ b/README.md @@ -348,7 +348,7 @@ Some standouts: 1. [In-place mutating updates of arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). + different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. diff --git a/docs/key-concepts.md b/docs/key-concepts.md index daab2c9fdde4..91f0c953462e 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -189,3 +189,43 @@ tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the le in a tree. You can learn more in the {ref}`working-with-pytrees` tutorial. + +(key-concepts-prngs)= +## Pseudorandom numbers + +Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`: + +```{code-cell} +from jax import random + +key = random.key(43) +print(key) +``` + +The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions. +Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated. + +```{code-cell} +print(random.normal(key)) +print(random.normal(key)) +``` + +**The rule of thumb is: never reuse keys (unless you want identical outputs).** + +In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: + +```{code-cell} +for i in range(3): + new_key, subkey = random.split(key) + del key # The old key is consumed by split() -- we must never use it again. + + val = random.normal(subkey) + del subkey # The subkey is consumed by normal(). + + print(f"draw {i}: {val}") + key = new_key # new_key is safe to use in the next iteration. +``` + +Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. + +For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 92c736957db6..02077d2a6b00 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -865,312 +865,9 @@ "id": "MUycRNh6e50W" }, "source": [ - "## 🔪 Random numbers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O8vvaVt3MRG2" - }, - "source": [ - "> _If all scientific papers whose results are in doubt because of bad\n", - "> `rand()`s were to disappear from library shelves, there would be a\n", - "> gap on each shelf about as big as your fist._ - Numerical Recipes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qikt9pPW9L5K" - }, - "source": [ - "### RNGs and state\n", - "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "rr9FeP41fynt", - "outputId": "df0ceb15-96ec-4a78-e327-c77f7ea3a745" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.2726690048900553\n", - "0.6304191979771206\n", - "0.6933648856441533\n" - ] - } - ], - "source": [ - "print(np.random.random())\n", - "print(np.random.random())\n", - "print(np.random.random())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ORMVVGZJgSVi" - }, - "source": [ - "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "7Pyp2ajzfPO2" - }, - "outputs": [], - "source": [ - "np.random.seed(0)\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n", - "# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n", - "# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aJIxHVXCiM6m" - }, - "source": [ - "This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "GAHaDCYafpAF" - }, - "outputs": [], - "source": [ - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n", - "\n", - "# Let's exhaust the entropy in this PRNG statevector\n", - "for i in range(311):\n", - " _ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "#print(rng_state)\n", - "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", - "# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n", + "## 🔪 Random numbers\n", "\n", - "# Next call iterates the RNG state for a new batch of fake \"entropy\".\n", - "_ = np.random.uniform()\n", - "rng_state = np.random.get_state()\n", - "# print(rng_state)\n", - "# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n", - "# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "N_mWnleNogps" - }, - "source": [ - "The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n", - "\n", - "The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Uvq7nV-j4vKK" - }, - "source": [ - "### JAX PRNG" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "COjzGBpO4tzL" - }, - "source": [ - "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", - "\n", - "The random state is described by a special array element that we call a __key__:" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "id": "yPHE7KTWgAWs", - "outputId": "ae8af0ee-f19e-474e-81b6-45e894eb2fc3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([0, 0], dtype=uint32)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "key = random.key(0)\n", - "key" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XjYyWYNfq0hW" - }, - "source": [ - "JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n", - "\n", - "Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "7zUdQMynoE5e", - "outputId": "23a86b72-dfb9-410a-8e68-22b48dc10805" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.20584226]\n", - "[0 0]\n", - "[-0.20584226]\n", - "[0 0]\n" - ] - } - ], - "source": [ - "print(random.normal(key, shape=(1,)))\n", - "print(key)\n", - "# No no no!\n", - "print(random.normal(key, shape=(1,)))\n", - "print(key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hQN9van8rJgd" - }, - "source": [ - "Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "ASj0_rSzqgGh", - "outputId": "2f13f249-85d1-47bb-d503-823eca6961aa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [0 0]\n", - " \\---SPLIT --> new key [4146024105 967050713]\n", - " \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tqtFVE4MthO3" - }, - "source": [ - "We propagate the __key__ and make new __subkeys__ whenever we need a new random number:" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "jbC34XLor2Ek", - "outputId": "4059a2e2-0205-40bc-ad55-17709d538871" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "old key [4146024105 967050713]\n", - " \\---SPLIT --> new key [2384771982 3928867769]\n", - " \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n" - ] - } - ], - "source": [ - "print(\"old key\", key)\n", - "key, subkey = random.split(key)\n", - "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(r\" \\---SPLIT --> new key \", key)\n", - "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0KLYUluz3lN3" - }, - "source": [ - "We can generate more than one __subkey__ at a time:" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "lEi08PJ4tfkX", - "outputId": "1f280560-155d-4c04-98e8-c41d72ee5b01" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-0.37533438]\n", - "[0.98645043]\n", - "[0.14553197]\n" - ] - } - ], - "source": [ - "key, *subkeys = random.split(key, 4)\n", - "for subkey in subkeys:\n", - " print(random.normal(subkey, shape=(1,)))" + "JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial." ] }, { diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 00955de236e7..f35c5ead13b7 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -384,153 +384,7 @@ jnp.sum(jnp.array(x)) ## 🔪 Random numbers -+++ {"id": "O8vvaVt3MRG2"} - -> _If all scientific papers whose results are in doubt because of bad -> `rand()`s were to disappear from library shelves, there would be a -> gap on each shelf about as big as your fist._ - Numerical Recipes - -+++ {"id": "Qikt9pPW9L5K"} - -### RNGs and state -You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness: - -```{code-cell} ipython3 -:id: rr9FeP41fynt -:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 - -print(np.random.random()) -print(np.random.random()) -print(np.random.random()) -``` - -+++ {"id": "ORMVVGZJgSVi"} - -Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this "entropy" has been used up. - -```{code-cell} ipython3 -:id: 7Pyp2ajzfPO2 - -np.random.seed(0) -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044, -# 2481403966, 4042607538, 337614300, ... 614 more numbers..., -# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0) -``` - -+++ {"id": "aJIxHVXCiM6m"} - -This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, "consuming" 2 of the uint32s in the Mersenne twister state vector: - -```{code-cell} ipython3 -:id: GAHaDCYafpAF - -_ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0) - -# Let's exhaust the entropy in this PRNG statevector -for i in range(311): - _ = np.random.uniform() -rng_state = np.random.get_state() -#print(rng_state) -# --> ('MT19937', array([2443250962, 1093594115, 1878467924, -# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0) - -# Next call iterates the RNG state for a new batch of fake "entropy". -_ = np.random.uniform() -rng_state = np.random.get_state() -# print(rng_state) -# --> ('MT19937', array([1499117434, 2949980591, 2242547484, -# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0) -``` - -+++ {"id": "N_mWnleNogps"} - -The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user. - -The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow. - -+++ {"id": "Uvq7nV-j4vKK"} - -### JAX PRNG - -+++ {"id": "COjzGBpO4tzL"} - -JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. - -The random state is described by a special array element that we call a __key__: - -```{code-cell} ipython3 -:id: yPHE7KTWgAWs -:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 - -key = random.key(0) -key -``` - -+++ {"id": "XjYyWYNfq0hW"} - -JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! - -Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__: - -```{code-cell} ipython3 -:id: 7zUdQMynoE5e -:outputId: 23a86b72-dfb9-410a-8e68-22b48dc10805 - -print(random.normal(key, shape=(1,))) -print(key) -# No no no! -print(random.normal(key, shape=(1,))) -print(key) -``` - -+++ {"id": "hQN9van8rJgd"} - -Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number: - -```{code-cell} ipython3 -:id: ASj0_rSzqgGh -:outputId: 2f13f249-85d1-47bb-d503-823eca6961aa - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "tqtFVE4MthO3"} - -We propagate the __key__ and make new __subkeys__ whenever we need a new random number: - -```{code-cell} ipython3 -:id: jbC34XLor2Ek -:outputId: 4059a2e2-0205-40bc-ad55-17709d538871 - -print("old key", key) -key, subkey = random.split(key) -normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(r" \---SPLIT --> new key ", key) -print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) -``` - -+++ {"id": "0KLYUluz3lN3"} - -We can generate more than one __subkey__ at a time: - -```{code-cell} ipython3 -:id: lEi08PJ4tfkX -:outputId: 1f280560-155d-4c04-98e8-c41d72ee5b01 - -key, *subkeys = random.split(key, 4) -for subkey in subkeys: - print(random.normal(subkey, shape=(1,))) -``` +JAX's pseudo-random number generation differs from Numpy's in important ways. For a quick how-to, see {ref}`key-concepts-prngs`. For more details, see the {ref}`pseudorandom-numbers` tutorial. +++ {"id": "rg4CpMZ8c3ri"} diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 2ad1eadb0968..00f77e3473bb 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -17,6 +17,10 @@ kernelspec: +> _If all scientific papers whose results are in doubt because of bad +> `rand()`s were to disappear from library shelves, there would be a +> gap on each shelf about as big as your fist._ - Numerical Recipes + In this section we focus on {mod}`jax.random` and pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next. @@ -35,6 +39,19 @@ import numpy as np np.random.seed(0) ``` +Repeated calls to NumPy's stateful pseudorandom number generators (PRNGs) mutate the global state and give a stream of pseudorandom numbers: + +```{code-cell} +:id: rr9FeP41fynt +:outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745 + +print(np.random.random()) +print(np.random.random()) +print(np.random.random()) +``` + +Underneath the hood, NumPy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by 624 32-bit unsigned ints and a position indicating how much of this "entropy" has been used up. + You can inspect the content of the state using the following command. ```{code-cell} @@ -109,7 +126,7 @@ Further, when executing in multi-device environments, execution efficiency would ### Explicit random state -To avoid this issue, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: +To avoid these issues, JAX avoids implicit global random state, and instead tracks state explicitly via a random `key`: ```{code-cell} from jax import random @@ -137,6 +154,7 @@ Re-using the same key, even with different {mod}`~jax.random` APIs, can result i **The rule of thumb is: never reuse keys (unless you want identical outputs).** +JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: ```{code-cell} diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 590f68ac0b3b..6540fd1f5d41 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -677,7 +677,7 @@ class KeyReuseError(JAXTypeError): KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 This sort of key reuse is problematic because the JAX PRNG is stateless, and keys - must be manually split; For more information on this see `Sharp Bits: Random Numbers - `_. + must be manually split; For more information on this see `the Pseudorandom Numbers + tutorial `_. """ pass From 8525ef2b23f12affcff23b9a54d4d2515acb671f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 15 Nov 2024 17:41:42 -0800 Subject: [PATCH 367/698] [sharding_in_types] Don't emit a wsc under full manual mode to avoid increasing HLO size by a lot PiperOrigin-RevId: 697048126 --- jax/_src/interpreters/mlir.py | 20 +++++++++----------- jax/_src/mesh.py | 2 +- tests/pjit_test.py | 8 ++++---- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ee3c929b26f7..a1d326162a1c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2475,17 +2475,15 @@ def _wrap_with_spmd_op(name: str, def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): - if sharding_proto is None: - proto = aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() - else: - proto = sharding_proto - # TODO(yashkatariya): Setting all axes as unspecified should work even when - # any axes is Collective because that's what happens in partial auto shmap. - # Do that after tests for it exists. - unspecified_dims = (set(range(aval.ndim)) - if aval.sharding.mesh.are_all_axes_collective else None) - return wrap_with_sharding_op( - ctx, op, aval, proto, unspecified_dims=unspecified_dims) + # Don't emit a wsc under full manual mode to avoid increasing HLO size. + if aval.sharding.mesh._are_all_axes_collective: + return op + proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + if sharding_proto is None else sharding_proto) + # TODO(yashkatariya): Enable this + # unspecified_dims = (set(range(aval.ndim)) + # if aval.sharding.mesh._any_axis_collective else None) + return wrap_with_sharding_op(ctx, op, aval, proto) def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 6c6017c4b2b7..a2ab261fa0e9 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -426,7 +426,7 @@ def empty(self): return self.size == 0 @functools.cached_property - def are_all_axes_collective(self) -> bool: + def _are_all_axes_collective(self) -> bool: if self.axis_types is None: return False return all(t == AxisTypes.Collective for t in self.axis_types.keys()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7196a6335960..be1f9cfc267a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5208,8 +5208,8 @@ def test_shard_map_full_manual(self): arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) def g(x, y): - self.assertTrue(x.sharding.mesh.are_all_axes_collective) - self.assertTrue(y.sharding.mesh.are_all_axes_collective) + self.assertTrue(x.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.sharding.mesh._are_all_axes_collective) return x * y @jax.jit @@ -5232,8 +5232,8 @@ def test_shard_map_dot(self): arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) def g(x, y): - self.assertTrue(x.sharding.mesh.are_all_axes_collective) - self.assertTrue(y.sharding.mesh.are_all_axes_collective) + self.assertTrue(x.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.sharding.mesh._are_all_axes_collective) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y return jax.lax.psum(z, axis_name='y') From 609dfac29452e4842c62168c9c9036f38976a57d Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Mon, 11 Nov 2024 22:59:24 -0800 Subject: [PATCH 368/698] Adds a flag to control proxy env checking. name typo fix. Fixes comments. --- jax/_src/clusters/cluster.py | 6 ------ jax/_src/distributed.py | 38 ++++++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 2fb13fde72cf..69ef77a6421d 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -49,12 +49,6 @@ def auto_detect_unset_distributed_params(cls, initialization_timeout: int | None, ) -> tuple[str | None, int | None, int | None, Sequence[int] | None]: - - if all(p is not None for p in (coordinator_address, num_processes, - process_id, local_device_ids)): - return (coordinator_address, num_processes, process_id, - local_device_ids) - # First, we check the spec detection method because it will ignore submitted values # If if succeeds. if cluster_detection_method is not None: diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 5b9130fc0455..f80f90bde186 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -27,6 +27,13 @@ logger = logging.getLogger(__name__) +_CHECK_PROXY_ENVS = config.bool_flag( + name="jax_check_proxy_envs", + default=True, + help="Checks proxy vars in user envs and emit warnings.", +) + + class State: process_id: int = 0 num_processes: int = 1 @@ -55,16 +62,17 @@ def initialize(self, if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')): local_device_ids = list(map(int, env_ids.split(","))) - (coordinator_address, num_processes, process_id, local_device_ids) = ( - clusters.ClusterEnv.auto_detect_unset_distributed_params( - coordinator_address, - num_processes, - process_id, - local_device_ids, - cluster_detection_method, - initialization_timeout, - ) - ) + if None in (coordinator_address, num_processes, process_id, local_device_ids): + (coordinator_address, num_processes, process_id, local_device_ids) = ( + clusters.ClusterEnv.auto_detect_unset_distributed_params( + coordinator_address, + num_processes, + process_id, + local_device_ids, + cluster_detection_method, + initialization_timeout, + ) + ) if coordinator_address is None: raise ValueError('coordinator_address should be defined.') @@ -92,8 +100,10 @@ def initialize(self, self.process_id = process_id - # Emit a warning about PROXY variables if they are in the user's env: - proxy_vars = [ key for key in os.environ.keys() if '_proxy' in key.lower()] + proxy_vars = [] + if _CHECK_PROXY_ENVS.value: + proxy_vars = [key for key in os.environ.keys() + if '_proxy' in key.lower()] if len(proxy_vars) > 0: vars = " ".join(proxy_vars) + ". " @@ -179,7 +189,9 @@ def initialize(coordinator_address: str | None = None, ``cluster_detection_method="mpi4py"`` to bootstrap the required arguments. Otherwise, you must provide the ``coordinator_address``, - ``num_processes``, and ``process_id`` arguments to :func:`~jax.distributed.initialize`. + ``num_processes``, ``process_id``, and ``local_device_ids`` arguments + to :func:`~jax.distributed.initialize`. When all four arguments are provided, cluster + environment auto detection will be skipped. Please note: on some systems, particularly HPC clusters that only access external networks through proxy variables such as HTTP_PROXY, HTTPS_PROXY, etc., the call to From 626aea017b6c60b346f5e7edebfc5bbf116ff4cf Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 16 Nov 2024 11:10:45 -0800 Subject: [PATCH 369/698] Deduplicate constants in StableHLO lowering. The goal of this change is to reduce the size of the generated code: we frequently built thousands of scalar 0s, for example. --- jax/_src/interpreters/mlir.py | 41 ++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a1d326162a1c..477ba6880eda 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1752,6 +1752,38 @@ def _emit_lowering_rule_as_fun(lowering_rule, return func_op +class HashableLiteral: + """Hashable wrapper of core.Literal, used for deduplicating IR constants.""" + + __slots__ = ["value"] + + value: core.Literal + + def __init__(self, value): + self.value = value + + def __hash__(self): + h = self.value.hash + return id(self.value.val) if h is None else h + + def __eq__(self, other): + if self is other: + return True + if type(self.value.val) != type(other.value.val): + return False + if self.value.aval != other.value.aval: + return False + if isinstance(self.value.val, (bool, int, float, complex)): + return self.value == other.value + if isinstance(self.value.val, (np.generic, np.ndarray)): + return np.array_equal( + self.value.val, other.value.val, + equal_nan=np.issubdtype(self.value.val.dtype, np.inexact)) + # Since the use case is constant deduplication, it's safe to return + # False in unhandled cases. + return False + + def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, name_stack: source_info_util.NameStack, tokens: TokenSet, @@ -1767,9 +1799,16 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, IR function, in the order of ctx.shape_poly_state.dim_vars. """ assert "gpu" not in ctx.platforms + cached_ir_consts: dict[HashableLiteral, IrValues] = {} + def read(v: core.Atom) -> IrValues: if type(v) is core.Literal: - return ir_constant(xla.canonicalize_dtype(v.val)) + h = HashableLiteral(v) + c = cached_ir_consts.get(h) + if c is None: + c = ir_constant(xla.canonicalize_dtype(v.val)) + cached_ir_consts[h] = c + return c else: assert isinstance(v, core.Var) return env[v] From 1d519f4ce3cd4a621b8f7e1bceab75317ed5db24 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 16 Nov 2024 13:38:23 -0800 Subject: [PATCH 370/698] Return a ndarray in shape_as_value if the shape is known to be constant. --- jax/_src/lax/lax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b780aab870e9..9c183ae93d41 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4547,6 +4547,8 @@ def shape_as_value(shape: core.Shape): """Converts a shape that may contain Poly values into a JAX value.""" if len(shape) == 0: return full((0,), np.array(0, np.int64)) + if core.is_constant_shape(shape): + return np.asarray(shape, dtype=np.int64) dims = [ expand_dims(convert_element_type(core.dimension_as_value(d), np.int64), (0,)) From 7b9914d711593dca8725d46aa1dadb2194284519 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 16 Nov 2024 13:39:24 -0800 Subject: [PATCH 371/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9ab7d704d7fe7e73fc3976adc2ccec070bc9a2ea. PiperOrigin-RevId: 697222155 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e7ae7fe718a6..af2fab8ed55f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "195f45b7082930033f6533a160b0f8f7f1cbfb40" -XLA_SHA256 = "75e77091bae789175f3de24efee9debf8835b167770490db75571bf65c27b727" +XLA_COMMIT = "9ab7d704d7fe7e73fc3976adc2ccec070bc9a2ea" +XLA_SHA256 = "6944ceaa425cacd30a54cca3cd6c4cb88b79f219d421fb97fa87ffbf06007143" def repo(): tf_http_archive( From 8a6c560b2562be13de5c0808db143b614db531ee Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 16 Nov 2024 14:29:20 -0800 Subject: [PATCH 372/698] Use a direct StableHLO lowering for pow. This is slightly faster than lowering via tracing, and the code is simpler also. --- jax/_src/lax/lax.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b780aab870e9..cf16c9b99935 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2574,15 +2574,12 @@ def _pow_jvp_rhs(g, ans, x, y): def _pow_lower(ctx, x, y): x_aval, y_aval = ctx.avals_in - out_aval, = ctx.avals_out - convert = mlir.lower_fun( - partial(convert_element_type, new_dtype=out_aval.dtype), False) - x_aval_ = x_aval.update(dtype=out_aval.dtype) - y_aval_ = y_aval.update(dtype=out_aval.dtype) - [x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x) - [y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y) - ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_]) - return _nary_lower_hlo(hlo.power, ctx_, x_, y_) + if x_aval.dtype != y_aval.dtype: + out_aval, = ctx.avals_out + y_aval = y_aval.update(dtype=out_aval.dtype) + y = hlo.convert(mlir.aval_to_ir_type(y_aval), y) + ctx = ctx.replace(avals_in=[x_aval, y_aval]) + return _nary_lower_hlo(hlo.power, ctx, x, y) mlir.register_lowering(pow_p, _pow_lower) def _integer_pow_dtype_rule(x, *, y): From 27bf80a50617ff38aca19e1da7c0b7599e691ef6 Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Mon, 11 Nov 2024 22:59:24 -0800 Subject: [PATCH 373/698] Adds an env that can let users provide a custom version suffix for jax dev build. fix the local version update to what Jake suggested --- jax/version.py | 6 +++++- tests/version_test.py | 30 ++++++++++++++++++++---------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/jax/version.py b/jax/version.py index c27caf979ddb..3e8a8291ec8d 100644 --- a/jax/version.py +++ b/jax/version.py @@ -60,7 +60,11 @@ def _version_from_git_tree(base_version: str) -> str | None: except: return None else: - return f"{base_version}.dev{datestring}+{commit_hash}" + version = f"{base_version}.dev{datestring}+{commit_hash}" + suffix = os.environ.get("JAX_CUSTOM_VERSION_SUFFIX", None) + if suffix: + return version + "." + suffix + return version def _get_version_for_build() -> str: diff --git a/tests/version_test.py b/tests/version_test.py index 7ce98c8588e5..51297a9716b1 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -26,11 +26,11 @@ # This is a subset of the full PEP440 pattern; for example we skip pre & post releases VERSION_PATTERN = re.compile(r""" - ^ # start of string - (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' - (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' - (?:\+(?P[a-zA-Z0-9_]+))? # optional local version; like '+g6643af3c3' - $ # end of string + ^ # start of string + (?P[0-9]+\.[0-9]+\.[0-9]+) # main version; like '0.4.16' + (?:\.dev(?P[0-9]+))? # optional dev version; like '.dev20230908' + (?:\+(?P[a-zA-Z0-9_.]+))? # optional local version; like '+g6643af3c3' + $ # end of string """, re.VERBOSE) @@ -61,11 +61,12 @@ def assert_no_subprocess_call(): @contextlib.contextmanager -def assert_subprocess_call(): +def assert_subprocess_call(stdout: bytes | None = None): """Run code, asserting that subprocess.Popen *is* called at least once.""" with mock.patch("subprocess.Popen") as mock_Popen: + mock_Popen.return_value.communicate.return_value = (stdout, b"") yield - mock_Popen.assert_called() + mock_Popen.return_value.communicate.assert_called() class JaxVersionTest(unittest.TestCase): @@ -126,7 +127,7 @@ def testBuildVersionFromEnvironment(self): self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY="1"): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() datestring = datetime.date.today().strftime("%Y%m%d") @@ -134,19 +135,28 @@ def testBuildVersionFromEnvironment(self): self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE="1", JAXLIB_RELEASE=None, - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() self.assertEqual(version, base_version) self.assertValidVersion(version) with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE="1", - JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() self.assertEqual(version, base_version) self.assertValidVersion(version) + with jtu.set_env(JAX_RELEASE=None, JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None, + JAX_CUSTOM_VERSION_SUFFIX="test"): + with assert_subprocess_call(stdout=b"1731433958-1c0f1076e"): + version = jax.version._get_version_for_build() + self.assertTrue(version.startswith(f"{base_version}.dev")) + self.assertTrue(version.endswith("test")) + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3") From 742cabc54724456397dee7fd4e92411aa57f16b4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 17 Nov 2024 14:19:00 -0800 Subject: [PATCH 374/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/58ea2935b4316b48979cb47f617ae06ce9f49638. PiperOrigin-RevId: 697425145 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index af2fab8ed55f..b35f9daa2144 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9ab7d704d7fe7e73fc3976adc2ccec070bc9a2ea" -XLA_SHA256 = "6944ceaa425cacd30a54cca3cd6c4cb88b79f219d421fb97fa87ffbf06007143" +XLA_COMMIT = "58ea2935b4316b48979cb47f617ae06ce9f49638" +XLA_SHA256 = "669eef5690be3e1059de8429cdfbf24bf0a15a5aa6e00b9aefd7a072d839d0aa" def repo(): tf_http_archive( From ed250b89831aab2e2ed672ad05e13a7eee818396 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 17 Nov 2024 23:58:46 -0800 Subject: [PATCH 375/698] [AutoPGLE] Temporary disable pgle_test in the OSS. PiperOrigin-RevId: 697517161 --- tests/pgle_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index fa574df18f29..609ca38fd7a5 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -43,6 +43,11 @@ @jtu.pytest_mark_if_available('multiaccelerator') +# TODO(patrios): Remove this skip once b/379267258 is fixed. +@jtu.skip_under_pytest( + 'This test requires specific XLA_FLAGS. However, pytest does not reload ' + 'modules between tests. So if another test is launched before this one ' + 'necessary XLA_FLAGS will not be re-used by the XLA.') class PgleTest(jtu.JaxTestCase): _dump_exit_stack: ExitStack | None = None From 8607cb6470726b074a04379e2cdf295613260c7a Mon Sep 17 00:00:00 2001 From: charleshofer Date: Mon, 18 Nov 2024 09:56:53 -0600 Subject: [PATCH 376/698] Make daily sync permissions at the workflow level and fix merge CI (#143) --- .../workflows/rocm-nightly-upstream-sync.yml | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index e8cb5f480313..f309427df197 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -6,21 +6,22 @@ on: workflow_dispatch: schedule: - cron: '0 6 * * 1-5' +permissions: + contents: write + pull-requests: write env: SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }} jobs: sync-main: - permissions: - contents: write runs-on: ubuntu-latest steps: - - run: gh repo sync rocm/jax -b main + - run: | + gh auth status + gh repo sync rocm/jax -b main env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} create-sync-branch: needs: sync-main - permissions: - contents: write runs-on: ubuntu-latest env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -32,11 +33,15 @@ jobs: git fetch git checkout origin/main git checkout -b $SYNC_BRANCH_NAME + # Try and merge rocm-main into this new branch so that we don't run upstream's CI code + git config --global user.email "github-actions@github.com" + git config --global user.name "GitHub Actions" + git merge origin/rocm-main || true + # If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts + git merge --abort || true git push origin HEAD open-sync-pr: needs: create-sync-branch - permissions: - pull-requests: write runs-on: ubuntu-latest steps: - run: | From ccb331707e80b16d89de6e5c9f2f89b87c1682ed Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 18 Nov 2024 08:11:04 -0800 Subject: [PATCH 377/698] Add a GPU implementation of `lax.linalg.eig`. This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. PiperOrigin-RevId: 697631402 --- CHANGELOG.md | 3 + jax/_src/config.py | 11 + jax/_src/lax/linalg.py | 133 ++++- jax/_src/numpy/linalg.py | 4 +- jaxlib/cpu/lapack_kernels.cc | 28 -- jaxlib/cpu/lapack_kernels.h | 29 ++ jaxlib/cuda/BUILD | 50 ++ jaxlib/gpu/BUILD | 3 + jaxlib/gpu/hybrid.cc | 60 +++ jaxlib/gpu/hybrid_kernels.cc | 631 ++++++++++++++++++++++++ jaxlib/gpu/hybrid_kernels.h | 55 +++ jaxlib/gpu_solver.py | 43 ++ jaxlib/jax.bzl | 1 + jaxlib/rocm/BUILD | 43 ++ jaxlib/tools/build_gpu_kernels_wheel.py | 2 + jaxlib/tools/build_wheel.py | 2 + tests/BUILD | 7 + tests/lax_numpy_test.py | 4 +- tests/linalg_test.py | 35 +- tests/magma_linalg_test.py | 125 +++++ 20 files changed, 1214 insertions(+), 55 deletions(-) create mode 100644 jaxlib/gpu/hybrid.cc create mode 100644 jaxlib/gpu/hybrid_kernels.cc create mode 100644 jaxlib/gpu/hybrid_kernels.h create mode 100644 tests/magma_linalg_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d670e43b6137..204df6a83e52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. declared inline via {func}`dataclasses.field`. See the function documentation for examples. * Added {func}`jax.numpy.put_along_axis`. + * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions + ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now + supported on GPU. See {jax-issue}`#24663` for more details. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/jax/_src/config.py b/jax/_src/config.py index 72f394dba76f..1c62f7125ee7 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1963,3 +1963,14 @@ def _update_garbage_collection_guard(state, key, val): ), include_in_jit_key=True, ) + +gpu_use_magma = enum_state( + name='jax_use_magma', + enum_values=['off', 'on', 'auto'], + default='auto', + help=( + 'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. ' + 'See the documentation for lax.linalg.eig for more details about how ' + 'to use this feature.' + ), +) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0e0390abc78f..62cb72c69fd7 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -121,16 +121,46 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, - compute_right_eigenvectors: bool = True) -> list[Array]: + compute_right_eigenvectors: bool = True, + use_magma: bool | None = None) -> list[Array]: """Eigendecomposition of a general matrix. - Nonsymmetric eigendecomposition is at present only implemented on CPU. + Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU, + the default implementation calls LAPACK directly on the host CPU, but an + experimental GPU implementation using `MAGMA `_ + is also available. The MAGMA implementation is typically slower than the + equivalent LAPACK implementation for small matrices (less than about 2048), + but it may perform better for larger matrices. + + To enable the MAGMA implementation, you must install MAGMA yourself (there + are Debian and conda-forge packages, or you can build from source). Then set + the ``use_magma`` argument to ``True``, or set the ``jax_use_magma`` + configuration variable to ``"on"`` or ``"auto"``: + + .. code-block:: python + + jax.config.update('jax_use_magma', 'on') + + JAX will try to ``dlopen`` the installed MAGMA shared library, raising an + error if it is not found. To explicitly specify the path to the MAGMA + library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full + installation path. + + If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will + be used if the library can be found, and the input matrix is sufficiently + large (>= 2048x2048). Args: x: A batch of square matrices with shape ``[..., n, n]``. compute_left_eigenvectors: If true, the left eigenvectors will be computed. compute_right_eigenvectors: If true, the right eigenvectors will be computed. + use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the + eigendecomposition is computed using MAGMA. If ``False``, the computation + is done using LAPACK on to the host CPU. If ``None`` (default), the + behavior is controlled by the ``jax_use_magma`` flag. This argument + is only used on GPU. + Returns: The eigendecomposition of ``x``, which is a tuple of the form ``(w, vl, vr)`` where ``w`` are the eigenvalues, ``vl`` are the left @@ -142,7 +172,8 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, for that batch element. """ return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors) + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma) def eigh( @@ -678,12 +709,14 @@ def _symmetric_product_jax_fn(a, c, *, alpha, beta): # Asymmetric eigendecomposition -def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): +def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): return dispatch.apply_primitive( eig_p, operand, compute_left_eigenvectors=compute_left_eigenvectors, compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma, ) def eig_lower(*args, **kw): @@ -692,7 +725,8 @@ def eig_lower(*args, **kw): "If your matrix is symmetric or Hermitian, you should use eigh instead.") def eig_abstract_eval(operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError("Argument to nonsymmetric eigendecomposition must have " @@ -716,7 +750,8 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors, return tuple(output) def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused operand_aval, = ctx.avals_in out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] @@ -763,18 +798,94 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, return output +def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors, + compute_right_eigenvectors, use_magma): + gpu_solver.initialize_hybrid_kernels() + dtype = x.dtype + is_real = dtype == np.float32 or dtype == np.float64 + if is_real: + target_name = f"{target_name_prefix}hybrid_eig_real" + complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + else: + target_name = f"{target_name_prefix}hybrid_eig_comp" + assert dtype == np.complex64 or dtype == np.complex128 + complex_dtype = dtype + + batch_dims = x.shape[:-2] + n, m = x.shape[-2:] + assert n == m + num_batch_dims = len(batch_dims) + + layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims) + out_types = [ + api.ShapeDtypeStruct(batch_dims + (n,), dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims, np.int32), + ] + out_layouts = [None, layout, layout, None] + if is_real: + out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types + out_layouts = [None] + out_layouts + + magma = config.gpu_use_magma.value + if use_magma is not None: + magma = "on" if use_magma else "off" + fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout], + output_layouts=out_layouts) + *w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + assert len(w) == 2 + w = lax.complex(*w) + else: + assert len(w) == 1 + w = w[0] + ok = lax.eq(info, lax.zeros_like_array(info)) + ok = _broadcast_to(ok[..., None], w.shape) + w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j)) + ok = _broadcast_to(ok[..., None], x.shape) + output = [w] + if compute_left_eigenvectors: + vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j)) + output.append(vl) + if compute_right_eigenvectors: + vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j)) + output.append(vr) + return output + + +def _eig_gpu_lowering(target_name_prefix, ctx, operand, *, + compute_left_eigenvectors, compute_right_eigenvectors, + use_magma): + if ctx.is_forward_compat(): + raise NotImplementedError( + "Export of nonsymmetric eigendecomposition on GPU is not supported " + "because of forward compatibility. The " + "'jax_export_ignore_forward_compatibility' configuration option can be " + "used to disable this check.") + rule = mlir.lower_fun(partial( + _eig_gpu_impl, target_name_prefix, + compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), multiple_results=True) + return rule(ctx, operand) + + def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors, - compute_right_eigenvectors=compute_right_eigenvectors), + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=use_magma), (0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors)) def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, - compute_right_eigenvectors): + compute_right_eigenvectors, use_magma): + del use_magma # unused if compute_left_eigenvectors or compute_right_eigenvectors: raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' @@ -793,6 +904,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, eig_p.def_abstract_eval(eig_abstract_eval) mlir.register_lowering(eig_p, eig_lower) mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'), + platform='cuda') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'), + platform='rocm') batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 03f864919887..76a4abff48ad 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -731,7 +731,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: - This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128 for 64-bit input. - - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. + - At present, non-symmetric eigendecomposition is only implemented on the CPU and + GPU backends. For more details about the GPU implementation, see the + documentation for :func:`jax.lax.linalg.eig`. See also: - :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix. diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 19b82a5ce149..ed815e1b1bd2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// LAPACK uses a packed representation to represent a mixture of real -// eigenvectors and complex conjugate pairs. This helper unpacks the -// representation into regular complex matrices. -template -static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag, - const T* packed, std::complex* unpacked) { - for (int j = 0; j < n;) { - if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { - // Real values in each row without imaginary part - // Second row of the imaginary part is not provided - for (int i = 0; i < n; ++i) { - unpacked[j * n + i] = {packed[j * n + i], 0.}; - } - ++j; - } else { - // Complex values where the real part is in the jth row - // and the imaginary part is in the next row (j + 1) - for (int i = 0; i < n; ++i) { - const T real_part = packed[j * n + i]; - const T imag_part = packed[(j + 1) * n + i]; - unpacked[j * n + i] = {real_part, imag_part}; - unpacked[(j + 1) * n + i] = {real_part, -imag_part}; - } - j += 2; - } - } -} - // lapack geev template diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 7d15e494fffc..cddcb1162120 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ +#include #include #include #include @@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian { // lapack geev +// LAPACK uses a packed representation to represent a mixture of real +// eigenvectors and complex conjugate pairs. This helper unpacks the +// representation into regular complex matrices. +template +static void UnpackEigenvectors(Int n, const T* eigenvals_imag, + const T* packed, std::complex* unpacked) { + for (int j = 0; j < n;) { + if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { + // Real values in each row without imaginary part + // Second row of the imaginary part is not provided + for (int i = 0; i < n; ++i) { + unpacked[j * n + i] = {packed[j * n + i], 0.}; + } + ++j; + } else { + // Complex values where the real part is in the jth row + // and the imaginary part is in the next row (j + 1) + for (int i = 0; i < n; ++i) { + const T real_part = packed[j * n + i]; + const T imag_part = packed[(j + 1) * n + i]; + unpacked[j * n + i] = {real_part, imag_part}; + unpacked[(j + 1) * n + i] = {real_part, -imag_part}; + } + j += 2; + } + } +} + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 34e40d12d5be..afce2c000ecc 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -476,6 +476,55 @@ pybind_extension( ], ) +cc_library( + name = "cuda_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + module_name = "_hybrid", + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_hybrid_kernels", + ":cuda_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cudart", + ], +) + cc_library( name = "cuda_gpu_kernels", srcs = ["//jaxlib/gpu:gpu_kernels.cc"], @@ -633,6 +682,7 @@ py_library( name = "cuda_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_rnn", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 7d50a91cfcda..e888f6a42a9b 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -37,6 +37,9 @@ exports_files(srcs = [ "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", + "hybrid.cc", + "hybrid_kernels.cc", + "hybrid_kernels.h", "linalg.cc", "linalg_kernels.cc", "linalg_kernels.cu.cc", diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc new file mode 100644 index 000000000000..afe95a650d29 --- /dev/null +++ b/jaxlib/gpu/hybrid.cc @@ -0,0 +1,60 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/gpu/hybrid_kernels.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace { +namespace ffi = xla::ffi; +namespace nb = nanobind; + +void GetLapackKernelsFromScipy() { + static bool initialized = false; // Protected by GIL + if (initialized) return; + nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas"); + nb::module_ cython_lapack = + nb::module_::import_("scipy.linalg.cython_lapack"); + nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); + auto lapack_ptr = [&](const char* name) { + return nb::cast(lapack_capi[name]).data(); + }; + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>(lapack_ptr("cgeev")); + AssignKernelFn>( + lapack_ptr("zgeev")); +} + +NB_MODULE(_hybrid, m) { + m.def("initialize", GetLapackKernelsFromScipy); + m.def("has_magma", []() { return MagmaLookup().FindMagmaInit().ok(); }); + m.def("registrations", []() { + nb::dict dict; + dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal); + dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp); + return dict; + }); +} + +} // namespace +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc new file mode 100644 index 000000000000..1ce2e547b11f --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.cc @@ -0,0 +1,631 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/hybrid_kernels.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace ffi = ::xla::ffi; + +// This helper class is used to define a host buffer that can be copied to and +// from a device buffer. +template +class HostBuffer { + public: + HostBuffer(std::size_t size) : size_(size) { + data_ = std::unique_ptr(new T[size]); + } + + absl::Status CopyFromDevice(gpuStream_t stream, const T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(data_.get(), buffer, size_ * sizeof(T), + gpuMemcpyDeviceToHost, stream)); + } + + absl::Status CopyToDevice(gpuStream_t stream, T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(buffer, data_.get(), size_ * sizeof(T), + gpuMemcpyHostToDevice, stream)); + } + + T* get() const { return data_.get(); } + + private: + std::unique_ptr data_; + size_t size_; +}; + +// Forwarded from MAGMA for use as an input parameter. +typedef enum { + MagmaNoVec = 301, + MagmaVec = 302, +} magma_vec_t; + +// Compile time lookup of MAGMA function names depending on the data type. +template +struct always_false : std::false_type {}; + +template +struct MagmaGeev { + static_assert(always_false::value, "unsupported data type"); +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_sgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_dgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_cgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_zgeev"; +}; + +MagmaLookup::~MagmaLookup() { + if (initialized_) { + void* magma_finalize = dlsym(handle_, "magma_finalize"); + if (magma_finalize != nullptr) { + reinterpret_cast(magma_finalize)(); + } + } + if (handle_ != nullptr) { + dlclose(handle_); + } +} + +absl::StatusOr MagmaLookup::FindMagmaInit() { + void* magma_init = nullptr; + std::vector paths; + const char* magma_lib_path = std::getenv("JAX_GPU_MAGMA_PATH"); + if (magma_lib_path != nullptr) { + paths.push_back(magma_lib_path); + } else { + paths.push_back("libmagma.so.2"); + paths.push_back("libmagma.so"); + paths.push_back(nullptr); + } + for (const auto& path : paths) { + handle_ = dlopen(path, RTLD_LAZY); + if (handle_ != nullptr) { + magma_init = dlsym(handle_, "magma_init"); + if (magma_init != nullptr) { + if (path != nullptr) { + lib_path_ = std::string(path); + } + break; + } + } + } + if (handle_ == nullptr || magma_init == nullptr) { + return absl::InternalError( + "Unable to dlopen a MAGMA shared library that defines a magma_init " + "symbol. Use the JAX_GPU_MAGMA_PATH environment variable to " + "specify an explicit path to the library."); + } + return magma_init; +} + +absl::Status MagmaLookup::Initialize() { + if (failed_) { + return absl::InternalError("MAGMA initialization was unsuccessful."); + } + if (!initialized_) { + auto maybe_magma_init = FindMagmaInit(); + if (!maybe_magma_init.ok()) { + failed_ = true; + return maybe_magma_init.status(); + } + reinterpret_cast(maybe_magma_init.value())(); + initialized_ = true; + } + return absl::OkStatus(); +} + +absl::StatusOr MagmaLookup::Find(const char name[]) { + if (!initialized_) { + return absl::InternalError("MAGMA support has not been initialized."); + } + + auto it = symbols_.find(name); + if (it != symbols_.end()) return it->second; + + void* symbol = dlsym(handle_, name); + if (symbol == nullptr) { + if (lib_path_.has_value()) { + return absl::InternalError(absl::StrFormat( + "Unable to load the symbol '%s' from the MAGMA library at '%s'.", + name, lib_path_.value())); + + } else { + return absl::InternalError(absl::StrFormat( + "Unable to load a globally defined symbol called '%s'. Use the " + "JAX_GPU_MAGMA_PATH environment variable to specify an explicit " + "path to the library.", + name)); + } + } + + symbols_.insert({name, symbol}); + return symbol; +} + +// Lookup the MAGMA symbol for the given function name. This function only +// dlopen the MAGMA library once per process. +absl::StatusOr FindMagmaSymbol(const char name[]) { + static absl::Mutex mu; + static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu); + absl::MutexLock lock(&mu); + auto status = lookup.Initialize(); + if (!status.ok()) { + return status; + } + return lookup.Find(name); +} + +// Real-valued eigendecomposition + +template +class EigRealHost { + using Real = ffi::NativeType; + + public: + explicit EigRealHost() = default; + EigRealHost(EigRealHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecomposition::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + EigenvalueDecomposition::fn(&jobvl_, &jobvr_, &n_, x, &n_, wr, wi, + vl, &n_, vr, &n_, work, &lwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigRealMagma { + using Real = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Real*, int, Real*, Real*, Real*, + int, Real*, int, Real*, int, int*); + + public: + explicit EigRealMagma() = default; + EigRealMagma(EigRealMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Real query_host; + fn_(jobvl_, jobvr_, n, nullptr, n, nullptr, nullptr, nullptr, n, nullptr, n, + &query_host, -1, &query_info); + return static_cast(query_host); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, n_, wr, wi, vl, n_, vr, n_, work, lwork, info); + } + + private: + int n_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigReal(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto wr_host = HostBuffer(batch * cols); + auto wi_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto work_left = AllocateScratchMemory(cols * cols); + auto work_right = AllocateScratchMemory(cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), + [](auto value) { return std::isfinite(value); }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, wr_host.get() + i * cols, + wi_host.get() + i * cols, work_left.get(), work_right.get(), + work_host.get(), lwork, info_host.get() + i); + if (info_host.get()[i] == 0) { + if (left) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_left.get(), + vl_host.get() + i * cols * cols); + } + if (right) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_right.get(), + vr_host.get() + i * cols * cols); + } + } + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + wr_host.CopyToDevice(stream, wr->typed_data())); + FFI_RETURN_IF_ERROR_STATUS( + wi_host.CopyToDevice(stream, wi->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigRealDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != wr->element_type() || dataType != wi->element_type() || + ffi::ToComplex(dataType) != vl->element_type() || + ffi::ToComplex(dataType) != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(wr->dimensions(), {batch, cols}, "wr", "eig")); + FFI_RETURN_IF_ERROR(CheckShape(wi->dimensions(), {batch, cols}, "wi", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::F32: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + case ffi::F64: + if (use_magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_real", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigReal, EigRealDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // wr + .Ret() // wi + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +// Complex-valued eigendecomposition + +template +class EigCompHost { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + public: + explicit EigCompHost() = default; + EigCompHost(EigCompHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecompositionComplex::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + EigenvalueDecompositionComplex::fn(&jobvl_, &jobvr_, &n_, x, &n_, + w, vl, &n_, vr, &n_, work, + &lwork, rwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigCompMagma { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Complex*, int, Complex*, + Complex*, int, Complex*, int, Complex*, int, Real*, int*); + + public: + explicit EigCompMagma() = default; + EigCompMagma(EigCompMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + lda_ = std::max(n_, 1); + ldvl_ = left ? n_ : 1; + ldvr_ = right ? n_ : 1; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Complex query_host; + fn_(jobvl_, jobvr_, n_, nullptr, lda_, nullptr, nullptr, ldvl_, nullptr, + ldvr_, &query_host, -1, nullptr, &query_info); + return static_cast(query_host.real()); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, lda_, w, vl, ldvl_, vr, ldvr_, work, lwork, + rwork, info); + } + + private: + int n_, lda_, ldvl_, ldvr_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigComp(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto w_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto rwork_host = + AllocateScratchMemory(2 * cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { + return std::isfinite(z.real()) && std::isfinite(z.imag()); + }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, w_host.get() + i * cols, + vl_host.get() + i * cols * cols, + vr_host.get() + i * cols * cols, work_host.get(), lwork, + rwork_host.get(), info_host.get() + i); + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + w_host.CopyToDevice(stream, w->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigCompDispatch(gpuStream_t stream, std::string_view magma, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != w->element_type() || dataType != vl->element_type() || + dataType != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::C64: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + case ffi::C128: + if (use_magma) { + return EigComp(EigCompMagma(), batch, cols, + stream, left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_comp", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigComp, EigCompDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // w + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h new file mode 100644 index 000000000000..2890837a2bd5 --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_HYBRID_KERNELS_H_ +#define JAXLIB_GPU_HYBRID_KERNELS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +// The MagmaLookup class is used for dlopening the MAGMA shared library, +// initializing it, and looking up MAGMA symbols. +class MagmaLookup { + public: + explicit MagmaLookup() = default; + ~MagmaLookup(); + absl::StatusOr FindMagmaInit(); + absl::Status Initialize(); + absl::StatusOr Find(const char name[]); + + private: + bool initialized_ = false; + bool failed_ = false; + void* handle_ = nullptr; + std::optional lib_path_ = std::nullopt; + absl::flat_hash_map symbols_; +}; + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_HYBRID_KERNELS_H_ diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 03fd43e9ef89..59819f1fc914 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -56,6 +56,21 @@ xla_client.register_custom_call_target(_name, _value, platform="CUDA", api_version=api_version) +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: + try: + _cuhybrid = importlib.import_module( + f"{cuda_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _cuhybrid = None + else: + break + +if _cuhybrid: + for _name, _value in _cuhybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="CUDA", + api_version=1) + try: from .rocm import _blas as _hipblas # pytype: disable=import-error except ImportError: @@ -88,6 +103,34 @@ xla_client.register_custom_call_target(_name, _value, platform="ROCM", api_version=api_version) +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hiphybrid = importlib.import_module( + f"{rocm_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _hiphybrid = None + else: + break + +if _hiphybrid: + for _name, _value in _hiphybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM", + api_version=1) + +def initialize_hybrid_kernels(): + if _cuhybrid: + _cuhybrid.initialize() + if _hiphybrid: + _hiphybrid.initialize() + +def has_magma(): + if _cuhybrid: + return _cuhybrid.has_magma() + if _hiphybrid: + return _hiphybrid.has_magma() + return False + def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" return np.finfo(dtype).dtype diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index b5bfe733b992..2bae7ab2a203 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -66,6 +66,7 @@ _py_deps = { "filelock": ["@pypi_filelock//:pkg"], "flatbuffers": ["@pypi_flatbuffers//:pkg"], "hypothesis": ["@pypi_hypothesis//:pkg"], + "magma": [], "matplotlib": ["@pypi_matplotlib//:pkg"], "mpmath": [], "opt_einsum": ["@pypi_opt_einsum//:pkg"], diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index c9b73a5785f1..1076f9a77bf8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -389,6 +389,48 @@ pybind_extension( ], ) +cc_library( + name = "hip_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_hybrid", + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_hybrid_kernels", + ":hip_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) + cc_library( name = "triton_kernels", srcs = ["//jaxlib/gpu:triton_kernels.cc"], @@ -456,6 +498,7 @@ py_library( name = "rocm_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_solver", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 5b3ac636303a..9a47c6ad5409 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -108,6 +108,7 @@ def prepare_wheel_cuda( f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", @@ -144,6 +145,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 438cebca2b06..4db36fa0ea97 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -231,6 +231,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) @@ -244,6 +245,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", ], ) diff --git a/tests/BUILD b/tests/BUILD index c80f63e6d7d6..bd4312e4aa24 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -664,6 +664,13 @@ jax_multiplatform_test( }, ) +jax_multiplatform_test( + name = "magma_linalg_test", + srcs = ["magma_linalg_test.py"], + enable_backends = ["gpu"], + deps = py_deps("magma"), +) + jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a1817f528f27..7aad5634775d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1492,8 +1492,8 @@ def testTrimZerosNotOneDArray(self): def testPoly(self, a_shape, dtype, rank): if dtype in (np.float16, jnp.bfloat16, np.int16): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and not jtu.test_device_matches(["cpu"]): - self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") + elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]): + self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } if jtu.test_device_matches(["tpu"]): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d3fe8f476722..d0b109dda07e 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -34,6 +34,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -250,11 +251,11 @@ def testIssue1213(self): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] @@ -293,12 +294,12 @@ def check_left_eigenvectors(a, w, vl): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, @@ -309,15 +310,15 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, @jtu.sample_product( shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types, - ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + ) + @jtu.run_on_devices("cpu", "gpu") def testEigvalsGrad(self, shape, dtype): # This test sometimes fails for large matrices. I (@j-towns) suspect, but # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -329,10 +330,10 @@ def testEigvalsGrad(self, shape, dtype): shape=[(4, 4), (5, 5), (50, 50)], dtype=float_types + complex_types, ) - # TODO: enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -340,9 +341,11 @@ def testEigvals(self, shape, dtype): w2 = jnp.linalg.eigvals(a) self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -350,8 +353,10 @@ def testEigvalsInf(self): shape=[(1, 1), (4, 4), (5, 5)], dtype=float_types + complex_types, ) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape args = rng(shape, dtype) diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py new file mode 100644 index 000000000000..d2abb9fe3a0b --- /dev/null +++ b/tests/magma_linalg_test.py @@ -0,0 +1,125 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np + +from absl.testing import absltest + +import jax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import linalg as lax_linalg +from jax._src.lib import gpu_solver +from jax._src.lib import version as jaxlib_version + +config.parse_flags_with_absl() + +float_types = jtu.dtypes.floating +complex_types = jtu.dtypes.complex + + +class MagmaLinalgTest(jtu.JaxTestCase): + + @jtu.sample_product( + shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEig(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + rng = jtu.rand_default(self.rng()) + n = shape[-1] + args_maker = lambda: [rng(shape, dtype)] + + # Norm, adjusted for dimension and type. + def norm(x): + norm = np.linalg.norm(x, axis=(-2, -1)) + return norm / ((n + 1) * jnp.finfo(dtype).eps) + + def check_right_eigenvectors(a, w, vr): + self.assertTrue( + np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) + + def check_left_eigenvectors(a, w, vl): + rank = len(a.shape) + aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) + wC = jnp.conj(w) + check_right_eigenvectors(aH, wC, vl) + + a, = args_maker() + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + w = results[0] + + if compute_left_eigenvectors: + check_left_eigenvectors(a, w, results[1]) + if compute_right_eigenvectors: + check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) + + self._CompileAndCheck(jnp.linalg.eig, args_maker, rtol=1e-3) + + @jtu.sample_product( + shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)], + dtype=float_types + complex_types, + compute_left_eigenvectors=[False, True], + compute_right_eigenvectors=[False, True], + ) + @jtu.run_on_devices("gpu") + def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, + compute_right_eigenvectors): + """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + # TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for + # complex128 in some configurations. + if dtype == np.complex128: + self.skipTest("MAGMA support for complex128 types is flaky.") + a = jnp.full(shape, jnp.nan, dtype) + results = lax_linalg.eig( + a, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors, + use_magma=True) + for result in results: + self.assertTrue(np.all(np.isnan(result))) + + def testEigMagmaConfig(self): + if jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + rng = jtu.rand_default(self.rng()) + a = rng((5, 5), np.float32) + with config.gpu_use_magma("on"): + hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text() + self.assertIn('magma = "on"', hlo) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 14187399d7d555410fcaf7e18a1d2cfb4ced8987 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 18 Nov 2024 08:51:54 -0800 Subject: [PATCH 378/698] Add new CI script for running Bazel GPU presubmits PiperOrigin-RevId: 697643622 --- .github/workflows/bazel_gpu_rbe.yml | 39 ++++++++++++++++++++++ ci/run_bazel_test_gpu_rbe.sh | 51 +++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 .github/workflows/bazel_gpu_rbe.yml create mode 100755 ci/run_bazel_test_gpu_rbe.sh diff --git a/.github/workflows/bazel_gpu_rbe.yml b/.github/workflows/bazel_gpu_rbe.yml new file mode 100644 index 000000000000..a7cf645b50b3 --- /dev/null +++ b/.github/workflows/bazel_gpu_rbe.yml @@ -0,0 +1,39 @@ +name: CI - Bazel GPU tests (RBE) + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_tests: + if: github.event.repository.fork == false + strategy: + matrix: + runner: ["linux-x86-n2-16"] + + runs-on: ${{ matrix.runner }} + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel GPU Tests with RBE + run: ./ci/run_bazel_test_gpu_rbe.sh \ No newline at end of file diff --git a/ci/run_bazel_test_gpu_rbe.sh b/ci/run_bazel_test_gpu_rbe.sh new file mode 100755 index 000000000000..0c004c584300 --- /dev/null +++ b/ci/run_bazel_test_gpu_rbe.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one +# GPU apiece on RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece). +echo "Running RBE GPU tests..." + +bazel test --config=rbe_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file From e9864c69da9a9c10012d94b013f302e295434efb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 15 Nov 2024 09:02:09 -0800 Subject: [PATCH 379/698] Make logaddexp and logaddexp2 into ufuncs --- jax/_src/numpy/reductions.py | 32 ++++++++++++++++++++++++++++++ jax/_src/numpy/ufuncs.py | 6 ++---- jax/numpy/__init__.pyi | 4 ++-- tests/lax_numpy_ufuncs_test.py | 36 ++++++++++++++++++++++++---------- 4 files changed, 62 insertions(+), 16 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index bc85bc3e8761..69d6843f5155 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -723,6 +723,38 @@ def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None initial=initial, where_=where) +def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Compute log(sum(exp(a))) while avoiding precision loss.""" + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") + dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") + a_arr, = promote_dtypes_inexact(a) + pos_dims, dims = _reduction_dims(a_arr, axis) + amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf) + amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) + amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) + exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) + sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) + result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype)) + return result if initial is None else lax.logaddexp(initial, result) + + +def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Compute log2(sum(2 ** a)) via logsumexp.""" + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") + dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") + ln2 = float(np.log(2)) + if initial is not None: + initial *= ln2 + return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims, + where=where, initial=initial) / ln2 + + @export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index bbbce9733aa5..de8688e491ba 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2687,8 +2687,7 @@ def _pow_int_int(x1, x2): return acc -@export -@jit +@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp) def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. @@ -2714,8 +2713,7 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax_other.logaddexp(x1, x2) -@export -@jit +@binary_ufunc(identity=-np.inf, reduce=reductions._logsumexp2) def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index af7b056fcbb0..b71afebe921c 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -629,8 +629,8 @@ def log(x: ArrayLike, /) -> Array: ... def log10(x: ArrayLike, /) -> Array: ... def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... -def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logaddexp: BinaryUfunc +logaddex2: BinaryUfunc logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... logical_or: BinaryUfunc diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 61c86c0a05e4..20a1a58a9dbe 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -179,13 +179,15 @@ def test_unary_ufunc_call(self, name, dtype, shape): rhs_shape=broadcast_compatible_shapes, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_bimary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + def test_binary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): jnp_fun = getattr(jnp, name) np_fun = getattr(np, name) rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, tol=tol) self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( @@ -218,7 +220,9 @@ def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker, tol=tol) self._CompileAndCheck(jnp_fun.outer, args_maker) @jtu.sample_product( @@ -259,7 +263,9 @@ def test_binary_ufunc_reduce(self, name, shape, axis, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( @@ -315,7 +321,9 @@ def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): rng_where = jtu.rand_bool(self.rng()) args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)] - self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( @@ -356,8 +364,10 @@ def np_fun_accumulate(x): result = np_fun.accumulate(x, axis=axis) return result if x.dtype == bool else result.astype(x.dtype) - self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker) - self._CompileAndCheck(jnp_fun_accumulate, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_accumulate, args_maker, tol=tol) @jtu.sample_product( SCALAR_FUNCS, @@ -400,7 +410,9 @@ def np_fun_at(x, idx): np_fun.at(x_copy, idx) return x_copy - self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_at, args_maker) @jtu.sample_product( @@ -422,7 +434,9 @@ def np_fun_at(x, idx, y): np_fun.at(x_copy, idx, y) return x_copy - self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_at, args_maker) def test_frompyfunc_at_broadcasting(self): @@ -483,7 +497,9 @@ def np_fun_reduceat(x, i): # Numpy has different casting behavior. return np_fun.reduceat(x, i).astype(x.dtype) - self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker) + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker, tol=tol) self._CompileAndCheck(jnp_fun.reduceat, args_maker) From 6fe7b1713a5c6b2de3c7ab2fe04bc36beeb8f8f9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 18 Nov 2024 10:44:59 -0800 Subject: [PATCH 380/698] Return SingleDeviceSharding instead of GSPMDShardings when there is only 1 device during `compiled.input_shardings` call. PiperOrigin-RevId: 697683233 --- jax/_src/interpreters/pxla.py | 9 +++++---- tests/pjit_test.py | 8 ++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6c9e54441f8e..2164c1a914c9 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2747,11 +2747,11 @@ def _maybe_get_and_check_out_shardings( return new_out_shardings -def finalize_out_shardings(out_shardings, device_assignment): +def finalize_shardings(shardings, device_assignment): if len(device_assignment) == 1: return [SingleDeviceSharding(device_assignment[0], memory_kind=o.memory_kind) - if isinstance(o, GSPMDSharding) else o for o in out_shardings] - return out_shardings + if isinstance(o, GSPMDSharding) else o for o in shardings] + return shardings @dataclasses.dataclass @@ -2892,7 +2892,8 @@ def from_hlo(name: str, in_shardings, out_shardings, global_in_avals, global_out_avals, intermediate_shardings, context_mesh) - out_shardings = finalize_out_shardings(out_shardings, da) + in_shardings = finalize_shardings(in_shardings, da) + out_shardings = finalize_shardings(out_shardings, da) return UnloadedMeshExecutable( xla_executable=xla_executable, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index be1f9cfc267a..0c1c28809062 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4624,6 +4624,14 @@ def f(x): jax.jit(f, out_shardings=s)(np.arange(8)) self.assertEqual(count[0], 1) + def test_input_shardings_single_device(self): + @jax.jit + def f(x): + return x * 2 + + ins, _ = f.lower(np.arange(8)).compile().input_shardings + self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") From 5bebd0f6c40e152c90a610db80ae85e04773d088 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 18 Nov 2024 11:04:33 -0800 Subject: [PATCH 381/698] fix typo in numpy/__init__.pyi --- jax/numpy/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b71afebe921c..5d357ab1bb03 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -630,7 +630,7 @@ def log10(x: ArrayLike, /) -> Array: ... def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... logaddexp: BinaryUfunc -logaddex2: BinaryUfunc +logaddexp2: BinaryUfunc logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... logical_or: BinaryUfunc From 0ed6eaeb4a0c5ebf7679f3877b01bd7d6df29bae Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Mon, 18 Nov 2024 12:13:55 -0800 Subject: [PATCH 382/698] [SDY] fix JAX layouts tests for Shardy. PiperOrigin-RevId: 697715276 --- tests/BUILD | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/BUILD b/tests/BUILD index bd4312e4aa24..a645a971a799 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -267,6 +267,9 @@ jax_multiplatform_test( backend_tags = { "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, + enable_configs = [ + "tpu_v3_2x2_shardy", + ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", From 461a2507f8b8e2a4da1d5de9a0c9fee98cfef245 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Nov 2024 17:04:52 -0500 Subject: [PATCH 383/698] Disable some complex function accuracy tests that fail on Mac ARM. --- tests/lax_test.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/lax_test.py b/tests/lax_test.py index 14f453b38e7c..78bc5857acb7 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4398,14 +4398,34 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'tanh': regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj') + elif name == 'arcsin': + if is_arm_cpu and platform.system() == 'Darwin': + regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real') + else: + regions_with_inaccuracies.clear() + + elif name == 'arcsinh': + if is_arm_cpu and platform.system() == 'Darwin': + regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', + 'negj.imag', 'posj.imag') + else: + regions_with_inaccuracies.clear() + elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') - elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', - 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh', 'square'}: + elif name == 'log1p': + if is_arm_cpu and platform.system() == 'Darwin': + regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', + 'posj.imag') + else: + regions_with_inaccuracies.clear() + + elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'tan', + 'arcsinh', 'arccosh', 'arctan', 'arctanh', 'square'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable From f32505169fe98dc3a8f9c66ebd343bd349c8798e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Nov 2024 14:07:13 -0800 Subject: [PATCH 384/698] Filter custom dtypes by supported_dtypes in `_LazyDtypes`. The other methods of `_LazyDtypes` filter by the supported dtypes, so it's strange that this property does not. Change in preparation for landing https://github.com/jax-ml/jax/pull/23585 without breaking existing tests. PiperOrigin-RevId: 697752034 --- jax/_src/test_util.py | 14 +++++++++++--- tests/api_test.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index e546ebd2a0f3..72154fd5871d 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -457,7 +457,15 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64} + np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e5m2} + elif device_under_test() == "gpu": + types = {np.bool_, np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + _dtypes.bfloat16, np.float16, np.float32, np.float64, + np.complex64, np.complex128, + _dtypes.float8_e4m3fn, _dtypes.float8_e5m2} elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: @@ -1464,10 +1472,10 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ + return self.supported([ _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]) @_cached_property def floating(self): diff --git a/tests/api_test.py b/tests/api_test.py index 49cd33ee464c..ae38f50460ab 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4625,7 +4625,7 @@ def test_cache_miss_explanations_no_source_info(self): jax.jit(operator.add)(42, 24) @parameterized.named_parameters([ - {"testcase_name": f"{dtype}", "dtype": dtype} + {"testcase_name": f"{np.dtype(dtype)}", "dtype": dtype} for dtype in jtu.dtypes.custom_floats]) def test_jit_custom_floats(self, dtype): f = lambda x: x + 1 From a60ef6e9bb19a898ab9a87e62fe4ed73c44ede24 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 18 Nov 2024 14:08:04 -0800 Subject: [PATCH 385/698] [Pallas] Increase test coverage of pl.dot. PiperOrigin-RevId: 697752355 --- tests/pallas/ops_test.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 41670137c39f..df48da776e5f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1417,17 +1417,25 @@ def f(x_ref, o_ref): np.testing.assert_allclose(f(x), expected) @parameterized.product( - size=[16, 32, 64], - dtype=["float32", "float16"], + size=[16, 32, 64, 128, 256], + dtype=[jnp.float32, jnp.float16, jnp.bfloat16], trans_x=[False, True], trans_y=[False, True], ) def test_dot(self, size, dtype, trans_x, trans_y): - if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: - self.skipTest("16-bit types are not supported on TPU") + if jtu.test_device_matches(["tpu"]): + if dtype == jnp.float16: + self.skipTest("float16 type is not supported on TPU") + if dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4): + self.skipTest("bfloat16 matmul is supported on TPUv4+") + if trans_x: + self.skipTest("Not implemented: Transposed LHS") - if jtu.test_device_matches(["tpu"]) and trans_x: - self.skipTest("Not implemented: Transposed LHS") + if jtu.test_device_matches(["gpu"]): + if dtype == jnp.bfloat16: + self.skipTest("bfloat16 type are not supported on GPU") + if size > 128: + self.skipTest("Shared memory size limit exceeded") @functools.partial( self.pallas_call, @@ -1444,7 +1452,12 @@ def dot(x_ref, y_ref, o_ref): y = random.normal(k2, (size, size), dtype=dtype) out = dot(x, y) expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected.astype(jnp.float32), + atol=0.05, + rtol=0.05, + ) @parameterized.product( size=[1, 2, 64, 129, 1021], From b3ca6c47cc30cdf6e9e3ff3de1a12b9ee1b4ad81 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 18 Nov 2024 14:21:17 -0800 Subject: [PATCH 386/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/082a7014706f67bb8a42fb1c90051bc4990f2fd3. PiperOrigin-RevId: 697756717 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index b35f9daa2144..71fb2a8e9757 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "58ea2935b4316b48979cb47f617ae06ce9f49638" -XLA_SHA256 = "669eef5690be3e1059de8429cdfbf24bf0a15a5aa6e00b9aefd7a072d839d0aa" +XLA_COMMIT = "082a7014706f67bb8a42fb1c90051bc4990f2fd3" +XLA_SHA256 = "f1ca797df8e95bf13419d20520d2b783f075d80d1c5ddf1506ba427c934de849" def repo(): tf_http_archive( From d4316b5760a824bf044622073812e2f4a094a29d Mon Sep 17 00:00:00 2001 From: barnesjoseph Date: Mon, 18 Nov 2024 14:46:10 -0800 Subject: [PATCH 387/698] Adds font fallbacks --- docs/_static/style.css | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/_static/style.css b/docs/_static/style.css index 32033940e8c4..d801c2a412a6 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -103,7 +103,7 @@ body:has(.hero) .bd-container { display: flex; align-items: center; justify-content: center; - font: 500 24px 'Roboto'; + font: 500 24px 'Roboto', sans-serif; } .getting-started .sd-card-title, @@ -165,13 +165,13 @@ body:has(.hero) .bd-container { } .hero-left h2 { - font: 500 32px 'Google Sans'; + font: 500 32px 'Google Sans', 'Roboto', sans-serif; color: white; margin-top: 0; } .hero-left p { - font: 400 16px 'Roboto'; + font: 400 16px 'Roboto', sans-serif; color: white; } @@ -200,7 +200,7 @@ body:has(.hero) .bd-container { } .product-offerings .sd-card-title { - font: 400 24px 'Google Sans'; + font: 400 24px 'Google Sans', 'Roboto', sans-serif; } .color-cards { From e904c177f7644f0a733501bc548d1f5b237396af Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 18 Nov 2024 15:34:55 -0800 Subject: [PATCH 388/698] Delete _normalized_spec from NamedSharding PiperOrigin-RevId: 697779844 --- jax/_src/array.py | 2 +- jax/_src/core.py | 2 +- jax/_src/sharding_impls.py | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index cf346067ea31..d8182976254e 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1035,7 +1035,7 @@ def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): return self.aval.update(sharding=NamedSharding( self.sharding.mesh.abstract_mesh, - self.sharding._normalized_spec(self.ndim))) + self.sharding.spec._normalized_spec(self.ndim))) else: return self.aval api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array diff --git a/jax/_src/core.py b/jax/_src/core.py index a1fcdac65df0..cbf3282fb2cc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1704,7 +1704,7 @@ def _get_abstract_sharding(val): if (config.sharding_in_types.value and hasattr(val, 'sharding') and isinstance(val.sharding, NamedSharding)): return NamedSharding(val.sharding.mesh.abstract_mesh, - val.sharding._normalized_spec(val.ndim)) + val.sharding.spec._normalized_spec(val.ndim)) return None def primal_dtype_to_tangent_dtype(primal_dtype): diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 8957a6186339..dc4171eec146 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -363,9 +363,6 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) - def _normalized_spec(self, ndim: int) -> PartitionSpec: - return self.spec._normalized_spec(ndim) - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) From 2c68569af05d54a66a3c47b28bc1c20317f9e560 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 18 Nov 2024 16:20:21 -0800 Subject: [PATCH 389/698] Fix a bug where mesh checking was not correct PiperOrigin-RevId: 697792885 --- jax/_src/lax/lax.py | 2 +- tests/pjit_test.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f434413834f7..ff9ac0a49578 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2096,11 +2096,11 @@ def broadcasting_sharding_rule(name, *avals): mesh = None for a in avals: if a.sharding is not None: - mesh = a.sharding.mesh if mesh is not None and mesh != a.sharding.mesh: raise ValueError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') + mesh = a.sharding.mesh assert mesh is not None shapes = [aval.shape for aval in avals if aval.shape] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0c1c28809062..6df011419513 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4979,6 +4979,21 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) + def test_broadcasting_nary_error(self): + mesh1 = Mesh([jax.devices()[0]], 'x') + mesh2 = Mesh([jax.devices()[0]], 'y') + + arr1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) + arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) + + @jax.jit + def f(x, y): + return x + y + + with self.assertRaisesRegex( + ValueError, "Mesh for all inputs should be equal"): + f(arr1, arr2) + def test_sin_unop(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16.).reshape(8, 2) From 45c9c0a585704c0c139a33b838d6827b8d16df5e Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 18 Nov 2024 17:09:28 -0800 Subject: [PATCH 390/698] [pallas] Minor simplifications to Pallas interpreter. BlockMappings are always present now. PiperOrigin-RevId: 697807120 --- jax/_src/pallas/pallas_call.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index f7bd0dd4e4d7..729d0e617a87 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -72,10 +72,6 @@ pallas_call_p.multiple_results = True def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): - if start_idx is None: - assert is_indexing is None - return value - assert is_indexing is not None start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, @@ -84,10 +80,6 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): def _maybe_dynamic_update_slice(start_idx, block_shape, value, update, is_indexing): - if start_idx is None: - assert is_indexing is None - return update - assert is_indexing is not None start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) broadcast_dims = tuple(i for i, b in enumerate(is_indexing) if not b) @@ -234,8 +226,7 @@ def _pallas_call_impl_interpret( for bm in grid_mapping.block_mappings ] block_shapes = [ - None if iid is None - else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) + tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) ] @@ -284,8 +275,9 @@ def body(carry): aval = jax_core.get_aval(s) s.aval = aval.update(dtype=jnp.int32) start_indices = [ - None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) - for bm in grid_mapping.block_mappings] + bm.compute_start_indices_interpret(loop_idx, *scalars) + for bm in grid_mapping.block_mappings + ] blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry_consts_ins, is_indexing_dim) with pallas_core.grid_env(local_grid_env): From c5e8ae80f9949c69bd6b99d245bf599be2644d7b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 18 Nov 2024 09:46:22 -0500 Subject: [PATCH 391/698] Update jax.scipy.special.gamma and gammasgn to return NaN for negative integer inputs. Change to match upstream scipy: https://github.com/scipy/scipy/pull/21827. Fixes #24875 --- CHANGELOG.md | 3 ++ jax/_src/scipy/special.py | 26 ++++++++++++++-- tests/lax_scipy_special_functions_test.py | 37 +++++++++++++++++------ 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 204df6a83e52..9082399c8695 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` on the function inputs. + * {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now + return NaN for negative integer inputs, to match the behavior of SciPy from + https://github.com/scipy/scipy/pull/21827. * `jax.clear_backends` was removed after being deprecated in v0.4.26. * New Features diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 605cde19b1e7..2fffe6381b97 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -66,6 +66,7 @@ def gammaln(x: ArrayLike) -> Array: return lax.lgamma(x) +@jit def gammasgn(x: ArrayLike) -> Array: r"""Sign of the gamma function. @@ -81,6 +82,13 @@ def gammasgn(x: ArrayLike) -> Array: Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Because :math:`\Gamma(x)` is never zero, no condition is required for this case. + * if :math:`x = -\infty`, NaN is returned. + * if :math:`x = \pm 0`, :math:`\pm 1` is returned. + * if :math:`x` is a negative integer, NaN is returned. The sign of gamma + at a negative integer depends on from which side the pole is approached. + * if :math:`x = \infty`, :math:`1` is returned. + * if :math:`x` is NaN, NaN is returned. + Args: x: arraylike, real valued. @@ -92,8 +100,14 @@ def gammasgn(x: ArrayLike) -> Array: - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function """ x, = promote_args_inexact("gammasgn", x) + typ = x.dtype.type floor_x = lax.floor(x) - return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0) + x_negative = x < 0 + return jnp.select( + [(x_negative & (x == floor_x)) | jnp.isnan(x), + (x_negative & (floor_x % 2 != 0)) | ((x == 0) & jnp.signbit(x))], + [typ(np.nan), typ(-1.0)], + typ(1.0)) def gamma(x: ArrayLike) -> Array: @@ -115,6 +129,13 @@ def gamma(x: ArrayLike) -> Array: \Gamma(n) = (n - 1)! + * if :math:`z = -\infty`, NaN is returned. + * if :math:`x = \pm 0`, :math:`\pm \infty` is returned. + * if :math:`x` is a negative integer, NaN is returned. The sign of gamma + at a negative integer depends on from which side the pole is approached. + * if :math:`x = \infty`, :math:`\infty` is returned. + * if :math:`x` is NaN, NaN is returned. + Args: x: arraylike, real valued. @@ -127,7 +148,8 @@ def gamma(x: ArrayLike) -> Array: - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function Notes: - Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs. + Unlike the scipy version, JAX's ``gamma`` does not support complex-valued + inputs. """ x, = promote_args_inexact("gamma", x) return gammasgn(x) * lax.exp(lax.lgamma(x)) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index cb40ae291e76..5753628957c7 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -20,9 +20,11 @@ from absl.testing import parameterized import numpy as np +import scipy import scipy.special as osp_special import jax +import jax.numpy as jnp from jax._src import test_util as jtu from jax.scipy import special as lsp_special @@ -214,7 +216,7 @@ def partial_lax_op(*vals): n=[0, 1, 2, 3, 10, 50] ) def testScipySpecialFunBernoulli(self, n): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. scipy_op = lambda: osp_special.bernoulli(n).astype(dtype) lax_op = functools.partial(lsp_special.bernoulli, n) args_maker = lambda: [] @@ -222,16 +224,33 @@ def testScipySpecialFunBernoulli(self, n): self._CompileAndCheck(lax_op, args_maker, atol=0, rtol=1E-5) def testGammaSign(self): - # Test that the sign of `gamma` matches at integer-valued inputs. - dtype = jax.numpy.zeros(0).dtype # default float dtype. - args_maker = lambda: [np.arange(-10, 10).astype(dtype)] - rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 - self._CheckAgainstNumpy(osp_special.gamma, lsp_special.gamma, args_maker, rtol=rtol) - self._CompileAndCheck(lsp_special.gamma, args_maker, rtol=rtol) + dtype = jnp.zeros(0).dtype # default float dtype. + typ = dtype.type + testcases = [ + (np.arange(-10, 0).astype(dtype), np.array([np.nan] * 10, dtype=dtype)), + (np.nextafter(np.arange(-5, 0).astype(dtype), typ(-np.inf)), + np.array([1, -1, 1, -1, 1], dtype=dtype)), + (np.nextafter(np.arange(-5, 0).astype(dtype), typ(np.inf)), + np.array([-1, 1, -1, 1, -1], dtype=dtype)), + (np.arange(0, 10).astype(dtype), np.ones((10,), dtype)), + (np.nextafter(np.arange(0, 10).astype(dtype), typ(np.inf)), + np.ones((10,), dtype)), + (np.nextafter(np.arange(1, 10).astype(dtype), typ(-np.inf)), + np.ones((9,), dtype)), + (np.array([-np.inf, -0.0, 0.0, np.inf, np.nan]), + np.array([np.nan, -1.0, 1.0, 1.0, np.nan])) + ] + for inp, out in testcases: + self.assertArraysEqual(out, lsp_special.gammasgn(inp)) + self.assertArraysEqual(out, jnp.sign(lsp_special.gamma(inp))) + if jtu.parse_version(scipy.__version__) >= (1, 15): + self.assertArraysEqual(out, osp_special.gammasgn(inp)) + self.assertAllClose(osp_special.gammasgn(inp), + lsp_special.gammasgn(inp)) def testNdtriExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.arange(-10, 10).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) @@ -239,7 +258,7 @@ def testNdtriExtremeValues(self): def testRelEntrExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jnp.zeros(0).dtype # default float dtype. args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype), np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)] rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 From 0fe77bc9f0e5a7c78e3de6371cbbbc9a3a43bf5a Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Mon, 18 Nov 2024 18:06:36 -0800 Subject: [PATCH 392/698] [Mosaic TPU] Support relayout for mask vector We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished. PiperOrigin-RevId: 697823543 --- .../tpu/transforms/apply_vector_layout.cc | 49 ++++++++++++++++--- tests/pallas/tpu_ops_test.py | 18 +++++++ 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 2732b63d7638..8292a770a1c3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6314,6 +6314,14 @@ FailureOr> relayout(RewriteContext &ctx, return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } VectorType vty = v.getType(); + const bool is_mask = vty.getElementTypeBitWidth() == 1; + if (is_mask) { + if (src.bitwidth() != 32 || dst.bitwidth() != 32) { + return emitError(v.getLoc(), + "Not implemented: mask relayout with non-32 bitwidth in " + "vector layout"); + } + } { // Replication imposes a replication constraint on the *logical* value of // the vector: When moving along a replicated axis, all elements must be @@ -6347,6 +6355,31 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); + if (is_mask) { + auto new_tile_ty = + getNativeVregOrVmaskType(builder.getI32Type(), 32, target_shape); + src_tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = + builder.create(tile->getLoc(), new_tile_ty, *tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI32Type()); + } + auto assemble_with_mask_check = [&](xla::Array &tiles, + bool use_implicit_shape = false) { + if (is_mask) { + auto zeros_tile = builder.create( + tiles.begin()->getLoc(), + DenseElementsAttr::get(cast(tiles.begin()->getType()), + builder.getI32IntegerAttr(0))); + tiles.Each([&](const absl::Span idx, Value *tile) { + *tile = builder.create( + tile->getLoc(), arith::CmpIPredicate::ne, *tile, zeros_tile); + }); + vty = VectorType::get(vty.getShape(), builder.getI1Type()); + } + return assemble(builder, vty, dst, tiles, target_shape, use_implicit_shape) + .getResult(); + }; // Two easy cases: source is more general, or is replicated. if (src.generalizes(dst, vty.getShape(), target_shape)) { // A value with a replicated offset might use fewer vregs than a value with @@ -6397,9 +6430,8 @@ FailureOr> relayout(RewriteContext &ctx, .getResult(); } src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { @@ -6410,8 +6442,7 @@ FailureOr> relayout(RewriteContext &ctx, xla::Array dst_tiles( /*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape), /*value=*/src_tiles.data()[0]); - return assemble(builder, vty, dst, std::move(dst_tiles), target_shape) - .getResult(); + return assemble_with_mask_check(dst_tiles); } // Consider (1,128),-2 -> (8,128). In this case we can change the implicit @@ -6449,9 +6480,8 @@ FailureOr> relayout(RewriteContext &ctx, dst.offsets())); CHECK_EQ(src, dst); // At this point we've should be done. - return assemble(builder, vty, dst, std::move(src_tiles), target_shape, - /*use_implicit_shape=*/true) - .getResult(); + return assemble_with_mask_check(src_tiles, + /*use_implicit_shape=*/true); } // TODO(apaszke): Implement a debug mode that inserts additional assertions. @@ -6491,6 +6521,9 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); + if (*lo == *li) { + continue; + } OpBuilder builder(&op); FAILUREOR_ASSIGN_OR_RETURN( Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index ca5361a70051..8843c6a58064 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -233,6 +233,24 @@ def run(cond, lhs, rhs): assert (run(cond, lhs, rhs) == lhs).all() + def test_logical_and_relayouted_mask(self): + def get_mask(x_ref): + x = x_ref[...] == 1 + iota = jax.lax.broadcasted_iota(jnp.int32, x_ref.shape, 1) + iota = iota > 7 + return jnp.logical_and(x, iota) + + def body(x_ref, y_ref): + y_ref[...] = jnp.where(get_mask(x_ref), 0.0, -1.0) + + shape = (2, 512) + out = jax.ShapeDtypeStruct(shape, jnp.float32) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape(shape) + result = self.pallas_call(body, out_shape=out)(x) + expected = jnp.ones(x.shape, dtype=jnp.float32) + expected = expected.at[...].set(jnp.where(get_mask(x), 0.0, -1.0)) + np.testing.assert_array_equal(result, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True From d397dd968468dc054b91aacd1958a8586c409878 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 18 Nov 2024 23:58:40 -0800 Subject: [PATCH 393/698] Implement lax.pad in Pallas. PiperOrigin-RevId: 697897093 --- jax/_src/pallas/mosaic/lowering.py | 65 +++++++++++++++++++ .../tpu/transforms/apply_vector_layout.cc | 10 ++- tests/pallas/ops_test.py | 58 +++++++++++++++-- 3 files changed, 127 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3dbb410be29f..be4102dff716 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3243,3 +3243,68 @@ def _lower_fun(shape): lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering + + +def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): + operand, padding_value = args + padding_config = kwargs["padding_config"] + + out_type: ir.VectorType = aval_to_ir_type(ctx.avals_in[0]) + if not isinstance(out_type, ir.VectorType): + raise NotImplementedError("Only vector types are supported.") + + for axis, (low, high, interior) in enumerate(padding_config): + if low == 0 and high == 0 and interior == 0: + continue + + def _pad(val): + shape = list(operand.type.shape) + shape[axis] = val + pad_vec_type = ir.VectorType.get( + shape, + operand.type.element_type, + ) + + if isinstance(padding_value, ir.OpResult): + pad = vector.BroadcastOp( + pad_vec_type, + padding_value, + ).result + else: + scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value) + pad = arith.ConstantOp( + pad_vec_type, + ir.DenseElementsAttr.get_splat( + pad_vec_type, + scalar_attr, + ), + ).result + return pad + + if low != 0: + pad_low = _pad(low) + new_shape = out_type.shape + new_shape[axis] += low + out_type = ir.VectorType.get( + new_shape, + out_type.element_type, + ) + operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis) + + if high != 0: + pad_high = _pad(high) + new_shape = out_type.shape + new_shape[axis] += high + out_type = ir.VectorType.get( + new_shape, + out_type.element_type, + ) + operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis) + + if interior > 0: + raise NotImplementedError("Not implemented: interior padding") + + return operand + + +lowering_rules[lax.pad_p] = _pad_lowering_rule diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8292a770a1c3..4a344fa9d427 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2674,6 +2674,13 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, for (size_t i = 0; i < operand_vregs.size(); ++i) { auto &vreg = operand_vregs[i]; const auto &layout = layouts_in[i]; + const int packing = res_layout->packing(); + + if (layout->tiling()[0] % packing != 0) { + return op.emitOpError( + "Illegal tiling: Non-native tiling in concat - this should " + "have been caught earlier!"); + } const int64_t operand_offset = *layout->offsets()[tiling_dim.value()]; if (operand_offset != 0) { @@ -2685,7 +2692,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, } const auto bitwidth = res_ty.getElementTypeBitWidth(); - const int packing = res_layout->packing(); SmallVector out_idx; vreg.Each([&](absl::Span idx, Value *v) { out_idx.assign(idx.begin(), idx.end()); @@ -2716,7 +2722,7 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, mask = builder.create( op.getLoc(), vmask_ty, ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(layout->tiling()[0]), + ArrayRef{boundIdxConst(layout->tiling()[0] / packing), boundIdxConst(operand_offset)}); } // Blend the current value with the existing value in the output. diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index df48da776e5f..9f0b9aef5af3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -21,12 +21,9 @@ from typing import Any import unittest -import numpy as np from absl.testing import absltest from absl.testing import parameterized - import jax -import jax.numpy as jnp from jax import lax from jax import random from jax._src import config @@ -34,8 +31,10 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu -from jax.interpreters import partial_eval as pe from jax.experimental import pallas as pl +from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import numpy as np if sys.platform != "win32": from jax.experimental.pallas import triton as plgpu @@ -1980,6 +1979,57 @@ def convert(x_ref, y_ref): y_ref = jax.lax.bitcast_convert_type(x, out_dtype) np.testing.assert_array_equal(y, y_ref) + @parameterized.product( + array_shapes=[(4, 128), (10, 100), (8, 128), (17, 257)], + padding=[ + ((5, 8), (0, 0)), + ((0, 0), (5, 100)), + ((1, 1), (1, 1)), + ((0, 0), (0, 0)), + ], + pad_type=["constant", "wrap"], + dtype=( + jnp.float32, + jnp.bfloat16, + ), + ) + def test_arbitrary_padding_jnp_pad( + self, array_shapes, padding, pad_type, dtype + ): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Not implemented on GPU") + + x = jnp.arange(np.prod(array_shapes), dtype=dtype).reshape(array_shapes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.pad(x_ref[...], padding, mode=pad_type) + + ref = jnp.pad(x, padding, mode=pad_type) + + out_shape = jax.ShapeDtypeStruct(ref.shape, x.dtype) + try: + out = self.pallas_call( + kernel, + out_shape=out_shape, + )(x) + np.testing.assert_array_equal(out, jnp.pad(x, padding, mode=pad_type)) + except Exception as e: + self.assertEqual( + dtype, + jnp.bfloat16, + "some bfloat16 combinations can fail with not implemented", + ) + # The first two options are expected to fail due to current limitations + # in the Pallas TPU lowering. However, the last one is unexpected, and + # should be fixed, it is a pjrt bug. + # b/379787665 + acceptable_errors = ( + "Only 32-bit types supported" in str(e) + or "Not implemented" in str(e) + or "Expected mask vector type" in str(e) + ) + self.assertTrue(acceptable_errors, "Failed with error: " + str(e)) + class OpsInterpretTest(OpsTest): INTERPRET = True From da50ad7ee395eec84930d7a1c87346a547b0ae07 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 19 Nov 2024 01:47:13 -0800 Subject: [PATCH 394/698] [AutoPGLE] Use compile options to override debug options instead of XLA_FLAGS. PiperOrigin-RevId: 697924164 --- tests/pgle_test.py | 327 +++++++++++++++++++++------------------------ 1 file changed, 153 insertions(+), 174 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 609ca38fd7a5..46146abfc7c6 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import ExitStack from functools import partial import glob import logging @@ -43,41 +42,7 @@ @jtu.pytest_mark_if_available('multiaccelerator') -# TODO(patrios): Remove this skip once b/379267258 is fixed. -@jtu.skip_under_pytest( - 'This test requires specific XLA_FLAGS. However, pytest does not reload ' - 'modules between tests. So if another test is launched before this one ' - 'necessary XLA_FLAGS will not be re-used by the XLA.') class PgleTest(jtu.JaxTestCase): - _dump_exit_stack: ExitStack | None = None - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._dump_exit_stack = ExitStack() - - cls.dump_dir = cls._dump_exit_stack.enter_context(tempfile.TemporaryDirectory()) - if 'XLA_FLAGS' in os.environ: - cls.old_xla_flags = os.environ['XLA_FLAGS'] - else: - cls.old_xla_flags = None - - os.environ['XLA_FLAGS'] = ( - f'--xla_dump_to={cls.dump_dir}' - ' --xla_gpu_experimental_dump_fdo_profiles=true' - ' --xla_gpu_enable_latency_hiding_scheduler=true' - # TODO(patrios): Remove this flag once b/376647494 is fixed. - ' --xla_gpu_graph_level=0' - ) - if cls.old_xla_flags: - os.environ['XLA_FLAGS'] += ' ' + cls.old_xla_flags - - @classmethod - def tearDownClass(cls): - if cls.old_xla_flags: - os.environ['XLA_FLAGS'] = cls.old_xla_flags - cls._dump_exit_stack.close() - super().tearDownClass() def setUp(self): super().setUp() @@ -85,12 +50,6 @@ def setUp(self): cc.reset_cache() def tearDown(self): - # Cleanup dump directory - for file in os.listdir(self.dump_dir): - file_path = os.path.join(self.dump_dir, file) - if os.path.isfile(file_path): - os.remove(file_path) - cc.set_cache_dir(None) super().tearDown() @@ -101,6 +60,7 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) def f(x, y): return x @ y @@ -130,6 +90,11 @@ def testPGLEProfilerGetFDOProfileLarge(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + }, ) def f(x): agg = x @@ -154,6 +119,11 @@ def testAutoPgle(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + }, ) def f(x): return x * 2 @@ -172,7 +142,7 @@ def f(x): # Run 2: Second PGLE run should not recompile the module with jtu.count_cached_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 0) + self.assertLess(cache_miss_count[0], 2) # Run 3: The module should be recompiled with FDO profiles with jtu.count_cached_compilation_cache_miss() as cache_miss_count: @@ -182,7 +152,7 @@ def f(x): # Run 4: Fast-path should be used after PGLE is done with jtu.count_cached_compilation_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 0) + self.assertLess(cache_miss_count[0], 2) def testAutoPgleWithAot(self): @jax.jit @@ -211,145 +181,154 @@ def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - ) - def f(x): - agg = x - for _ in range(its): - agg = agg @ x - return agg - - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - - with (config.enable_compilation_cache(True), - config.enable_pgle(True), - config.raise_persistent_cache_errors(True), - config.raise_persistent_cache_errors(True), - config.persistent_cache_min_entry_size_bytes(0), - config.persistent_cache_min_compile_time_secs(0), - config.pgle_profiling_runs(2), - tempfile.TemporaryDirectory() as cache_dir): - cc.reset_cache() - cc.set_cache_dir(cache_dir) - # Run 1: Module should be compiled without FDO - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + with tempfile.TemporaryDirectory() as dump_dir: + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True' + }, + ) + def f(x): + agg = x + for _ in range(its): + agg = agg @ x + return agg + + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + + with (config.enable_compilation_cache(True), + config.enable_pgle(True), + config.raise_persistent_cache_errors(True), + config.raise_persistent_cache_errors(True), + config.persistent_cache_min_entry_size_bytes(0), + config.persistent_cache_min_compile_time_secs(0), + config.pgle_profiling_runs(2), + tempfile.TemporaryDirectory() as cache_dir): + cc.reset_cache() + cc.set_cache_dir(cache_dir) + # Run 1: Module should be compiled without FDO + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertGreater(cache_miss_count[0], 0) + + # Non-pgle profiled version of module should be saved + non_pgle_profiled_files = os.listdir(cache_dir) + self.assertNotEmpty(non_pgle_profiled_files) + + # Run 2: Compilation should not be called + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertLess(cache_miss_count[0], 2) + + module_before_pgle = os.listdir(dump_dir) + self.assertNotEmpty(module_before_pgle) + # Run 3: Module should be compiled with FDO and stored to persistent cache + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + f(x) + self.assertGreater(cache_miss_count[0], 0) + + # Check if FDO profile file of the biggest module is not empty + module_after_pgle = [ + x + for x in os.listdir(dump_dir) + if x not in module_before_pgle + ] + self.assertNotEmpty(module_after_pgle) + biggest_module_after_pgle = max( + module_after_pgle, + key=lambda x: os.path.getsize( + os.path.join(dump_dir, x) + ), + ) + base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) + + # Check if FDO profile file in dump directory is not empty + for module in module_after_pgle: + if module.startswith(base_module_name) and module.endswith( + '.fdo_profile' + ): + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, module)), 0 + ) + + for pgle_profiler in pjit._pgle_profiler_dict.values(): + self.assertTrue(pgle_profiler.is_enabled()) + self.assertTrue(pgle_profiler.is_fdo_consumed()) + + files_after_pgle_profile = os.listdir(cache_dir) + self.assertGreater( + len(files_after_pgle_profile), len(non_pgle_profiled_files) + ) + + # Removing non-pgle profiled module from cache to check that later pgle + # profiled version will be used. + for non_pgle_file in non_pgle_profiled_files: + path = os.path.join(cache_dir, non_pgle_file) + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + # Run 4: Persistent compilation cache should be hit PGLE profiler should + # be disabled + cache_hit = 0 + def check_if_cache_hit(event): + nonlocal cache_hit + if event == '/jax/compilation_cache/cache_hits': + cache_hit += 1 + + monitoring.register_event_listener(check_if_cache_hit) f(x) - self.assertGreater(cache_miss_count[0], 0) + monitoring._unregister_event_listener_by_callback(check_if_cache_hit) - # Non-pgle profiled version of module should be saved - non_pgle_profiled_files = os.listdir(cache_dir) - self.assertNotEmpty(non_pgle_profiled_files) + self.assertGreater(cache_hit, 0) - # Run 2: Compilation should not be called - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - f(x) - self.assertEqual(cache_miss_count[0], 0) + def testPassingFDOProfile(self): + mesh = jtu.create_mesh((2,), ('x',)) - module_before_pgle = os.listdir(self.dump_dir) - self.assertNotEmpty(module_before_pgle) - # Run 3: Module should be compiled with FDO and stored to persistent cache - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - f(x) - self.assertGreater(cache_miss_count[0], 0) - - # Check if FDO profile file of the biggest module is not empty - module_after_pgle = [ - x - for x in os.listdir(self.dump_dir) - if x not in module_before_pgle - ] - self.assertNotEmpty(module_after_pgle) - biggest_module_after_pgle = max( - module_after_pgle, - key=lambda x: os.path.getsize( - os.path.join(self.dump_dir, x) - ), - ) - base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) - - # Check if FDO profile file in dump directory is not empty - for module in module_after_pgle: - if module.startswith(base_module_name) and module.endswith( - '.fdo_profile' - ): - self.assertGreater( - os.path.getsize(os.path.join(self.dump_dir, module)), 0 - ) - - for pgle_profiler in pjit._pgle_profiler_dict.values(): - self.assertTrue(pgle_profiler.is_enabled()) - self.assertTrue(pgle_profiler.is_fdo_consumed()) - - files_after_pgle_profile = os.listdir(cache_dir) - self.assertGreater( - len(files_after_pgle_profile), len(non_pgle_profiled_files) + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, ) + def f(x, y): + return x @ y - # Removing non-pgle profiled module from cache to check that later pgle - # profiled version will be used. - for non_pgle_file in non_pgle_profiled_files: - path = os.path.join(cache_dir, non_pgle_file) - if os.path.isfile(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) - - api.clear_caches() - pjit._pgle_profiler_dict.clear() - - # Run 4: Persistent compilation cache should be hit PGLE profiler should - # be disabled - cache_hit = 0 - def check_if_cache_hit(event): - nonlocal cache_hit - if event == '/jax/compilation_cache/cache_hits': - cache_hit += 1 - - monitoring.register_event_listener(check_if_cache_hit) - f(x) - monitoring._unregister_event_listener_by_callback(check_if_cache_hit) - - self.assertGreater(cache_hit, 0) - - def testPassingFDOProfile(self): - mesh = jtu.create_mesh((2,), ('x',)) - - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - ) - def f(x, y): - return x @ y - - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - y = x + 1 + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 - with config.pgle_profiling_runs(0): - f_lowered = f.lower(x, y) - compiled = f_lowered.compile() + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() - with tempfile.TemporaryDirectory() as cache_dir: - jax.profiler.start_trace(cache_dir) - compiled(x, y) - jax.profiler.stop_trace() - directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) - directories = [d for d in directories if os.path.isdir(d)] - rundir = directories[-1] - logging.info('rundir: %s', rundir) - fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) - - if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda(): - self.assertIn(b'custom', fdo_profile) - - logging.info('fdo_profile: %s', fdo_profile) - # Test pass fdo_profile as compiler_options API works. - f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) + with tempfile.TemporaryDirectory() as cache_dir: + jax.profiler.start_trace(cache_dir) + compiled(x, y) + jax.profiler.stop_trace() + directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) + directories = [d for d in directories if os.path.isdir(d)] + rundir = directories[-1] + logging.info('rundir: %s', rundir) + fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) + + if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda(): + self.assertIn(b'custom', fdo_profile) + + logging.info('fdo_profile: %s', fdo_profile) + # Test pass fdo_profile as compiler_options API works. + f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) if __name__ == '__main__': From 1458d3dd562c6c1004c9ac1162de731fed91ec68 Mon Sep 17 00:00:00 2001 From: nireekshak Date: Tue, 19 Nov 2024 15:04:55 +0000 Subject: [PATCH 395/698] Fix some typos --- docs/Custom_Operation_for_GPUs.md | 6 +++--- docs/advanced-autodiff.md | 5 ++--- docs/autodidax.ipynb | 2 +- docs/autodidax.md | 2 +- docs/autodidax.py | 2 +- docs/notebooks/shard_map.ipynb | 2 +- docs/notebooks/shard_map.md | 2 +- 7 files changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index f4b61cbcf7dc..2163272e2542 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -623,16 +623,16 @@ be used with the custom_partitioning registration and for the gradient. (And if you implement the interface to support vmat, it will also be on the outer primitive). -JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic. +JAX custom_partitioning implementations are callbacks from XLA to Python during XLA sharding logic. XLA sharding goes in two phases: a sharding propagation phase and a partition phase. -The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph. +The propagation phase is when XLA plan the sharding to be created. It is the partition phase that creates the sharded graph. For XLA to be able to shard our custom operations, it needs us to define 2 extra functions: infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively. The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding. The partition() function will do a few things: -- tell which input sharding will be expected. XLA will reshad if needed. +- tell which input sharding will be expected. XLA will reshard if needed. - tell the final version of the output sharding. - give a function that will create the new instruction from the sharded inputs. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index 023dc8040954..c56e82c77450 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -350,7 +350,7 @@ This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \math and so on. -To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of the two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. +To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. ## How it's made: Two foundational autodiff functions @@ -475,7 +475,7 @@ where we use `CT a` to denote the type for the cotangent space for `a`. In words This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.vmap` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. -There's a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). +There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/). @@ -1762,7 +1762,6 @@ print(grad(app, 1)(lambda x: x ** 2, 4.)) Refer to `fixed_point` above for another usage example. **You don't need to use** `nondiff_argnums` **with array-valued arguments**, such as, for example, ones with the integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments. -s ## Next steps diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 8b418b16f878..e620967de4b7 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2797,7 +2797,7 @@ "representing unknown outputs, we need avals, which we get from the abstract\n", "eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n", "`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n", - "weakrefs.)\n", + "`weakref`s.)\n", "\n", "That `process_primitive` logic applies to most primitives, but `xla_call_p`\n", "requires recursive treatment. So we special-case its rule in a\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 9e726e5ed82e..1c16db80f608 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -2195,7 +2195,7 @@ output. If instead any input is unknown then we instead stage out into a representing unknown outputs, we need avals, which we get from the abstract eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using -weakrefs.) +`weakref`s.) That `process_primitive` logic applies to most primitives, but `xla_call_p` requires recursive treatment. So we special-case its rule in a diff --git a/docs/autodidax.py b/docs/autodidax.py index f57af2cd96f2..f74617f31416 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -2187,7 +2187,7 @@ def full_lower(self): # representing unknown outputs, we need avals, which we get from the abstract # eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and # `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using -# weakrefs.) +# `weakref`s.) # # That `process_primitive` logic applies to most primitives, but `xla_call_p` # requires recursive treatment. So we special-case its rule in a diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 37c27ce2728a..d73b0d4c0f3e 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -864,7 +864,7 @@ "Indeed, this implementation is often used on both TPU and GPU!\n", "\n", "The reason `psum_scatter` can require about half the communication as a full\n", - "`psum` is illustrated the `ppermute` section.\n", + "`psum` is illustrated in the `ppermute` section.\n", "\n", "Another intuition is that we can use `psum_scatter` to implement a distributed\n", "matrix multiplication with inputs and outputs sharded over the same axis. In\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 47b11079e27d..c52cf0e6d22b 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -627,7 +627,7 @@ def psum(x, axis_name): Indeed, this implementation is often used on both TPU and GPU! The reason `psum_scatter` can require about half the communication as a full -`psum` is illustrated the `ppermute` section. +`psum` is illustrated in the `ppermute` section. Another intuition is that we can use `psum_scatter` to implement a distributed matrix multiplication with inputs and outputs sharded over the same axis. In From d912034cb5e6c5584255621b34f958f2846d1d11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 19 Nov 2024 16:42:19 +0100 Subject: [PATCH 396/698] fix(docs): typos in macro name chore(docs): sync .md file --- docs/ffi.ipynb | 4 ++-- docs/ffi.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index f1a699b5c56c..72a2a6914fc0 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -139,8 +139,8 @@ "}\n", "\n", "// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare\n", - "// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`\n", - "// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.\n", + "// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`\n", + "// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.\n", "XLA_FFI_DEFINE_HANDLER_SYMBOL(\n", " RmsNorm, RmsNormImpl,\n", " ffi::Ffi::Bind()\n", diff --git a/docs/ffi.md b/docs/ffi.md index dbe901237ed4..96b627675004 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -134,8 +134,8 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x, } // Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare -// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL` -// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`. +// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL` +// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`. XLA_FFI_DEFINE_HANDLER_SYMBOL( RmsNorm, RmsNormImpl, ffi::Ffi::Bind() From 3556a8333443228d341245ee59278c7c93e22238 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 19 Nov 2024 09:52:15 -0800 Subject: [PATCH 397/698] Add missing version guard in GPU tests for jnp.poly. jaxlib v0.4.35 is required for running `jnp.linalg.eig` on GPU which is required for `poly`. PiperOrigin-RevId: 698052642 --- tests/lax_numpy_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7aad5634775d..ef80e368c9c7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -51,6 +51,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal +from jax._src.lib import version as jaxlib_version from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace config.parse_flags_with_absl() @@ -1494,6 +1495,8 @@ def testPoly(self, a_shape, dtype, rank): self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.") + if rank == 2 and jaxlib_version <= (0, 4, 35) and jtu.test_device_matches(["gpu"]): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } if jtu.test_device_matches(["tpu"]): From 6c31efa3f324a810461389f728ab848abffd767f Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 19 Nov 2024 10:32:28 -0800 Subject: [PATCH 398/698] [Mosaic TPU] Add general tpu.vector_store and support masked store. This cl introduces a general store op called tpu.vector_stores which aims to unify vector::store, tpu::strided_load, vector::masked_store. The tpu.vector_stores should also provide general interface for lowering for both TensorCore and SparseCore. This cl also adds the support for (dynamic) masked store. PiperOrigin-RevId: 698067741 --- jaxlib/mosaic/dialect/tpu/tpu.td | 16 +++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 27 ++++++- .../tpu/transforms/apply_vector_layout.cc | 71 ++++++++++++++----- .../tpu/transforms/infer_vector_layout.cc | 15 +++- 4 files changed, 107 insertions(+), 22 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index b312bca7a7d3..4fd960063dc4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -214,6 +214,22 @@ def TPU_LoadOp : TPU_Op<"load"> { }]; } +// TODO(jevinjiang): migrate tpu.strided_store to general vector store op. +def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { + let arguments = (ins + AnyVector:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask // Elementwise mask. + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) + }]; + let hasVerifier = 1; +} + def TPU_StridedLoadOp : TPU_Op<"strided_load"> { let arguments = (ins AnyMemRef:$base, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 6f690f6a0fcb..96b78c8caf37 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -440,6 +440,31 @@ LogicalResult StridedStoreOp::verify() { getValueToStore().getType()); } +LogicalResult VectorStoreOp::verify() { + if (!getStrides().empty()) { + return emitError("Not implemented: general vector store with strides."); + } + VectorType value_ty = getValueToStore().getType(); + MemRefType ref_ty = getBase().getType(); + + if (value_ty.getElementType() != ref_ty.getElementType()) { + return emitOpError( + "Expected base and valueToStore element type should match"); + } + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " indices"; + } + if (getMask()) { + if (value_ty.getElementTypeBitWidth() != 32) { + return emitError( + "Not implemented: masked store with non-32-bit element type"); + } + if (value_ty.getShape() != getMask().getType().getShape()) + return emitOpError("Expected valueToStore shape to match mask shape"); + } + return success(); +} + LogicalResult ReinterpretCastOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); @@ -468,7 +493,7 @@ LogicalResult verifyRotateOp(Op op) { } if (op.getStride().has_value() != op.getStrideDimension().has_value()) { op.emitOpError( - "Expected either none or both stride and stride dimension are " + "Expected either none or both stride and stride dimension are " "present"); return failure(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 4a344fa9d427..8ade7450881a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4200,18 +4200,15 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, shape_cast_op->erase(); return success(); } -LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(layouts_out.size(), 0); - MLIRContext *const mlir_ctx = op.getContext(); - TPU_ASSERT_OP(layouts_in.front().has_value()); - TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), - [&](const Layout &l) { return l.has_value(); })); + +template +LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, + const VectorLayout &to_store_layout, + TypedValue store_mask = nullptr) { + Operation &op = *(store_op.getOperation()); + MLIRContext *const mlir_ctx = store_op.getContext(); ImplicitLocOpBuilder builder(op.getLoc(), &op); - vector::StoreOp store_op = cast(op); const VectorType ty = store_op.getValueToStore().getType(); - const VectorLayout &to_store_layout = *layouts_in.front(); const auto memref_ty = getMemRefType(store_op.getBase()); if (!ty.getRank()) { return op.emitOpError("Not implemented: scalar stores to vmem"); @@ -4308,10 +4305,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, } else { // Convert dynamic store to dynamic slice + static store. This saves us a // bunch of scalar core work. - auto slice_result = - sliceRef(builder, store_op.getBase(), - store_op.getVectorType().getShape(), store_op.getIndices(), - ArrayRef(memref_tiling).take_back(tiled_dims)); + auto slice_result = sliceRef( + builder, store_op.getBase(), ty.getShape(), store_op.getIndices(), + ArrayRef(memref_tiling).take_back(tiled_dims)); if (failed(slice_result)) { return failure(); } @@ -4332,6 +4328,13 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, xla::Array tiles, disassemble(builder, to_store_layout, store_op.getValueToStore(), ctx.target_shape)); + std::optional> tile_masks; + if (store_mask) { + FAILUREOR_ASSIGN_OR_RETURN( + tile_masks, + disassemble(builder, to_store_layout, store_mask, ctx.target_shape)); + TPU_ASSERT_EQ_OP(tile_masks->dimensions(), tiles.dimensions()); + } const int64_t ndims = ty.getRank(); const auto base_s = is_1d ? IdxConst(0, builder, op.getLoc()) : tile_base_idxs.front(); @@ -4353,6 +4356,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, const absl::Status status = tiles.EachStatus([&](const absl::Span idx, const Value tile) -> absl::Status { + const auto tile_mask = store_mask ? (*tile_masks)(idx) : nullptr; const std::unique_ptr bounds = to_store_layout.tileDataBounds(mlir_ctx, stored_shape, toArrayRef(idx), ctx.target_shape); @@ -4412,19 +4416,19 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, updated = builder.create(mask, tile, data); } builder.create( - updated, base_addr, indices, sublane_mask, - /*mask=*/nullptr, + updated, base_addr, indices, sublane_mask, tile_mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } else { builder.create( tile, base_addr, indices, sublane_mask, - /*mask=*/mask, + tile_mask + ? builder.create(mask, tile_mask).getResult() + : mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } } else { builder.create( - tile, base_addr, indices, sublane_mask, - /*mask=*/nullptr, + tile, base_addr, indices, sublane_mask, tile_mask, /*sublane_stride=*/builder.getI32IntegerAttr(sublane_stride)); } return absl::OkStatus(); @@ -4434,7 +4438,35 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, } store_op->erase(); return success(); +} + +LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto store_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(), + [&](const Layout &l) { return l.has_value(); })); + return vector_store_impl(ctx, store_op, *layouts_in.front()); +} + +LogicalResult tpu_vector_store_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto store_op = cast(op); + TPU_ASSERT_EQ_OP(layouts_out.size(), 0); + TPU_ASSERT_OP(layouts_in.front().has_value()); + auto other_layouts_in = layouts_in.drop_front(); + if (store_op.getMask()) { + TPU_ASSERT_EQ_OP(layouts_in.front(), layouts_in.back()); + other_layouts_in = other_layouts_in.drop_back(); } + TPU_ASSERT_OP(llvm::none_of(other_layouts_in, + [&](const Layout &l) { return l.has_value(); })); + return vector_store_impl(ctx, store_op, *layouts_in.front(), + store_op.getMask()); +} LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, @@ -4648,6 +4680,7 @@ const llvm::StringMap &rules() { {tpu::StoreOp::getOperationName(), tpu_store_rule}, {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, + {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, {tpu::RegionOp::getOperationName(), tpu_region_rule}, {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 30486b6e995c..d84e4b883172 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -321,8 +321,14 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op)) { + if (inferStore(op, + /*has_mask=*/op.getMask() != nullptr) + .failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { - if (infer(op).failed()) { + if (inferStore(op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -1540,7 +1546,8 @@ class VectorLayoutInferer { return failure(); } - LogicalResult infer(vector::StoreOp op) { + template + LogicalResult inferStore(Op op, bool has_mask = false) { auto ref_ty = getMemRefType(op.getBase()); auto store_ty = op.getValueToStore().getType(); TPU_CHECK_OP(ref_ty.getRank() == store_ty.getRank(), @@ -1648,6 +1655,10 @@ class VectorLayoutInferer { } SmallVector in_layout{store_layout}; in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout); + if (has_mask) { + // Mask layout should be the same as the layout of value to store. + in_layout.push_back(store_layout); + } setInLayout(op, in_layout); return success(); } From c44f11d15e60ccb27d9c21a13a5e789ebded7713 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Tue, 19 Nov 2024 11:25:51 -0800 Subject: [PATCH 399/698] Add alternate implementation of threefry as a pallas kernel. Current restrictions: 1) Dynamic grid sizes are not supported yet. This could in theory allow us to not recompile the kernel for different shapes. 2) fold_in and split still use the original rules. But there isn't a huge benefit to using the kernel right now since the input is so small and we can't avoid re-compilation due to (1). 3) Currently doesn't support high bits on the counter, meaning we can generate at max 4B numbers in one call. This is a fringe use-case since we only support 32-bit, and generating 4B 32-bit numbers would consume 16GB of HBM (an entire TPU v5p worth of HBM). PiperOrigin-RevId: 698086352 --- .../pallas/ops/tpu/random/threefry.py | 156 ++++++++++++++++++ tests/pallas/BUILD | 4 + tests/pallas/tpu_pallas_random_test.py | 51 ++++++ 3 files changed, 211 insertions(+) create mode 100644 jax/experimental/pallas/ops/tpu/random/threefry.py diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py new file mode 100644 index 000000000000..d1e6bf1fd93d --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -0,0 +1,156 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the Threefry PRNG as a Pallas kernel.""" +from typing import Sequence +import jax +from jax import lax +from jax._src import prng +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + +Shape = Sequence[int] + +BLOCK_SIZE = (256, 256) + +_round_up = lambda x, y: (x + y - 1) // y * y + + +def blocked_iota(block_shape: Shape, + total_shape: Shape): + """Computes a sub-block of a larger shaped iota. + + Args: + block_shape: The output block shape of the iota. + total_shape: The total shape of the input tensor. + Returns: + Result of the blocked iota. + """ + iota_data = jnp.zeros(block_shape, dtype=jnp.uint32) + multiplier = 1 + for dim in range(len(block_shape)-1, -1, -1): + block_mult = 1 + counts_lo = lax.broadcasted_iota( + dtype=jnp.uint32, shape=block_shape, dimension=dim + ) + iota_data += counts_lo * multiplier * block_mult + multiplier *= total_shape[dim] + return iota_data + + +def _compute_scalar_offset(iteration_index, + total_size: Shape, + block_size: Shape): + ndims = len(iteration_index) + dim_size = 1 + total_idx = 0 + for i in range(ndims-1, -1, -1): + dim_idx = iteration_index[i] * block_size[i] + total_idx += dim_idx * dim_size + dim_size *= total_size[i] + return total_idx + + +def threefry_2x32_count(key, + shape: Shape, + unpadded_shape: Shape, + block_size: tuple[int, int]): + """Generates random bits using the Threefry hash function. + + This function is a fusion of prng.shaped_iota and prng.threefry_2x32 from + the JAX core library. + + Args: + key: A threefry key of shape (2,). + shape: The shape of the output. Must be divisible by `block_size`. + unpadded_shape: If `shape` is padded, then this is the shape of the + output tensor if it were not padded. This is important for indexing + calculations within the kernel. If `shape` is not padded, then this + should be equal to `shape`. + block_size: The block size of the kernel. + + Returns: + A tensor of random bits of shape `shape`. + """ + shape = tuple(shape) + if np.prod(shape) > jnp.iinfo(jnp.uint32).max: + raise ValueError( + f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}") + + if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0): + raise ValueError( + f"Shape dimension {shape[-2:]} must be divisible by {block_size}") + grid_dims = shape[:-2] + ( + shape[-2] // block_size[-2], shape[-1] // block_size[1],) + + def kernel(key_ref, out_ref): + counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims))) + offset = _compute_scalar_offset(counts_idx, unpadded_shape, block_shape) + counts_lo = blocked_iota(block_size, unpadded_shape) + counts_lo = counts_lo + offset + counts_lo = counts_lo.astype(jnp.uint32) + # TODO(justinfu): Support hi bits on count. + counts_hi = jnp.zeros_like(counts_lo) + k1 = jnp.reshape(key_ref[0, 0], (1, 1)) + k2 = jnp.reshape(key_ref[0, 1], (1, 1)) + o1, o2 = prng.threefry2x32_p.bind( + k1, k2, counts_hi, counts_lo) + out_bits = o1 ^ o2 + out_ref[...] = out_bits.reshape(out_ref.shape) + + key = key.reshape((1, 2)) + out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32) + block_shape = (1,) * (len(shape)-2) + block_size + result = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs), + grid=grid_dims, + out_shape=out, + )(key) + return result + +def plthreefry_random_bits(key, bit_width: int, shape: Shape): + if bit_width != 32: + raise ValueError("Only 32-bit PRNG supported.") + if len(shape) == 0: + return plthreefry_random_bits(key, bit_width, (1, 1))[0, 0] + elif len(shape) == 1: + return plthreefry_random_bits(key, bit_width, (1, *shape))[0] + + requires_pad = ( + shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0) + if requires_pad: + padded_shape = tuple(shape[:-2]) + ( + _round_up(shape[-2], BLOCK_SIZE[-2]), + _round_up(shape[-1], BLOCK_SIZE[-1]), + ) + padded_result = threefry_2x32_count( + key, padded_shape, shape, block_size=BLOCK_SIZE) + return padded_result[..., :shape[-2], :shape[-1]] + else: + return threefry_2x32_count(key, shape, shape, block_size=BLOCK_SIZE) + + +plthreefry_prng_impl = prng.PRNGImpl( + key_shape=(2,), + seed=prng.threefry_seed, + split=prng.threefry_split, + random_bits=plthreefry_random_bits, + fold_in=prng.threefry_fold_in, + name="pallas_threefry2x32", + tag="plfry") + +prng.register_prng(plthreefry_prng_impl) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 92cab875df7d..50c1054ba9fd 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -394,9 +394,13 @@ jax_multiplatform_test( "tpu_pallas_random_test.py", ], enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5p_2x2", + ], deps = [ "//jax:pallas", "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", "//jax/_src/pallas/mosaic:random", ] + py_deps("absl/testing") + py_deps("numpy"), ) diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index 2b5c315263c9..88c33a020ce9 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -20,10 +20,14 @@ from jax._src import test_util as jtu from jax._src.pallas.mosaic import random as plrandom from jax.experimental import pallas as pl +from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.random import threefry # pylint: disable=unused-import # noqa: F401 import jax.numpy as jnp import numpy as np +P = jax.sharding.PartitionSpec + jax.config.parse_flags_with_absl() @@ -253,6 +257,53 @@ def body(key_ref, o_ref): ) np.testing.assert_array_equal(result, jax_result) + @parameterized.parameters( + ((512, 512),), + ((137, 275),), # Non block-aligned shape + ((4, 512, 512),), # Greater than 2D shape + ((34,),), # 1D + (tuple(),), # 0D + ) + def test_threefry_kernel_matches_jax_threefry(self, shape): + with jax.threefry_partitionable(True): + key_jax = jax_random.key(0, impl="threefry2x32") + jax_gen = jax_random.bits(key_jax, shape=shape) + key_pl = jax_random.key(0, impl="pallas_threefry2x32") + pl_gen = jax_random.bits(key_pl, shape=shape) + + np.testing.assert_array_equal(jax_gen, pl_gen) + + @parameterized.parameters( + ((256, 256),), + ((35, 113),), # Non block-aligned shape + ((331,),), # 1D + ) + def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): + if jax.device_count() < 2: + self.skipTest("Need at least 2 devices") + num_devices = jax.device_count() + partition = P("x") + mesh = jax.make_mesh((num_devices,), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + with jax.threefry_partitionable(True): + key_jax = jax_random.split( + jax_random.key(0, impl="threefry2x32"), num_devices) + key_pallas = jax_random.split( + jax_random.key(0, impl="pallas_threefry2x32"), num_devices) + key_jax = jax.device_put(key_jax, sharding) + key_pallas = jax.device_put(key_pallas, sharding) + generate = shard_map.shard_map( + lambda x: jax_random.bits(x[0], shape=shape), + mesh=mesh, + in_specs=partition, + out_specs=partition, + ) + jax_gen = generate(key_jax) + pl_gen = generate(key_pallas) + + np.testing.assert_array_equal(jax_gen, pl_gen) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From a59bbb7cd721cc146a499e9ef37577e94a7357fb Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 19 Nov 2024 11:59:59 -0800 Subject: [PATCH 400/698] Add test utility for accessing jaxlib version tuple. We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests. PiperOrigin-RevId: 698097487 --- jax/_src/test_util.py | 5 +++++ tests/compilation_cache_test.py | 5 ++--- tests/linalg_test.py | 13 ++++++------- tests/magma_linalg_test.py | 7 +++---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 527d7a46ed13..c5a713743fb8 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -44,6 +44,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes as _dtypes +from jax._src import lib as _jaxlib from jax._src import linear_util as lu from jax._src import monitoring from jax._src import pjit as pjit_lib @@ -451,6 +452,10 @@ def assert_num_jit_and_pmap_compilations(times): f"but executed {count[0]}") +def jaxlib_version() -> tuple[int, ...]: + return _jaxlib.version + + def device_under_test(): return _TEST_DUT.value or xla_bridge.get_backend().platform diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index d10558afbe16..0f949aaf1490 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -41,7 +41,6 @@ from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface from jax._src.lib import xla_client as xc -from jax._src.lib import version as jaxlib_version from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -538,7 +537,7 @@ def test_backend_serialization_deserialization(self): executable.fingerprint, deserialized_executable.fingerprint) def test_persistent_cache_enable_xla_caches(self): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("Test requires AutotuneCacheMode bindings") with config.compilation_cache_dir("jax-cache"): with config.persistent_cache_enable_xla_caches("none"): @@ -609,7 +608,7 @@ def test_tasks_disable_cache_metric(self): self.assertEqual(count_after_second_use, count_after_first_use) def test_persistent_cache_enable_xla_caches_disabled(self): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("Test requires AutotuneCacheMode bindings") with config.enable_compilation_cache(False): compile_options = compiler.get_compile_options( diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d0b109dda07e..7c135b4ffeca 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -34,7 +34,6 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -254,7 +253,7 @@ def testIssue1213(self): @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] @@ -298,7 +297,7 @@ def check_left_eigenvectors(a, w, vl): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( @@ -317,7 +316,7 @@ def testEigvalsGrad(self, shape, dtype): # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -332,7 +331,7 @@ def testEigvalsGrad(self, shape, dtype): ) @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] @@ -344,7 +343,7 @@ def testEigvals(self, shape, dtype): @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -355,7 +354,7 @@ def testEigvalsInf(self): ) @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): - if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py index d2abb9fe3a0b..bf9c0fb6b51d 100644 --- a/tests/magma_linalg_test.py +++ b/tests/magma_linalg_test.py @@ -24,7 +24,6 @@ from jax._src import test_util as jtu from jax._src.lax import linalg as lax_linalg from jax._src.lib import gpu_solver -from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -43,7 +42,7 @@ class MagmaLinalgTest(jtu.JaxTestCase): @jtu.run_on_devices("gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") @@ -94,7 +93,7 @@ def check_left_eigenvectors(a, w, vl): def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") @@ -111,7 +110,7 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, self.assertTrue(np.all(np.isnan(result))) def testEigMagmaConfig(self): - if jaxlib_version <= (0, 4, 35): + if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("eig on GPU requires jaxlib version > 0.4.35") if not gpu_solver.has_magma(): self.skipTest("MAGMA is not installed or can't be loaded.") From 2c80d1af50ed580d2fb34bb45a471cce11679d99 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Nov 2024 10:57:07 -0500 Subject: [PATCH 401/698] Add a new API jax.lax.split. This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently. Before: ``` In [1]: import jax.numpy as jnp, jax In [2]: x = jnp.ones((3,)) In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr Out[3]: { lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let f:f32[5,3] = pjit[ name=unstack jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let l:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] k m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0 n:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] j o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0 p:f32[5,3] = add_any m o q:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] i r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0 s:f32[5,3] = add_any p r t:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] h u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0 v:f32[5,3] = add_any s u w:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] g x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0 y:f32[5,3] = add_any v x in (y,) } ] a b c d e in (f,) } ``` Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents. After: ``` In [1]: import jax.numpy as jnp, jax In [2]: x = jnp.ones((3,)) In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr Out[3]: { lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let f:f32[5,3] = pjit[ name=unstack jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let l:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] k m:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] j n:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] i o:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] h p:f32[1,3] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 3) sharding=None ] g q:f32[5,3] = concatenate[dimension=0] p o n m l in (q,) } ] a b c d e in (f,) } ``` --- CHANGELOG.md | 3 + docs/jax.lax.rst | 1 + jax/_src/lax/lax.py | 96 ++++++++++++++++++++++++++---- jax/_src/numpy/array_methods.py | 3 +- jax/_src/numpy/lax_numpy.py | 31 +++++----- jax/_src/pallas/mosaic/lowering.py | 21 +++++++ jax/experimental/jax2tf/jax2tf.py | 6 ++ jax/experimental/jet.py | 1 + jax/lax/__init__.py | 2 + tests/lax_autodiff_test.py | 18 ++++++ tests/lax_test.py | 27 +++++++++ tests/lax_vmap_test.py | 18 ++++++ 12 files changed, 197 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9082399c8695..a0901e87ccfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now supported on GPU. See {jax-issue}`#24663` for more details. + * Added {func}`jax.lax.split`. This is a primitive version of + {func}`jax.numpy.split`, added because it yields a more compact + transpose in automatic differentiation. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 065127718c54..d8a28bc399c8 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -154,6 +154,7 @@ Operators slice_in_dim sort sort_key_val + split sqrt square squeeze diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ff9ac0a49578..e97427445aef 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -654,6 +654,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: return concatenate_p.bind(*operands, dimension=dimension) +def split(operand: ArrayLike, sizes: Sequence[int], + axis: int = 0) -> Sequence[Array]: + """Splits an array along ``axis``. + + Args: + operand: an array to split + sizes: the sizes of the split arrays. The sum of the sizes must be equal + to the size of the ``axis`` dimension of ``operand``. + axis: the axis along which to split the array. + + Returns: + A sequence of ``len(sizes)`` arrays. If ``sizes`` is + ``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``, + taken along ``axis``. + """ + operand = asarray(operand) + return split_p.bind(operand, sizes=tuple(sizes), + axis=canonicalize_axis(axis, operand.ndim)) + + _precision_strings: dict[Any, Precision] = {} class Precision(enum.Enum): @@ -4373,18 +4393,8 @@ def _concatenate_transpose_rule(t, *operands, dimension): return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None for o in operands] else: - limit_points = np.cumsum( - [shape[dimension] for shape in operand_shapes]).tolist() - starts = np.zeros((len(operands), t.ndim), dtype=int).tolist() - limits = np.tile(t.shape, (len(operands), 1)).tolist() - - for i, s in enumerate(starts[1:]): - s[dimension] = limit_points[:-1][i] - for i, l in enumerate(limits): - l[dimension] = limit_points[i] - - return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o) - else None for o, start, limit in zip(operands, starts, limits)] + return split(t, tuple(shape[dimension] for shape in operand_shapes), + axis=dimension) def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) @@ -4413,6 +4423,68 @@ def _concatenate_lower(ctx, *xs, dimension): mlir.register_lowering(concatenate_p, _concatenate_lower) +def _split_shape_rule(operand, *, sizes, axis): + offset = 0 + shapes = [] + shape = list(operand.shape) + if any(s < 0 for s in sizes): + raise ValueError( + f"Sizes passed to split must be nonnegative, got {list(sizes)}") + if operand.shape[axis] != np.sum(sizes): + raise ValueError( + f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the " + f"operand shape {list(operand.shape)}") + for size in sizes: + shape[axis] = size + shapes.append(tuple(shape)) + return shapes + +def _split_dtype_rule(operand, *, sizes, axis): + return (operand.dtype,) * len(sizes) + +def _split_weak_type_rule(operand, *, sizes, axis): + return (operand.weak_type,) * len(sizes) + +def _split_transpose_rule(cotangents, operand, *, sizes, axis): + assert ad.is_undefined_primal(operand) + if all(type(t) is ad_util.Zero for t in cotangents): + return ad_util.Zero(operand.aval), + cotangents = [ + _zeros(t.aval) if type(t) is ad_util.Zero else t + for t in cotangents + ] + return concatenate(cotangents, dimension=axis), + +def _split_batch_rule(batched_args, batch_dims, *, sizes, axis): + operand, = batched_args + bdim, = batch_dims + new_bdims = (bdim,) * len(sizes) + out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis) + return out, new_bdims + +def _split_lower(ctx, x, *, sizes, axis): + x_aval, = ctx.avals_in + start_indices = [0] * x_aval.ndim + limit_indices = list(x_aval.shape) + strides = (1,) * x_aval.ndim + outs = [] + for aval_out in ctx.avals_out: + limit_indices[axis] = start_indices[axis] + aval_out.shape[axis] + outs.append(mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, + limit_indices=limit_indices, strides=strides)) + start_indices[axis] = limit_indices[axis] + return outs + +split_p = core.Primitive('split') +split_p.multiple_results = True +split_p.def_abstract_eval( + partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, + _split_dtype_rule, _split_weak_type_rule)) +split_p.def_impl(partial(dispatch.apply_primitive, split_p)) +ad.deflinear2(split_p, _split_transpose_rule) +batching.primitive_batchers[split_p] = _split_batch_rule +mlir.register_lowering(split_p, _split_lower) + def _pad_dtype_rule(operand, padding_value, *, padding_config): if operand.dtype != padding_value.dtype: msg = "pad operand and padding_value must be same dtype: got {} and {}." diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 4768a8126c72..617213ca03de 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -629,7 +629,8 @@ def _multi_slice(self: Array, # avoid circular imports. @jax.jit def _unstack(x: Array) -> list[Array]: - return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] + dims = (0,) + return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] def _chunk_iter(x, size): if size > x.shape[0]: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 898e4255dd8e..d256c97a9957 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, + ceil_of_ratio, partition_list, safe_zip, set_module, unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) @@ -3280,10 +3280,10 @@ def _split(op: str, ary: ArrayLike, if (isinstance(indices_or_sections, (tuple, list)) or isinstance(indices_or_sections, (np.ndarray, Array)) and indices_or_sections.ndim > 0): - indices_or_sections = [ + split_indices = np.asarray([0] + [ core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1") - for i_s in indices_or_sections] - split_indices = [0] + list(indices_or_sections) + [size] + for i_s in indices_or_sections] + [size]) + sizes = list(np.diff(split_indices)) else: if core.is_symbolic_dim(indices_or_sections): raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is " @@ -3292,21 +3292,14 @@ def _split(op: str, ary: ArrayLike, f"in jax.numpy.{op} argument 1") part_size, r = divmod(size, num_sections) if r == 0: - split_indices = [i * part_size - for i in range(num_sections + 1)] + sizes = [part_size] * num_sections elif op == "array_split": - split_indices = ( - [i * (part_size + 1) for i in range(r + 1)] + - [i * part_size + ((r + 1) * (part_size + 1) - 1) - for i in range(num_sections - r)]) + sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r) else: raise ValueError(f"array split does not result in an equal division: rest is {r}") - split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] - for i in split_indices] - starts, ends = [0] * ndim(ary), shape(ary) - _subval = lambda x, i, v: subvals(x, [(i, v)]) - return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) - for start, end in zip(split_indices[:-1], split_indices[1:])] + sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] + for i in sizes] + return list(lax.split(ary, sizes, axis=axis)) @export @@ -4669,7 +4662,11 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: "Unstack requires arrays with rank > 0, however a scalar array was " "passed." ) - return tuple(moveaxis(x, axis, 0)) + dimensions = (axis,) + return tuple( + lax.squeeze(t, dimensions) + for t in lax.split(x, (1,) * x.shape[axis], axis=axis) + ) @export diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index be4102dff716..f0286c156e45 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1871,6 +1871,27 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule +def _split_lowering_rule( + ctx: LoweringRuleContext, x, *, sizes, axis +): + (x_aval,) = ctx.avals_in + slice_size = np.array(x_aval.shape, dtype=np.int64) + starts = np.zeros_like(slice_size) + strides = np.ones_like(slice_size) + outs = [] + for size, aval_out in zip(sizes, ctx.avals_out): + slice_size[axis] = size + outs.append( + vector.extract_strided_slice( + aval_to_ir_type(aval_out), x, starts, slice_size, strides + ) + ) + starts[axis] += size + return outs + +lowering_rules[lax.split_p] = _split_lowering_rule + + def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): out_type = aval_to_ir_type(ctx.avals_out[0]) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c41eda693d7f..2cc670ef6a43 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2087,6 +2087,12 @@ def _concatenate(*operands, dimension): tf_impl[lax.concatenate_p] = _concatenate +def _split(operand, *, sizes, axis): + return tf.split(operand, sizes, axis=axis) + +tf_impl[lax.split_p] = _split + + def _conv_general_dimension_numbers_proto(dimension_numbers): """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 2681ad1a2a7b..29ec21319361 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -323,6 +323,7 @@ def linear_prop(prim, primals_in, series_in, **params): deflinear(lax.convert_element_type_p) deflinear(lax.broadcast_in_dim_p) deflinear(lax.concatenate_p) +deflinear(lax.split_p) deflinear(lax.pad_p) deflinear(lax.reshape_p) deflinear(lax.squeeze_p) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index d569ed641138..dc9c69d97795 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -203,6 +203,8 @@ sort as sort, sort_key_val as sort_key_val, sort_p as sort_p, + split as split, + split_p as split_p, sqrt as sqrt, sqrt_p as sqrt_p, square as square, diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 78d90cb8a072..c7cbde069cc8 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -273,6 +273,24 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs): concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(base_shape)) + ], + num_pieces=range(3), + dtype=float_dtypes, + ) + def testSplitGrad(self, axis, base_shape, dtype, num_pieces): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + operands = (rng(shape, dtype),) + split = lambda x: lax.split(x, sizes, axis) + check_grads(split, operands, 2, ["fwd", "rev"], eps=1.) + + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides) for lhs_shape, rhs_shape, all_strides in itertools.chain( diff --git a/tests/lax_test.py b/tests/lax_test.py index 78bc5857acb7..48f70baa1e32 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -283,6 +283,33 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs): numpy_op = lambda *args: lax_reference.concatenate(args, dim) self._CheckAgainstNumpy(numpy_op, op, args_maker) + @jtu.sample_product( + [dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(shape))], + num_pieces=range(3), + dtype=lax_test_util.default_dtypes, + ) + def testSplit(self, axis, base_shape, dtype, num_pieces): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + op = lambda x: lax.split(x, sizes, axis=axis) + def numpy_op(x): + return np.split(x, np.cumsum(sizes[:-1]), axis=axis) + self._CompileAndCheck(op, args_maker) + self._CheckAgainstNumpy(numpy_op, op, args_maker) + + def testSplitErrors(self): + with self.assertRaisesRegex(ValueError, + "Sizes passed to split must be nonnegative"): + lax.split(np.arange(5), [-1]) + with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"): + lax.split(np.arange(5), [6]) + with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"): + lax.split(np.arange(5), sizes=(), axis=1) + @jtu.sample_product( [ dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 83d4d657751b..49e06e17be15 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -344,6 +344,24 @@ def testSlice(self, shape, dtype, starts, limits, strides, bdims): op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis, bdims=bdims) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(base_shape)) + for bdims in lax_test_util.all_bdims(base_shape) + ], + num_pieces=range(3), + dtype=lax_test_util.default_dtypes, + ) + def testSplit(self, base_shape, dtype, num_pieces, axis, bdims): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + op = lambda x: lax.split(x, sizes, axis) + self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng, + multiple_results=True) + @jtu.sample_product( [dict(shape=shape, perm=perm, bdims=bdims) for shape, perm in [ From 1bf70fbbc42b28c4d929e70bc949347b9b5732ae Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 19 Nov 2024 13:01:53 -0800 Subject: [PATCH 402/698] [pallas:mosaic_gpu] `copy_gmem_to_smem` no longer requires `barrier` to be a keyword argument ... because there really isn't any reason to require that. PiperOrigin-RevId: 698116984 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 1 - .../pallas/ops/gpu/attention_mgpu.py | 10 +++++----- tests/pallas/mosaic_gpu_test.py | 18 +++++++++--------- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 91e1e1c45429..9b6adc86f981 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -63,7 +63,7 @@ def copy_in(self, slot, grid_indices, barrier_ref): gpu_primitives.copy_gmem_to_smem( self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands self.smem_ref.at[slot], - barrier=barrier_ref.at[slot], + barrier_ref.at[slot], ) def copy_out(self, slot, grid_indices, predicate=None): diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 5fc4ed5e7afc..36dcba5d15d0 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -218,7 +218,6 @@ def _copy_gmem_to_smem_lowering( def copy_gmem_to_smem( src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef, - *, barrier: pallas_core.AbstractMemoryRef, ) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 56db5379d5e2..1c5b4d9f741b 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -89,7 +89,7 @@ def _compute_wg(): plgpu.copy_gmem_to_smem( q_ref.at[pl.ds(q_seq_base, block_q), q_head], qo_smem, - barrier=q_barriers.at[wg_idx], + q_barriers.at[wg_idx], ) plgpu.barrier_wait(q_barriers.at[wg_idx]) @@ -166,17 +166,17 @@ def _memory_wg(): kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) for i in range(max_concurrent_steps): s = (pl.ds(i * block_kv, block_kv), kv_head) - plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], barrier=k_barriers.at[i]) - plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], barrier=v_barriers.at[i]) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) s = (pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barrier) - plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], barrier=k_barriers.at[tma_slot]) + plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barrier) - plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], barrier=v_barriers.at[tma_slot]) + plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) def kv_epilogue(i, _): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 83202937503d..fe52a33c1637 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -263,7 +263,7 @@ def test_copy_gmem_to_smem(self, indexer): ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.copy_gmem_to_smem( - x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier=barrier_ref + x_ref_gmem.at[indexer], scratch_ref.at[indexer], barrier_ref ) plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 @@ -284,7 +284,7 @@ def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.copy_gmem_to_smem( - x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer] + x_ref_gmem, scratch_ref, barrier_ref.at[indexer] ) plgpu.barrier_wait(barrier_ref.at[indexer]) o_ref[...] = scratch_ref[...] + 1 @@ -296,7 +296,7 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_copy_with_transforms(self, to_smem): def kernel(x_ref, o_ref, barrier_ref): if to_smem: - plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) else: plgpu.commit_smem() @@ -329,7 +329,7 @@ def test_scoped_copy_with_transforms(self): ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): - plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) o_ref[...] = tmp_ref[...] * 2 pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) @@ -351,7 +351,7 @@ def body(tmp_ref): def test_copy_with_transforms_and_indexing(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): - plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) @@ -379,7 +379,7 @@ def test_indexing_before_transpose(self): def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem( - x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier=barrier_ref + x_ref, plgpu.transpose_ref(o_ref.at[i], (1, 0, 2)), barrier_ref ) plgpu.barrier_wait(barrier_ref) @@ -407,7 +407,7 @@ def test_copy_gmem_to_smem_in_run_scoped(self): def kernel(x_ref_gmem, o_ref): def body(barrier_ref): def inner_body(scratch_ref): - plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier=barrier_ref) + plgpu.copy_gmem_to_smem(x_ref_gmem, scratch_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32)) @@ -1092,7 +1092,7 @@ def body(step, _): lambda: plgpu.copy_gmem_to_smem( x_gmem.at[gmem_slice, pl.ds(fetch_step * 16, 16)], x_smem.at[fetch_slot], - barrier=barrier.at[fetch_slot], + barrier.at[fetch_slot], ), lambda: None, ) @@ -1103,7 +1103,7 @@ def body(step, _): plgpu.copy_gmem_to_smem( x_gmem.at[gmem_slice, pl.ds(slot * 16, 16)], x_smem.at[slot], - barrier=barrier.at[slot], + barrier.at[slot], ) jax.lax.fori_loop(0, num_steps, body, ()) From 0d36b0b433a93c707f86dac89b0c05d40302775a Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Tue, 19 Nov 2024 13:39:38 -0800 Subject: [PATCH 403/698] [Mosaic] Add target core type parameter to tpu.sem_signal Adds the optional core type parameter to `tpu.sem_signal` for cross-core signalling. If the target core type is not provided, the target core type is assumed to be that of the core issuing the signal. The issuing core type is determined based on the core type annotation of the parent function; if the annotation is not provided, the issuing core type is assumed to be TensorCore. PiperOrigin-RevId: 698129842 --- jaxlib/mosaic/BUILD | 1 + jaxlib/mosaic/dialect/tpu/tpu.td | 12 ++++++--- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 10 +++++++ jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 5 ++++ jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 34 ++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 14f3ee13c0f5..da7498ed437d 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -62,6 +62,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 4fd960063dc4..55d2e1ec975e 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -653,12 +653,18 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { MemRefOf<[TPU_SemaphoreType]>:$semaphore, I32:$amount, Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore + Optional:$core_id, // For megacore + OptionalAttr:$core_type ); - let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? attr-dict `:` type($semaphore) +let assemblyFormat = [{ + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; + let builders = [ + // A backward-compatible builder that sets `core_type` to nullptr. + OpBuilder<(ins "Value":$semaphore, "Value":$amount, + "Value":$device_id, "Value":$core_id)>, + ]; } def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 10ab154b7c10..92e8953837e3 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "absl/hash/hash.h" #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" @@ -81,6 +82,15 @@ void TPUDialect::initialize() { return mlir::cast(attr).getValue(); } +FailureOr> GetCoreTypeOfParentFunc(Operation &op) { + mlir::Operation *func_op = op.getParentOfType(); + if (func_op == nullptr) { + return op.emitError() << "Operation " << op.getName() + << " is not inside a func.func"; + } + return TPUDialect::GetCoreTypeAttr(func_op); +} + void VectorLayoutAttr::print(AsmPrinter &printer) const { printer << '<'; printer << getLayout(); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index dbb2ddaa5853..a8569acc6239 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -94,6 +95,10 @@ std::unique_ptr> createDebugAssertInsertionPass(); #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" +// Determine the core type of the given op based on the `tpu.core_type` +// annotation of its parent function. +FailureOr> GetCoreTypeOfParentFunc(Operation &op); + // Changes the memory space of the value and propagates it through the program. LogicalResult specializeMemorySpace(TypedValue value, MemorySpace memory_space); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 96b78c8caf37..b4dcca66f7dc 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -28,9 +28,12 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" +#include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/include/mlir/IR/BuiltinTypes.h" #include "mlir/include/mlir/IR/IRMapping.h" +#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -837,11 +840,42 @@ LogicalResult GetBarrierSemaphoreOp::verify() { return success(); } +void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state, + Value semaphore, Value amount, Value device_id, + Value core_id) { + build(builder, state, semaphore, amount, device_id, core_id, + /*core_type=*/nullptr); +} + LogicalResult SemaphoreSignalOp::verify() { auto sem_type = getMemRefType(getSemaphore()); if (sem_type.getRank() != 0) { return emitOpError("Semaphore reference must be rank 0"); } + + FailureOr> issuing_core_type_maybe = + GetCoreTypeOfParentFunc(**this); + if (failed(issuing_core_type_maybe)) { + return issuing_core_type_maybe; + } + CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc); + CoreType target_core_type = getCoreType().value_or(issuing_core_type); + + if (getCoreId() == nullptr && getDeviceId() == nullptr) { + if (target_core_type != issuing_core_type) { + return emitOpError( + absl::StrFormat("Target core type (%s) must match source core type " + "(%s) when device_id and core_id are not specified", + stringifyCoreType(target_core_type), + stringifyCoreType(issuing_core_type))); + } + } + if ((issuing_core_type == CoreType::kTc && + target_core_type == CoreType::kScScalarSubcore) || + (issuing_core_type == CoreType::kScScalarSubcore && + target_core_type == CoreType::kTc)) { + return emitOpError("Signalling between TC and SC is not implemented"); + } return success(); } From 3161a28424995d231b56c38ac43b89c0807d683a Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 19 Nov 2024 14:00:40 -0800 Subject: [PATCH 404/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/229f376e046b9a51039dc1566d1e388ee7c1ca6d. PiperOrigin-RevId: 698136955 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 71fb2a8e9757..99c2af75f3ad 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "082a7014706f67bb8a42fb1c90051bc4990f2fd3" -XLA_SHA256 = "f1ca797df8e95bf13419d20520d2b783f075d80d1c5ddf1506ba427c934de849" +XLA_COMMIT = "229f376e046b9a51039dc1566d1e388ee7c1ca6d" +XLA_SHA256 = "895b39b5cb298460185f29df3ecc8882f4ee151b0f7dc93e5387ef81ea32e374" def repo(): tf_http_archive( From 42fbd301fc7bed57386423722c1a2ddae11f91ec Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Tue, 19 Nov 2024 14:18:32 -0800 Subject: [PATCH 405/698] Move JAX example to public XLA:CPU API PiperOrigin-RevId: 698143471 --- examples/jax_cpp/BUILD | 7 ++++++- examples/jax_cpp/main.cc | 12 ++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index 6e4647b5e491..b3cb995aae21 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -26,8 +26,13 @@ cc_binary( "@tsl//tsl/platform:platform_port", "@xla//xla:literal", "@xla//xla:literal_util", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", "@xla//xla/pjrt:pjrt_client", - "@xla//xla/pjrt/cpu:cpu_client", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/service:hlo_module_config", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 2a8f8d4debba..ceac2cd2d7c9 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -36,15 +36,21 @@ limitations under the License. // } // ) +#include #include #include #include #include "third_party/absl/status/statusor.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" @@ -66,8 +72,10 @@ int main(int argc, char** argv) { // Run it using JAX C++ Runtime (PJRT). // Get a CPU client. + xla::CpuClientOptions options; + options.asynchronous = true; std::unique_ptr client = - xla::GetTfrtCpuClient(/*asynchronous=*/true).value(); + xla::GetXlaPjrtCpuClient(options).value(); // Compile XlaComputation to PjRtExecutable. xla::XlaComputation xla_computation(test_module_proto); From 525b646c0ebd5205f4fa0639c94adb2de47e1cf0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Nov 2024 14:46:44 -0800 Subject: [PATCH 406/698] Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a PiperOrigin-RevId: 698152759 --- CHANGELOG.md | 3 - docs/jax.lax.rst | 1 - jax/_src/lax/lax.py | 96 ++++-------------------------- jax/_src/numpy/array_methods.py | 3 +- jax/_src/numpy/lax_numpy.py | 31 +++++----- jax/_src/pallas/mosaic/lowering.py | 21 ------- jax/experimental/jax2tf/jax2tf.py | 6 -- jax/experimental/jet.py | 1 - jax/lax/__init__.py | 2 - tests/lax_autodiff_test.py | 18 ------ tests/lax_test.py | 27 --------- tests/lax_vmap_test.py | 18 ------ 12 files changed, 30 insertions(+), 197 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0901e87ccfc..9082399c8695 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,9 +59,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now supported on GPU. See {jax-issue}`#24663` for more details. - * Added {func}`jax.lax.split`. This is a primitive version of - {func}`jax.numpy.split`, added because it yields a more compact - transpose in automatic differentiation. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index d8a28bc399c8..065127718c54 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -154,7 +154,6 @@ Operators slice_in_dim sort sort_key_val - split sqrt square squeeze diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e97427445aef..ff9ac0a49578 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -654,26 +654,6 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: return concatenate_p.bind(*operands, dimension=dimension) -def split(operand: ArrayLike, sizes: Sequence[int], - axis: int = 0) -> Sequence[Array]: - """Splits an array along ``axis``. - - Args: - operand: an array to split - sizes: the sizes of the split arrays. The sum of the sizes must be equal - to the size of the ``axis`` dimension of ``operand``. - axis: the axis along which to split the array. - - Returns: - A sequence of ``len(sizes)`` arrays. If ``sizes`` is - ``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``, - taken along ``axis``. - """ - operand = asarray(operand) - return split_p.bind(operand, sizes=tuple(sizes), - axis=canonicalize_axis(axis, operand.ndim)) - - _precision_strings: dict[Any, Precision] = {} class Precision(enum.Enum): @@ -4393,8 +4373,18 @@ def _concatenate_transpose_rule(t, *operands, dimension): return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None for o in operands] else: - return split(t, tuple(shape[dimension] for shape in operand_shapes), - axis=dimension) + limit_points = np.cumsum( + [shape[dimension] for shape in operand_shapes]).tolist() + starts = np.zeros((len(operands), t.ndim), dtype=int).tolist() + limits = np.tile(t.shape, (len(operands), 1)).tolist() + + for i, s in enumerate(starts[1:]): + s[dimension] = limit_points[:-1][i] + for i, l in enumerate(limits): + l[dimension] = limit_points[i] + + return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o) + else None for o, start, limit in zip(operands, starts, limits)] def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) @@ -4423,68 +4413,6 @@ def _concatenate_lower(ctx, *xs, dimension): mlir.register_lowering(concatenate_p, _concatenate_lower) -def _split_shape_rule(operand, *, sizes, axis): - offset = 0 - shapes = [] - shape = list(operand.shape) - if any(s < 0 for s in sizes): - raise ValueError( - f"Sizes passed to split must be nonnegative, got {list(sizes)}") - if operand.shape[axis] != np.sum(sizes): - raise ValueError( - f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the " - f"operand shape {list(operand.shape)}") - for size in sizes: - shape[axis] = size - shapes.append(tuple(shape)) - return shapes - -def _split_dtype_rule(operand, *, sizes, axis): - return (operand.dtype,) * len(sizes) - -def _split_weak_type_rule(operand, *, sizes, axis): - return (operand.weak_type,) * len(sizes) - -def _split_transpose_rule(cotangents, operand, *, sizes, axis): - assert ad.is_undefined_primal(operand) - if all(type(t) is ad_util.Zero for t in cotangents): - return ad_util.Zero(operand.aval), - cotangents = [ - _zeros(t.aval) if type(t) is ad_util.Zero else t - for t in cotangents - ] - return concatenate(cotangents, dimension=axis), - -def _split_batch_rule(batched_args, batch_dims, *, sizes, axis): - operand, = batched_args - bdim, = batch_dims - new_bdims = (bdim,) * len(sizes) - out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis) - return out, new_bdims - -def _split_lower(ctx, x, *, sizes, axis): - x_aval, = ctx.avals_in - start_indices = [0] * x_aval.ndim - limit_indices = list(x_aval.shape) - strides = (1,) * x_aval.ndim - outs = [] - for aval_out in ctx.avals_out: - limit_indices[axis] = start_indices[axis] + aval_out.shape[axis] - outs.append(mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, - limit_indices=limit_indices, strides=strides)) - start_indices[axis] = limit_indices[axis] - return outs - -split_p = core.Primitive('split') -split_p.multiple_results = True -split_p.def_abstract_eval( - partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, - _split_dtype_rule, _split_weak_type_rule)) -split_p.def_impl(partial(dispatch.apply_primitive, split_p)) -ad.deflinear2(split_p, _split_transpose_rule) -batching.primitive_batchers[split_p] = _split_batch_rule -mlir.register_lowering(split_p, _split_lower) - def _pad_dtype_rule(operand, padding_value, *, padding_config): if operand.dtype != padding_value.dtype: msg = "pad operand and padding_value must be same dtype: got {} and {}." diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 617213ca03de..4768a8126c72 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -629,8 +629,7 @@ def _multi_slice(self: Array, # avoid circular imports. @jax.jit def _unstack(x: Array) -> list[Array]: - dims = (0,) - return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] + return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] def _chunk_iter(x, size): if size > x.shape[0]: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d256c97a9957..898e4255dd8e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, set_module, unzip2, + ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) @@ -3280,10 +3280,10 @@ def _split(op: str, ary: ArrayLike, if (isinstance(indices_or_sections, (tuple, list)) or isinstance(indices_or_sections, (np.ndarray, Array)) and indices_or_sections.ndim > 0): - split_indices = np.asarray([0] + [ + indices_or_sections = [ core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1") - for i_s in indices_or_sections] + [size]) - sizes = list(np.diff(split_indices)) + for i_s in indices_or_sections] + split_indices = [0] + list(indices_or_sections) + [size] else: if core.is_symbolic_dim(indices_or_sections): raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is " @@ -3292,14 +3292,21 @@ def _split(op: str, ary: ArrayLike, f"in jax.numpy.{op} argument 1") part_size, r = divmod(size, num_sections) if r == 0: - sizes = [part_size] * num_sections + split_indices = [i * part_size + for i in range(num_sections + 1)] elif op == "array_split": - sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r) + split_indices = ( + [i * (part_size + 1) for i in range(r + 1)] + + [i * part_size + ((r + 1) * (part_size + 1) - 1) + for i in range(num_sections - r)]) else: raise ValueError(f"array split does not result in an equal division: rest is {r}") - sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] - for i in sizes] - return list(lax.split(ary, sizes, axis=axis)) + split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] + for i in split_indices] + starts, ends = [0] * ndim(ary), shape(ary) + _subval = lambda x, i, v: subvals(x, [(i, v)]) + return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) + for start, end in zip(split_indices[:-1], split_indices[1:])] @export @@ -4662,11 +4669,7 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: "Unstack requires arrays with rank > 0, however a scalar array was " "passed." ) - dimensions = (axis,) - return tuple( - lax.squeeze(t, dimensions) - for t in lax.split(x, (1,) * x.shape[axis], axis=axis) - ) + return tuple(moveaxis(x, axis, 0)) @export diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f0286c156e45..be4102dff716 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1871,27 +1871,6 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule -def _split_lowering_rule( - ctx: LoweringRuleContext, x, *, sizes, axis -): - (x_aval,) = ctx.avals_in - slice_size = np.array(x_aval.shape, dtype=np.int64) - starts = np.zeros_like(slice_size) - strides = np.ones_like(slice_size) - outs = [] - for size, aval_out in zip(sizes, ctx.avals_out): - slice_size[axis] = size - outs.append( - vector.extract_strided_slice( - aval_to_ir_type(aval_out), x, starts, slice_size, strides - ) - ) - starts[axis] += size - return outs - -lowering_rules[lax.split_p] = _split_lowering_rule - - def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): out_type = aval_to_ir_type(ctx.avals_out[0]) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 2cc670ef6a43..c41eda693d7f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2087,12 +2087,6 @@ def _concatenate(*operands, dimension): tf_impl[lax.concatenate_p] = _concatenate -def _split(operand, *, sizes, axis): - return tf.split(operand, sizes, axis=axis) - -tf_impl[lax.split_p] = _split - - def _conv_general_dimension_numbers_proto(dimension_numbers): """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 29ec21319361..2681ad1a2a7b 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -323,7 +323,6 @@ def linear_prop(prim, primals_in, series_in, **params): deflinear(lax.convert_element_type_p) deflinear(lax.broadcast_in_dim_p) deflinear(lax.concatenate_p) -deflinear(lax.split_p) deflinear(lax.pad_p) deflinear(lax.reshape_p) deflinear(lax.squeeze_p) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index dc9c69d97795..d569ed641138 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -203,8 +203,6 @@ sort as sort, sort_key_val as sort_key_val, sort_p as sort_p, - split as split, - split_p as split_p, sqrt as sqrt, sqrt_p as sqrt_p, square as square, diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index c7cbde069cc8..78d90cb8a072 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -273,24 +273,6 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs): concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(len(base_shape)) - ], - num_pieces=range(3), - dtype=float_dtypes, - ) - def testSplitGrad(self, axis, base_shape, dtype, num_pieces): - sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) - shape = list(base_shape) - shape[axis] = np.sum(sizes) - rng = jtu.rand_default(self.rng()) - operands = (rng(shape, dtype),) - split = lambda x: lax.split(x, sizes, axis) - check_grads(split, operands, 2, ["fwd", "rev"], eps=1.) - - @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides) for lhs_shape, rhs_shape, all_strides in itertools.chain( diff --git a/tests/lax_test.py b/tests/lax_test.py index 48f70baa1e32..78bc5857acb7 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -283,33 +283,6 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs): numpy_op = lambda *args: lax_reference.concatenate(args, dim) self._CheckAgainstNumpy(numpy_op, op, args_maker) - @jtu.sample_product( - [dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(len(shape))], - num_pieces=range(3), - dtype=lax_test_util.default_dtypes, - ) - def testSplit(self, axis, base_shape, dtype, num_pieces): - sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) - shape = list(base_shape) - shape[axis] = np.sum(sizes) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - op = lambda x: lax.split(x, sizes, axis=axis) - def numpy_op(x): - return np.split(x, np.cumsum(sizes[:-1]), axis=axis) - self._CompileAndCheck(op, args_maker) - self._CheckAgainstNumpy(numpy_op, op, args_maker) - - def testSplitErrors(self): - with self.assertRaisesRegex(ValueError, - "Sizes passed to split must be nonnegative"): - lax.split(np.arange(5), [-1]) - with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"): - lax.split(np.arange(5), [6]) - with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"): - lax.split(np.arange(5), sizes=(), axis=1) - @jtu.sample_product( [ dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 49e06e17be15..83d4d657751b 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -344,24 +344,6 @@ def testSlice(self, shape, dtype, starts, limits, strides, bdims): op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis, bdims=bdims) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(len(base_shape)) - for bdims in lax_test_util.all_bdims(base_shape) - ], - num_pieces=range(3), - dtype=lax_test_util.default_dtypes, - ) - def testSplit(self, base_shape, dtype, num_pieces, axis, bdims): - sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) - shape = list(base_shape) - shape[axis] = np.sum(sizes) - rng = jtu.rand_default(self.rng()) - op = lambda x: lax.split(x, sizes, axis) - self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng, - multiple_results=True) - @jtu.sample_product( [dict(shape=shape, perm=perm, bdims=bdims) for shape, perm in [ From c04aec9d525dd2e767495e41b98e82dd79315f37 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Tue, 19 Nov 2024 15:22:27 -0800 Subject: [PATCH 407/698] [Mosaic] Extend tpu.sem_signal with subcore_id This change: - Bumps up the version of Mosaic to 4 in `serde.cc`. - Adds optional `subcore_id` parameter to `tpu.sem_signal` for signalling specific subcores. - Extends deserialization to correctly parse the older versions of Mosaic without the new parameter `subcore_id` of `tpu.sem_signal`. PiperOrigin-RevId: 698163836 --- jaxlib/mosaic/dialect/tpu/tpu.td | 5 ++- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 36 +++++++++++---- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 44 +++++++++++++------ 3 files changed, 61 insertions(+), 24 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 55d2e1ec975e..590c27ac2099 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -654,14 +654,15 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { I32:$amount, Optional:$device_id, // For remote DMAs Optional:$core_id, // For megacore + Optional:$subcore_id, // For the SC vector subcore OptionalAttr:$core_type ); let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`subcore_id` $subcore_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; let builders = [ - // A backward-compatible builder that sets `core_type` to nullptr. + // A backward-compatible builder that sets `subcore_id` and `core_type` to nullptr. OpBuilder<(ins "Value":$semaphore, "Value":$amount, "Value":$device_id, "Value":$core_id)>, ]; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index b4dcca66f7dc..a103cda7dae2 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -844,7 +844,7 @@ void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state, Value semaphore, Value amount, Value device_id, Value core_id) { build(builder, state, semaphore, amount, device_id, core_id, - /*core_type=*/nullptr); + /*subcore_id=*/nullptr, /*core_type=*/nullptr); } LogicalResult SemaphoreSignalOp::verify() { @@ -861,21 +861,39 @@ LogicalResult SemaphoreSignalOp::verify() { CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc); CoreType target_core_type = getCoreType().value_or(issuing_core_type); - if (getCoreId() == nullptr && getDeviceId() == nullptr) { + if (getCoreId() == nullptr && getDeviceId() == nullptr && + getSubcoreId() == nullptr) { if (target_core_type != issuing_core_type) { - return emitOpError( - absl::StrFormat("Target core type (%s) must match source core type " - "(%s) when device_id and core_id are not specified", - stringifyCoreType(target_core_type), - stringifyCoreType(issuing_core_type))); + return emitOpError(absl::StrFormat( + "Target core type (%s) must match source core type " + "(%s) when device_id, core_id and subcore_id are not specified", + stringifyCoreType(target_core_type), + stringifyCoreType(issuing_core_type))); } } + if (target_core_type == CoreType::kScVectorSubcore && + issuing_core_type != CoreType::kScVectorSubcore && + getSubcoreId() == nullptr) { + return emitOpError( + "Subcore ID must be specified for the SC vector subcore"); + } + if (target_core_type != CoreType::kScVectorSubcore && + getSubcoreId() != nullptr) { + return emitOpError( + "Subcore ID must be specified only for the SC vector subcore"); + } if ((issuing_core_type == CoreType::kTc && - target_core_type == CoreType::kScScalarSubcore) || - (issuing_core_type == CoreType::kScScalarSubcore && + (target_core_type == CoreType::kScScalarSubcore || + target_core_type == CoreType::kScVectorSubcore)) || + ((issuing_core_type == CoreType::kScScalarSubcore || + issuing_core_type == CoreType::kScVectorSubcore) && target_core_type == CoreType::kTc)) { return emitOpError("Signalling between TC and SC is not implemented"); } + if (target_core_type == CoreType::kScVectorSubcore && + (getCoreId() != nullptr || getDeviceId() != nullptr)) { + return emitOpError("Signalling remote SC vector subcores is not supported"); + } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index fd68c9e6c95e..27a886ebeb7e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -15,19 +15,21 @@ limitations under the License. // We need to keep some extra headers for the code in tpu_passes.h.inc. +#include #include // IWYU pragma: keep #include #include #include +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" +#include "absl/strings/str_format.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" @@ -43,7 +45,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 3; +constexpr int kVersion = 4; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -86,21 +88,37 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) { LogicalResult semaphore_signal_rule(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. + // Added subcore_id in version 4. if (version < 2) { if (op->getNumOperands() == 2) { // Local signal. - op->setAttr(OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0, 0})); } else if (op->getNumOperands() == 3) { // Remote signal. - // Hardcoding that one optional value is device_id, not core_id. This - // could misinterpret sem_signals where core_id is specified, but - // device_id isn't. - op->setAttr(OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); - } else { - return op->emitError("Unexpected operand count in tpu.semaphore_signal"); + op->setAttr( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0, 0})); + } + return op->emitError("Unexpected operand count in tpu.semaphore_signal"); + } else if (version < 4) { + ArrayRef operand_segment_sizes = + op->getAttrOfType( + OpTrait::AttrSizedOperandSegments< + SemaphoreSignalOp>::getOperandSegmentSizeAttr()); + if (operand_segment_sizes.size() != 4) { + return op->emitError(absl::StrFormat( + "Expected operand count to be 4 in tpu.semaphore_signal. Got %d", + operand_segment_sizes.size())); } + SmallVector new_operand_segment_sizes( + operand_segment_sizes.begin(), operand_segment_sizes.end()); + new_operand_segment_sizes.push_back(0); + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), + new_operand_segment_sizes)); } return success(); } From 8c71d1ad6d543f95db1b191505150dd19b0b6e69 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 19 Nov 2024 18:30:57 -0800 Subject: [PATCH 408/698] Make deprecated jax.experimental.array_api module visibility internal-only This is in preparation for the module to be removed. PiperOrigin-RevId: 698215225 --- jax/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 0da99677dc7b..26694fec2ad3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1047,7 +1047,7 @@ pytype_library( "experimental/array_api/*.py", ], ), - visibility = [":internal"] + jax_visibility("array_api"), + visibility = [":internal"], deps = [ ":jax", ], From 867a36189bf6c9d19f0f4a6522e91306dec5945f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Nov 2024 18:59:08 -0800 Subject: [PATCH 409/698] Fix a bug where constant deduplication used an inappropriate inequality. We need to compare constants for bitwise equality, not, e.g., floating point equality. The change that added deduplication caused us to conflate +0.0 and -0.0, which led a downstream test not to terminate. PiperOrigin-RevId: 698221147 --- jax/_src/interpreters/mlir.py | 34 ++++++++++++++++++++-------------- tests/pjit_test.py | 11 +++++++++++ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 23325a9d7e26..102e4f490b5c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1755,33 +1755,39 @@ def _emit_lowering_rule_as_fun(lowering_rule, class HashableLiteral: """Hashable wrapper of core.Literal, used for deduplicating IR constants.""" - __slots__ = ["value"] + __slots__ = ["value", "data"] value: core.Literal + # Copy of the value suitable for an equality comparison. We are careful to + # avoid floating point comparisons here, because in particular we don't want + # 0.0 and -0.0 to be considered equal, but we are fine with NaNs being equal. + data: bytes | int | bool | None + def __init__(self, value): self.value = value + if isinstance(value.val, (np.generic, np.ndarray)): + self.data = value.val.tobytes() + elif isinstance(value.val, (bool, int)): + self.data = value.val + elif isinstance(value.val, float): + self.data = np.float64(value.val).tobytes() + elif isinstance(value.val, complex): + self.data = np.complex128(value.val).tobytes() + else: + self.data = None # Unhandled case. def __hash__(self): - h = self.value.hash - return id(self.value.val) if h is None else h + return hash(self.data) def __eq__(self, other): - if self is other: - return True if type(self.value.val) != type(other.value.val): return False if self.value.aval != other.value.aval: return False - if isinstance(self.value.val, (bool, int, float, complex)): - return self.value == other.value - if isinstance(self.value.val, (np.generic, np.ndarray)): - return np.array_equal( - self.value.val, other.value.val, - equal_nan=np.issubdtype(self.value.val.dtype, np.inexact)) - # Since the use case is constant deduplication, it's safe to return - # False in unhandled cases. - return False + if self.data is None: + return id(self) == id(other) + return self.data == other.data def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6df011419513..e32424cfdded 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1312,6 +1312,17 @@ def under_jvp(f): ans2 = jnp.sin(x0) # cpp_pjit hit with bad cache entry assert(ans1.devices() == ans2.devices()) + def test_zero_literal_equality(self): + # This test verifies that we don't accidentally conflate positive and + # negative zeros when deduplicating literals in the IR. + f = jax.jit(lambda x: (x / np.float32(-0.0), x / np.float32(0.0))) + a, b = f(np.float32(1.0)) + self.assertEqual(a, -np.inf) + self.assertEqual(b, np.inf) + ir = f.lower(np.float32(1.0)).as_text() + self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) + self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): From 6c291d67b7a9dfbc0517c0ab7828e80dc88bdc01 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Tue, 19 Nov 2024 19:03:55 -0800 Subject: [PATCH 410/698] [Mosaic] Add `tpu.log` verification on SC Guards against using formatting and targeting vector subcores on SC. PiperOrigin-RevId: 698222100 --- jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 590c27ac2099..de5e3514fc1d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -761,6 +761,7 @@ def TPU_LogOp : TPU_Op<"log"> { ); let results = (outs); let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; + let hasVerifier = 1; } def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::FuncOp"> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index a103cda7dae2..8586e2a16c8a 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1053,6 +1053,30 @@ LogicalResult ConcatenateOp::verify() { return success(); } +LogicalResult LogOp::verify() { + FailureOr> logging_core_type_maybe = + GetCoreTypeOfParentFunc(**this); + if (failed(logging_core_type_maybe)) { + return failure(); + } + CoreType logging_core_type = logging_core_type_maybe->value_or(CoreType::kTc); + if ((logging_core_type == CoreType::kScScalarSubcore || + logging_core_type == CoreType::kScVectorSubcore) && + getFormattedAttr() != nullptr && getFormattedAttr().getValue()) { + return emitOpError("Formatted logging is not supported on SC"); + } + switch (logging_core_type) { + case CoreType::kTc: + case CoreType::kScScalarSubcore: + return success(); + case CoreType::kScVectorSubcore: + return emitOpError("Log op is not supported on the SC vector subcore"); + } + return emitOpError( + absl::StrFormat("Unexpected core type: %s", + stringifyCoreType(logging_core_type_maybe->value()))); +} + } // namespace tpu } // namespace mlir From 4bb81075bcc8c5ac8ea9d5993f9d877bb16f3a13 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Tue, 29 Oct 2024 12:46:08 -0700 Subject: [PATCH 411/698] represent `random.key_impl` of builtin RNGs by canonical string name We do not have great reason to return specs here, and sticking to strings instead can help with simple serialization. --- jax/_src/random.py | 9 +++++---- tests/extend_test.py | 30 ++++++++++++++++++------------ tests/random_test.py | 4 ++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index dc9fc18aff38..6c04b0620080 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -293,14 +293,15 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: return _return_prng_keys(wrapped, _split(typed_key, num)) -def _key_impl(keys: KeyArray) -> PRNGImpl: +def _key_impl(keys: KeyArray) -> str | PRNGSpec: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) keys_dtype = typing.cast(prng.KeyTy, keys.dtype) - return keys_dtype._impl + impl = keys_dtype._impl + return impl.name if impl.name in prng.prngs else PRNGSpec(impl) -def key_impl(keys: KeyArrayLike) -> PRNGSpec: +def key_impl(keys: KeyArrayLike) -> str | PRNGSpec: typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) - return PRNGSpec(_key_impl(typed_keys)) + return _key_impl(typed_keys) def _key_data(keys: KeyArray) -> Array: diff --git a/tests/extend_test.py b/tests/extend_test.py index 84a907c7331d..42196a940a76 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -70,35 +70,41 @@ def test_symbols(self): class RandomTest(jtu.JaxTestCase): - def test_key_make_with_custom_impl(self): - shape = (4, 2, 7) - + def make_custom_impl(self, shape, seed=False, split=False, fold_in=False, + random_bits=False): + assert not split and not fold_in and not random_bits # not yet implemented def seed_rule(_): return jnp.ones(shape, dtype=jnp.dtype('uint32')) def no_rule(*args, **kwargs): assert False, 'unreachable' - impl = jex.random.define_prng_impl( - key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule, - random_bits=no_rule) + return jex.random.define_prng_impl( + key_shape=shape, seed=seed_rule if seed else no_rule, split=no_rule, + fold_in=no_rule, random_bits=no_rule) + + def test_key_make_with_custom_impl(self): + impl = self.make_custom_impl(shape=(4, 2, 7), seed=True) k = jax.random.key(42, impl=impl) self.assertEqual(k.shape, ()) self.assertEqual(impl, jax.random.key_impl(k)) def test_key_wrap_with_custom_impl(self): - def no_rule(*args, **kwargs): - assert False, 'unreachable' - shape = (4, 2, 7) - impl = jex.random.define_prng_impl( - key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule, - random_bits=no_rule) + impl = self.make_custom_impl(shape=shape) data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32')) k = jax.random.wrap_key_data(data, impl=impl) self.assertEqual(k.shape, (3,)) self.assertEqual(impl, jax.random.key_impl(k)) + def test_key_impl_is_spec(self): + # this is counterpart to random_test.py: + # KeyArrayTest.test_key_impl_builtin_is_string_name + spec_ref = self.make_custom_impl(shape=(4, 2, 7), seed=True) + key = jax.random.key(42, impl=spec_ref) + spec = jax.random.key_impl(key) + self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})") + class FfiTest(jtu.JaxTestCase): diff --git a/tests/random_test.py b/tests/random_test.py index fed12792d5c6..f9167b22b4ea 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1125,10 +1125,10 @@ class A: pass jax.random.key(42, impl=A()) @jtu.sample_product(name=[name for name, _ in PRNG_IMPLS]) - def test_key_spec_repr(self, name): + def test_key_impl_builtin_is_string_name(self, name): key = jax.random.key(42, impl=name) spec = jax.random.key_impl(key) - self.assertEqual(repr(spec), f"PRNGSpec({name!r})") + self.assertEqual(spec, name) def test_keyarray_custom_vjp(self): # Regression test for https://github.com/jax-ml/jax/issues/18442 From 4d60db17413208cf4ff829242d21eaa46c9586c4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 19 Nov 2024 21:32:44 -0800 Subject: [PATCH 412/698] Add test_compute_on_host_shared_sharding in memories_test PiperOrigin-RevId: 698250352 --- tests/memories_test.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/memories_test.py b/tests/memories_test.py index da4239338c02..ca676a2b1993 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -808,6 +808,46 @@ def h(x): self.assertArraysEqual(out2, inp * 6) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_on_host_shared_sharding(self): + mesh = jtu.create_mesh((2,), ("x")) + device_sharding = NamedSharding(mesh, P("x")) + host_sharding = device_sharding.with_memory_kind("pinned_host") + + @compute_on("device_host") + @functools.partial( + jax.jit, + in_shardings=(host_sharding, device_sharding), + out_shardings=(host_sharding, device_sharding), + donate_argnums=(0, 1), + ) + def host_func(x, y): + return (x * y), ((x**2) * (y**2)) + + @functools.partial( + jax.jit, + in_shardings=(host_sharding, device_sharding), + out_shardings=(host_sharding, device_sharding), + donate_argnums=(0), + ) + def device_func(host_data, device_data): + host_data, device_data = host_func(host_data, device_data) + device_data = device_data * 2 + host_data, device_data = host_func(host_data, device_data) + return (host_data, device_data) + + input_x = jnp.ones(8) + input_host = jax.device_put(input_x, host_sharding) + + input_device = jnp.arange(8) + input_device = jnp.where(input_device < 4, 0, 1) + input_device = jax.device_put(input_device, device_sharding) + + output_host, output_device = device_func(input_host, input_device) + self.assertEqual(output_host.sharding.memory_kind, 'pinned_host') + self.assertEqual(output_device.sharding.memory_kind, 'device') + self.assertArraysEqual(output_host, [0., 0., 0., 0., 2., 2., 2., 2.]) + self.assertArraysEqual(output_device, [0., 0., 0., 0., 4., 4., 4., 4.]) + def test_compute_on_basic_inline(self): @compute_on('device_host') @jax.jit From 1afb05e2e2341362a9107a6726721f4f617db46c Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Wed, 20 Nov 2024 03:01:11 -0800 Subject: [PATCH 413/698] [mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise. Only propagate signedness from operands when the output type of `op` is an `ir.IntegerType`. PiperOrigin-RevId: 698324596 --- jax/experimental/mosaic/gpu/fragmented_array.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index fd989d052917..e45202386b47 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -623,10 +623,6 @@ def to_layout(self, new_layout: FragmentedLayout): ) def _pointwise(self, op, *other, output_is_signed: bool | None = None): - is_signed = ( - output_is_signed if output_is_signed is not None else self.is_signed - ) - other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): @@ -636,7 +632,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): raise NotImplementedError(o) o = FragmentedArray.splat( - o, shape=self.shape, layout=self.layout, is_signed=is_signed + o, shape=self.shape, layout=self.layout, is_signed=self.is_signed ) if isinstance(o.layout, WGSplatFragLayout): @@ -646,7 +642,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): o.registers.flat[0], shape=self.shape, layout=self.layout, - is_signed=is_signed, + is_signed=self.is_signed, ) else: if self.layout != o.layout: @@ -659,8 +655,13 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) + reg_ty = new_regs.flat[0].type + if ir.VectorType.isinstance(reg_ty): + reg_ty = ir.VectorType(reg_ty).element_type + if output_is_signed is None and ir.IntegerType.isinstance(reg_ty): + output_is_signed = self.is_signed return FragmentedArray( - _registers=new_regs, _layout=self.layout, _is_signed=is_signed + _registers=new_regs, _layout=self.layout, _is_signed=output_is_signed ) def __pos__(self): From 14da7ebb76d5a97b9955822e8781d1f45505cb9e Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Wed, 20 Nov 2024 03:40:40 -0800 Subject: [PATCH 414/698] [pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type. Only handles the case where operand type and target type have the same bitwidth. PiperOrigin-RevId: 698332564 --- jax/_src/pallas/mosaic_gpu/lowering.py | 25 +++++++++++++++++++ .../mosaic/gpu/fragmented_array.py | 8 ++++-- tests/pallas/mosaic_gpu_test.py | 24 ++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6d30cdb0d4a3..5b5a6f4ace83 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1501,6 +1501,31 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): return list(switch_op.results) +@register_lowering_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, operand, *, new_dtype +): + # TODO(petebu) Handle case where src and dst types have different bitwidths + [operand_aval] = ctx.avals_in + operand = _ensure_fa(operand, operand_aval.dtype) + src_elem_type = mgpu_utils.dtype_to_ir_type(operand_aval.dtype) + dst_elem_type = mgpu_utils.dtype_to_ir_type(new_dtype) + assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) + assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) + if src_elem_type.width != dst_elem_type.width: + raise NotImplementedError( + f"Can't bitcast from {operand_aval.dtype} to {new_dtype} because they" + " have different widths" + ) + if ir.IntegerType.isinstance(dst_elem_type): + output_is_signed = mgpu_utils.is_signed(new_dtype) + else: + output_is_signed = None + return mgpu.FragmentedArray.bitcast( + operand, dst_elem_type, output_is_signed=output_is_signed + ) + + def _bcast( x: ir.Value, y: ir.Value, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index e45202386b47..2b985ff5c9b8 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -929,7 +929,9 @@ def fast_instr(x): raise NotImplementedError(x.type) return fast_instr - def bitcast(self, elt: ir.Type): + def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): + if elt == self.mlir_dtype: + return self reg_type = self.registers.flat[0].type if ir.VectorType.isinstance(reg_type): reg_shape = ir.VectorType(reg_type).shape @@ -937,7 +939,9 @@ def bitcast(self, elt: ir.Type): else: ty = elt - return self._pointwise(lambda x: arith.bitcast(ty, x)) + return self._pointwise( + lambda x: arith.bitcast(ty, x), output_is_signed=output_is_signed + ) def __getitem__(self, idx): if self.layout != WGMMA_LAYOUT: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index fe52a33c1637..b8098f40eccf 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1052,6 +1052,30 @@ def kernel(x_ref, o_ref): self.assertEqual(data.count('"name": "store"'), 2) np.testing.assert_array_equal(y, x + x) + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast_convert_type(self, in_dtype, out_dtype): + m, n = 16, 8 + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) + + x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + class PipelineTest(PallasTest): From c76e5fe9a0d1c4b67fdc844a824f1bd53821653d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 04:28:56 -0800 Subject: [PATCH 415/698] [pallas:mosaic_gpu] `copy_smem_to_gmem` now supports `wait_read_only` PiperOrigin-RevId: 698343812 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 4 +++- jax/_src/pallas/mosaic_gpu/primitives.py | 24 +++++++++++++++++------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 9b6adc86f981..90c00765e8b1 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -207,7 +207,9 @@ def loop_body(step, carry): # Wait for the current GMEM->SMEM copy to complete. gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. - gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) + gpu_primitives.wait_smem_to_gmem( + max_concurrent_steps - 1, wait_read_only=True + ) with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): body( diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 36dcba5d15d0..0f25f9808ac1 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -363,20 +363,30 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None: @wait_smem_to_gmem_p.def_effectful_abstract_eval -def _wait_smem_to_gmem_abstract_eval(n): - del n # Unused. +def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): + del n, wait_read_only # Unused. return (), {gpu_core._memory_effect} @lowering.register_lowering_rule(wait_smem_to_gmem_p) -def _wait_smem_to_gmem_lowering(ctx: lowering.LoweringRuleContext, n): - ctx.launch_ctx.await_async_copy(allow_groups=n) +def _wait_smem_to_gmem_lowering( + ctx: lowering.LoweringRuleContext, n, *, wait_read_only +): + ctx.launch_ctx.await_async_copy( + allow_groups=n, await_read_only=wait_read_only + ) return () -def wait_smem_to_gmem(n: int) -> None: - """Waits until there are no more than ``n`` SMEM->GMEM copies in flight.""" - wait_smem_to_gmem_p.bind(n) +def wait_smem_to_gmem(n: int, wait_read_only: bool = False) -> None: + """Waits until there are no more than ``n`` SMEM->GMEM copies in flight. + + Args: + n: The maximum number of copies in flight to wait for. + wait_read_only: If ``True``, wait for the in flight copies to finish + reading from SMEM. The writes to GMEM are not waited for. + """ + wait_smem_to_gmem_p.bind(n, wait_read_only=wait_read_only) # WGMMA on an accumulator reference From f442d40f926f801135cea7637a64cce47f05eae1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 04:29:12 -0800 Subject: [PATCH 416/698] [mosaic_gpu] Fixed `FragmentedArray` comparisons with literals PiperOrigin-RevId: 698343858 --- tests/mosaic/gpu_test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 157f682f5eef..ab2a00c730d6 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1318,18 +1318,21 @@ def kernel(ctx, dst, _): operator.ne, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], + rhs_is_literal=[False, True] ) - def test_comparison(self, op, dtype, m=64, n=32): + def test_comparison(self, op, dtype, rhs_is_literal, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + rhs = 0 if rhs_is_literal else iota + 1 + op(iota, rhs).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) - np.testing.assert_array_equal(result, op(iota, iota + 1)) + rhs = rhs = 0 if rhs_is_literal else iota + 1 + np.testing.assert_array_equal(result, op(iota, rhs)) @parameterized.product( op=[operator.and_, operator.or_, operator.xor], From 04e4c69f7f72e3aabee726315f370c8182045b49 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 05:05:14 -0800 Subject: [PATCH 417/698] [mosaic_gpu] Handle older `jaxlib`s in the profiler module `measure` now raises a `RuntimeError` if the available `jaxlib` does not have the required custom calls. PiperOrigin-RevId: 698351662 --- jax/experimental/mosaic/gpu/profiler.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 337581c54b86..0594e9239be7 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -36,12 +36,15 @@ try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: - pass + has_registrations = False else: - for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): - xla_client.register_custom_call_target( - name, handler, platform="CUDA", api_version=1 - ) + # TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36. + has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations") + if has_registrations: + for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): + xla_client.register_custom_call_target( + name, handler, platform="CUDA", api_version=1 + ) # ruff: noqa: F405 # mypy: ignore-errors @@ -80,6 +83,11 @@ def measure( Returns: The return value of ``f`` and the elapsed time in milliseconds. """ + if not has_registrations: + raise RuntimeError( + "This function requires jaxlib >=0.4.36 with CUDA support." + ) + if not (args or kwargs): # We require at least one argument and at least one output to ensure # that there is a data dependency between `_event_record` calls in From 1df4b5f79885e0d9fb0d8e097b7a526e577f04ef Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 05:06:50 -0800 Subject: [PATCH 418/698] [pallas] Do not skip vmap tests on GPU when x64 is enabled PiperOrigin-RevId: 698351984 --- tests/pallas/pallas_vmap_test.py | 37 +++++++++++++++----------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index fefccfe7eb4f..ffa6195625dd 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -22,6 +22,7 @@ import jax from jax import random from jax._src import config +from jax._src import dtypes from jax._src import test_util as jtu from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl @@ -35,6 +36,10 @@ config.parse_flags_with_absl() +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -42,8 +47,6 @@ class PallasBaseTest(jtu.JaxTestCase): def setUp(self): if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled: - self.skipTest("On GPU the test works only in 32-bit") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") @@ -67,7 +70,7 @@ def setUp(self): def test_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -77,7 +80,7 @@ def add_one(x_ref, o_ref): def test_vmap_of_simple_kernel_with_in_axes_None(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add(x_ref, y_ref, o_ref): o_ref[()] = x_ref[()] + y_ref[()] @@ -87,7 +90,7 @@ def add(x_ref, y_ref, o_ref): def test_double_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -97,7 +100,7 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_simple_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), ) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -108,7 +111,7 @@ def add_one(x_ref, o_ref): def test_quadruple_vmap_of_batched_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), intx), grid=(7,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -120,7 +123,7 @@ def add_one(x_ref, o_ref): def test_vmap_of_slicing_kernel(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -151,7 +154,7 @@ def kernel(src, dst): def test_vmap_of_kernel_with_input_output_aliases(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx), input_output_aliases={1:0}, grid=()) def add(x_ref, _, o_ref): @@ -163,7 +166,7 @@ def add(x_ref, _, o_ref): def test_vmap_of_kernel_with_input_output_aliases_different_axes(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + out_shape=jax.ShapeDtypeStruct((4,), intx), input_output_aliases={0: 0}, grid=(), ) @@ -176,7 +179,7 @@ def add(x_ref, o_ref): def test_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): i = pl.program_id(0) @@ -194,7 +197,7 @@ def add_one(x_ref, o_ref): def test_double_vmap_of_slicing_kernel_different_axes(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx), grid=(4,)) def sin(x_ref, o_ref): i = pl.program_id(0) @@ -211,7 +214,7 @@ def sin(x_ref, o_ref): def test_small_large_vmap(self): # Catches https://github.com/jax-ml/jax/issues/18361 @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -230,7 +233,7 @@ def add_one(x_ref, o_ref): def test_small_small_large_vmap(self): @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx), grid=(2,)) def add_one(x_ref, o_ref): o_ref[()] = x_ref[()] + 1 @@ -249,12 +252,6 @@ def add_one(x_ref, o_ref): class PallasCallVmapInterpretTest(PallasCallVmapTest): INTERPRET = True - def setUp(self): - super().setUp() - if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: - # TODO: assertion failures on CPU in 64-bit mode - self.skipTest("On CPU the test works only in 32-bit mode") - if __name__ == "__main__": absltest.main() From a582df02971337dba2834c5a3953f4af067caaa0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 20 Nov 2024 06:38:33 -0800 Subject: [PATCH 419/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/fcee07f619a765db815d9ed4e2bc229275818a2b. PiperOrigin-RevId: 698371906 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 99c2af75f3ad..a554cfd03687 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "229f376e046b9a51039dc1566d1e388ee7c1ca6d" -XLA_SHA256 = "895b39b5cb298460185f29df3ecc8882f4ee151b0f7dc93e5387ef81ea32e374" +XLA_COMMIT = "fcee07f619a765db815d9ed4e2bc229275818a2b" +XLA_SHA256 = "1dd144e64e2c2dcc20a2130e10607fec7b3a810926ba912918dd5437698a3375" def repo(): tf_http_archive( From a4266b5e31853a62a06281b211024cb8c2581876 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 08:23:19 -0800 Subject: [PATCH 420/698] Mention python 3.13 in docs & package metadata --- docs/deprecation.md | 3 +++ jaxlib/setup.py | 1 + setup.py | 1 + 3 files changed, 5 insertions(+) diff --git a/docs/deprecation.md b/docs/deprecation.md index 385d31271421..603a027f5efc 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -18,6 +18,7 @@ This means we support at least: * **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**. * **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**. * **Python 3.12** was released October 2023, and will be supported in new JAX releases at least until **July 2027**. + * **Python 3.13** was released October 2024, and will be supported in new JAX releases at least until **July 2028**. * All NumPy feature releases in the 24 months prior to each JAX release. For example: @@ -25,6 +26,7 @@ This means we support at least: * **NumPy 1.25** was released June 2023, and will be supported in new JAX releases at least until **June 2025** * **NumPy 1.26** was released September 2023, and will be supported in new JAX releases at least until **September 2025** * **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026** + * **NumPy 2.1** was released August 2024, and will be supported in new JAX releases at least until **August 2026** * All SciPy feature releases in the 24 months prior to each JAX release. For example: @@ -32,6 +34,7 @@ This means we support at least: * **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**. * **Scipy 1.12** was released January 2024, and will be supported in new JAX releases at least until **January 2026**. * **Scipy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**. + * **Scipy 1.14** was released June 2024, and will be supported in new JAX releases at least until **June 2026**. JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed diff --git a/jaxlib/setup.py b/jaxlib/setup.py index dea9503c7c00..989a8314eb92 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -72,6 +72,7 @@ def has_ext_modules(self): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={ 'jaxlib': [ diff --git a/setup.py b/setup.py index 98d509375d62..a3b54f7aa94f 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ def load_version_module(pkg_path): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], zip_safe=False, ) From 1e9e85a39eee20f7362c7aa6e79a8f345bbef748 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 20 Nov 2024 08:26:12 -0800 Subject: [PATCH 421/698] Simplify handling of `DotAlgorithmPreset` output types. Create a clear distinction between the type used for accumulation and possible output types. PiperOrigin-RevId: 698399447 --- jax/_src/lax/lax.py | 84 ++++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ff9ac0a49578..39c5bca5819c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -879,11 +879,11 @@ def __str__(self) -> str: def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: match self: case ( - DotAlgorithmPreset.DEFAULT | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.DEFAULT + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): return None case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32: @@ -906,14 +906,26 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: return self.lhs_precision_type @property - def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: + def accumulation_type(self) -> DTypeLike | None: match self: case ( - DotAlgorithmPreset.DEFAULT | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.DEFAULT + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): return None + case DotAlgorithmPreset.F16_F16_F16: + return np.float16 + case DotAlgorithmPreset.BF16_BF16_BF16: + return dtypes.bfloat16 + case DotAlgorithmPreset.F64_F64_F64: + return np.float64 + case _: + return np.float32 + + @property + def supported_output_types(self) -> tuple[DTypeLike, ...] | None: + match self: case ( DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM @@ -921,16 +933,11 @@ def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz) - case DotAlgorithmPreset.F16_F16_F16: - return np.float16 case DotAlgorithmPreset.F16_F16_F32: return (np.float32, np.float16) - case DotAlgorithmPreset.BF16_BF16_BF16: - return dtypes.bfloat16 - case DotAlgorithmPreset.F64_F64_F64: - return np.float64 case _: - return np.float32 + accumulation_type = self.accumulation_type + return None if accumulation_type is None else (accumulation_type,) def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: @@ -941,16 +948,18 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, tf32 = ir.FloatTF32Type.get() match self: case ( - DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | - DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY + | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): - fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), - np.dtype(dtypes.float8_e4m3fn), - np.dtype(dtypes.float8_e4m3fnuz), - np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)] + fp8_dtypes = [ + np.dtype(dtypes.float8_e4m3b11fnuz), + np.dtype(dtypes.float8_e4m3fn), + np.dtype(dtypes.float8_e4m3fnuz), + np.dtype(dtypes.float8_e5m2), + np.dtype(dtypes.float8_e5m2fnuz), + ] if dtypes.float8_e3m4 is not None: fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] if dtypes.float8_e4m3 is not None: @@ -958,13 +967,20 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " - f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.") + f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.' + ) lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype)) rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype)) acc = ir.F32Type.get() return hlo.DotAlgorithm.get( - lhs, rhs, acc, 1, 1, 1, - self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM) + lhs, + rhs, + acc, + 1, + 1, + 1, + self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM, + ) case DotAlgorithmPreset.F16_F16_F16: return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False) case DotAlgorithmPreset.F16_F16_F32: @@ -3649,9 +3665,8 @@ def maybe_convert_dtype(input_dtype, target_dtype): return input_dtype if not isinstance(target_dtype, tuple): target_dtype = (target_dtype,) - if any(input_dtype == d for d in target_dtype): - return input_dtype - return target_dtype[0] + return input_dtype if input_dtype in target_dtype else target_dtype[0] + if algorithm == DotAlgorithmPreset.BF16_BF16_F32: lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type) rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type) @@ -3662,10 +3677,15 @@ def maybe_convert_dtype(input_dtype, target_dtype): out_dtype = maybe_convert_dtype(out_dtype, np.float32) return lhs_dtype, rhs_dtype, out_dtype else: + if isinstance(algorithm, DotAlgorithmPreset): + supported_output_types = algorithm.supported_output_types + else: + supported_output_types = (algorithm.accumulation_type,) + return ( maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type), maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type), - maybe_convert_dtype(out_dtype, algorithm.accumulation_type), + maybe_convert_dtype(out_dtype, supported_output_types), ) From 85e2969aea15141bedd6d4ec0548cc02ef45b069 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 08:48:26 -0800 Subject: [PATCH 422/698] Deprecate several private APIs in jax.lib --- CHANGELOG.md | 6 ++++++ jax/lib/xla_client.py | 7 ++++++- jax/lib/xla_extension.py | 25 +++++++++++++++++++++++-- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9082399c8695..37fd68bce39d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. result in an indexing overflow for batch sizes close to int32 max. See {jax-issue}`#24843` for more details. +* Deprecations + * `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated; + use `jax.Array` instead. + * `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError` + instead. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index aaf3791037d0..cd3696d8838c 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -18,7 +18,6 @@ get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile mlir_api_version = _xc.mlir_api_version -ArrayImpl = _xc.ArrayImpl Client = _xc.Client CompileOptions = _xc.CompileOptions DeviceAssignment = _xc.DeviceAssignment @@ -95,6 +94,11 @@ "XlaComputation is deprecated; use StableHLO instead.", _xc.XlaComputation, ), + # Added Nov 20 2024 + "ArrayImpl": ( + "jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.", + _xc.ArrayImpl, + ), } import typing as _typing @@ -106,6 +110,7 @@ ops = _xc.ops register_custom_call_target = _xc.register_custom_call_target shape_from_pyval = _xc.shape_from_pyval + ArrayImpl = _xc.ArrayImpl Device = _xc.Device FftType = _FftType PaddingType = _xc.PaddingType diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 20ce459685aa..52fe94e231d1 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -24,7 +24,6 @@ pmap_lib = _xe.pmap_lib profiler = _xe.profiler pytree = _xe.pytree -ArrayImpl = _xe.ArrayImpl Device = _xe.Device DistributedRuntimeClient = _xe.DistributedRuntimeClient HloModule = _xe.HloModule @@ -33,6 +32,28 @@ PjitFunctionCache = _xe.PjitFunctionCache PjitFunction = _xe.PjitFunction PmapFunction = _xe.PmapFunction -XlaRuntimeError = _xe.XlaRuntimeError +_deprecations = { + # Added Nov 20 2024 + "ArrayImpl": ( + "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", + _xe.ArrayImpl, + ), + "XlaRuntimeError": ( + "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", + _xe.XlaRuntimeError, + ), +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + ArrayImpl = _xe.ArrayImpl + XlaRuntimeError = _xe.XlaRuntimeError +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing del _xe From 62225926253474c6e5e4b202d5c9cf3363a02a03 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 20 Nov 2024 17:46:06 +0000 Subject: [PATCH 423/698] Fix KeyError recently introduced in cloud_tpu_init.py This fixes a bug introduced in https://github.com/jax-ml/jax/pull/24889 --- jax/_src/cloud_tpu_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 8ff52bd2f559..a2f137686dae 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,7 @@ def cloud_tpu_init() -> None: os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']: + if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''): os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true' # this makes tensorstore serialization work better on TPU From 8d84f2837346b29b52d1b797f672af10df05df41 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 20 Nov 2024 09:59:39 -0800 Subject: [PATCH 424/698] [pallas mgpu] Lowering for while loops as long as they are secretly for loops. PiperOrigin-RevId: 698427307 --- jax/_src/pallas/mosaic_gpu/lowering.py | 38 ++++++++++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 16 +++++++++++ 2 files changed, 54 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5b5a6f4ace83..66437839cce2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1473,6 +1473,44 @@ def _scan_lowering_rule( return for_out +@register_lowering_rule(lax.while_p) +def _while_lowering_rule( + ctx: LoweringRuleContext, + *args, + cond_jaxpr, + body_jaxpr, + cond_nconsts, + body_nconsts, +): + # First try to lower via a simpler fori loop, which may optimize better. + fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop( + cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts + ) + del cond_jaxpr, body_jaxpr + if fori_jaxpr is None: + raise NotImplementedError(err) + + if fori_jaxpr.constvars: + raise NotImplementedError + + lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:] + # Reflect the changes of the pattern matcher to the context. + avals_in = ( + *ctx.avals_in[cond_nconsts:body_nconsts], + ctx.avals_in[body_nconsts], # the index + *ctx.avals_in[body_nconsts + 2:], + ) + + avals_out = tuple(ctx.avals_out[2:]) + ctx = ctx.replace(avals_in=avals_in, avals_out=avals_out) + _, consts, (lb, ub, *args) = util.split_list(args, [cond_nconsts, body_nconsts]) + + lb, ub = _ensure_ir_value(lb, lb_aval.dtype), _ensure_ir_value(ub, ub_aval.dtype) + length = arith_dialect.subi(ub, lb) + + for_out = _lower_jaxpr_to_for_loop(ctx, fori_jaxpr, lb, length, consts, *args, has_loop_index=True) + return (ub, ub, *for_out) + @register_lowering_rule(lax.cond_p) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b8098f40eccf..48c047697911 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -676,6 +676,22 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) + def test_fori_loop_dynamic_bounds(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + grid=(1,) + ) + def kernel(o_ref): + zero = pl.program_id(0) + # Equivalent to 2 + 3. + o_ref[...] = jax.lax.broadcast( + jax.lax.fori_loop(2 + zero, 4 + zero, lambda i, x: x + i, 0), o_ref.shape + ) + + np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) + def test_fori_loop_tuple(self): @functools.partial( pl.pallas_call, From d0f17c0c04bec626a5e03cbf33a4dae43cfc8443 Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 4 Nov 2024 13:33:19 -0500 Subject: [PATCH 425/698] Make a direct linearize trace. This is an alternative to doing JVP followed by partial eval. The linearize trace has two parent traces, one for the primal computation and one for the tangent computation. If we make the tangent trace a DynamicJaxprTrace then we get staged linearization. If we make it the same as the primal trace then we get primal and tangent computations occurring in step (JVP). This is a neat trick enabled by stackless which now lives up to its name. With two parent traces we have a tree of traces not a linked list stack. Primitive ops can have their own linearization rules but as a fallback we can derive a linearization rule for a single op using jvp/partial-eval. For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can make this the default for linearize/grad. It should help with remat and AD through state which are awkward to express via partial eval. --- jax/_src/config.py | 10 ++++ jax/_src/interpreters/ad.py | 101 ++++++++++++++++++++++++++++++++++-- tests/api_test.py | 15 ++++++ 3 files changed, 123 insertions(+), 3 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 1c62f7125ee7..eff9b757b95b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -219,6 +219,7 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, sharding_in_types.value, + use_direct_linearize.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -263,6 +264,7 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, sharding_in_types.value, + use_direct_linearize.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -983,6 +985,7 @@ class _GlobalExtraJitContext(NamedTuple): threefry_partitionable: bool = False threefry_gpu_kernel_lowering: bool = False sharding_in_types: bool = False + use_direct_linearize: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 pgle_profiling_runs: int = 0 @@ -1025,6 +1028,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): threefry_partitionable: bool | None = None threefry_gpu_kernel_lowering: bool | None = None sharding_in_types: bool | None = None + use_direct_linearize: bool | None = None softmax_custom_jvp: bool | None = None xla_profile_version: int | None = None pgle_profiling_runs: int | None = None @@ -1318,6 +1322,12 @@ def _update_jax_memories_thread_local(val): 'avals have sharding on them.'), include_in_jit_key=True) +use_direct_linearize = bool_state( + name='jax_use_direct_linearize', + default=False, + help=('Use direct linearization instead JVP followed by partial eval'), + include_in_jit_key=True) + data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', default=False, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 99340e728545..91f061fd2210 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -39,7 +39,6 @@ as_hashable_function, weakref_lru_cache, partition_list) - zip = safe_zip map = safe_map def identity(x): return x @@ -106,7 +105,29 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): store.store(aux_primals) return out_primals, out_tangents +def direct_linearize(traceable, *primals, **kwargs): + has_aux = kwargs.pop('has_aux', False) + assert not has_aux + with core.take_current_trace() as parent_trace: + frame = pe.JaxprStackFrame() + tangent_trace = pe.DynamicJaxprTrace(frame) + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] + tag = core.TraceTag() + linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag) + tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] + with core.set_current_trace(linearize_trace): + ans = traceable.call_wrapped(*tracers) + + out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) + out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) + jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents) + out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents] + del attrs_tracked # TODO: attrs + return out_primals, out_tangents_pvals, jaxpr, consts + def linearize(traceable, *primals, **kwargs): + if config.use_direct_linearize.value: + return direct_linearize(traceable, *primals, **kwargs) has_aux = kwargs.pop('has_aux', False) if not has_aux: jvpfun = jvp(traceable) @@ -444,15 +465,89 @@ def _primal_tangent_shapes_match(primal, tangent): call_param_updaters: dict[core.Primitive, Callable] = {} call_transpose_param_updaters: dict[core.Primitive, Callable] = {} +# -------------------- Linearize trace -------------------- + +class LinearizeTrace(Trace): + + def __init__(self, parent_trace, tangent_trace, tag): + self.tag = tag + self.parent_trace = parent_trace + self.tangent_trace = tangent_trace + + def to_primal_tangent_pair(self, val): + if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag: + return (val.primal, val.tangent) + else: + tangent_zero = Zero.from_primal_value(val) + return (val, tangent_zero) + + def process_primitive(self, primitive, args, params): + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) + if all(type(t) is Zero for t in tangents_in): + return primitive.bind_with_trace(self.parent_trace, primals_in, params) + lin = primitive_linearizations.get(primitive) + if lin is None: + lin = partial(fallback_linearize_rule, primitive) + with core.set_current_trace(self.parent_trace): + primal_out, linearized = lin(*primals_in, **params) + with core.set_current_trace(self.tangent_trace): + tangent_out = linearized(*tangents_in) + if primitive.multiple_results: + return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] + else: + return maybe_linearize_tracer(self, primal_out, tangent_out) + +def maybe_linearize_tracer(trace, primal, tangent): + if type(tangent) is Zero: + return primal + else: + return LinearizeTracer(trace, primal, tangent) + +def fallback_linearize_rule(prim, *args, **kwargs): + def call_prim(*args_): + return prim.bind(*args_, **kwargs) + with config.use_direct_linearize(False): + out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize( + lu.wrap_init(call_prim), *args, **kwargs) + def linearized(*tangents): + tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents)) + full_out = [pval.get_known() if pval.is_known() else next(tangents_out) + for pval in out_tangents_pvals] + assert next(tangents_out, None) is None + return full_out + return out_primals, linearized + +class LinearizeTracer(Tracer): + __slots__ = ['primal', 'tangent'] + + def __init__(self, trace, primal, tangent): + if config.enable_checks.value: + _primal_tangent_shapes_match(primal, tangent) + self._trace = trace + self.primal = primal + self.tangent = tangent + + @property + def aval(self): + return get_aval(self.primal) + + def full_lower(self): + if type(self.tangent) is Zero: + return core.full_lower(self.primal) + else: + return self + + def to_concrete_value(self): + return core.to_concrete_value(self.primal) + # -------------------- Primitives -------------------- primitive_jvps : dict[core.Primitive, Callable] = {} - primitive_transposes: dict[core.Primitive, Callable] = {} # transpose rules that internally perform reductions over the given named axes reducing_transposes: dict[core.Primitive, Callable] = {} - +primitive_linearizations: dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive) diff --git a/tests/api_test.py b/tests/api_test.py index ae38f50460ab..ff7855b68991 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4807,6 +4807,21 @@ def add_one_and_dupe(x: int) -> tuple[int, int]: jit_add_one_dupe = jax.jit(add_one_and_dupe, inline=True) jax.eval_shape(jit_add_one_dupe, 0) # don't crash + def test_use_direct_linearize(self): + + def check_invariant_to_use_direct_linearize(f): + with config.use_direct_linearize(False): + ans1 = f() + with config.use_direct_linearize(True): + ans2 = f() + + self.assertEqual(ans1, ans2) + + def sin_of_sin(x): + return jnp.sin(jnp.sin(x)) + + check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) + class RematTest(jtu.JaxTestCase): From fee272e550109e7409e8ae6e992bbde7bd1f1b90 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 10:30:12 -0800 Subject: [PATCH 426/698] Remove internal KeyArray alias This was useful during the transition to typed PRNG keys, but is no longer necessary. It also makes generated HTML docs confusing: it's better to just use Array as we expect users to. --- jax/_src/blocked_sampler.py | 4 +- jax/_src/nn/initializers.py | 25 +++++---- jax/_src/random.py | 104 ++++++++++++++++++------------------ 3 files changed, 65 insertions(+), 68 deletions(-) diff --git a/jax/_src/blocked_sampler.py b/jax/_src/blocked_sampler.py index 16da61d75b3f..3bc592d88246 100644 --- a/jax/_src/blocked_sampler.py +++ b/jax/_src/blocked_sampler.py @@ -23,7 +23,7 @@ Shape = random.Shape class SampleFn(Protocol): - def __call__(self, key: random.KeyArrayLike, *args, shape: Shape, + def __call__(self, key: ArrayLike, *args, shape: Shape, **kwargs) -> Array: ... @@ -43,7 +43,7 @@ def _compute_scalar_index(iteration_index: Sequence[int], def blocked_fold_in( - global_key: random.KeyArrayLike, + global_key: ArrayLike, total_size: Shape, block_size: Shape, tile_size: Shape, diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index eb1bb1609bbf..8086a97a3748 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -36,7 +36,6 @@ export = set_module('jax.nn.initializers') -KeyArray = Array # TODO: Import or define these to match # https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py. DTypeLikeFloat = Any @@ -48,13 +47,13 @@ @typing.runtime_checkable class Initializer(Protocol): @staticmethod - def __call__(key: KeyArray, + def __call__(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: raise NotImplementedError @export -def zeros(key: KeyArray, +def zeros(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of zeros. @@ -69,7 +68,7 @@ def zeros(key: KeyArray, return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) @export -def ones(key: KeyArray, +def ones(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of ones. @@ -100,7 +99,7 @@ def constant(value: ArrayLike, Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -126,7 +125,7 @@ def uniform(scale: RealNumeric = 1e-2, Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -152,7 +151,7 @@ def normal(stddev: RealNumeric = 1e-2, Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -189,7 +188,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2, [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -230,7 +229,7 @@ def _compute_fans(shape: Sequence[int], fan_out = out_size * receptive_field_size return fan_in, fan_out -def _complex_uniform(key: KeyArray, +def _complex_uniform(key: Array, shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ @@ -244,7 +243,7 @@ def _complex_uniform(key: KeyArray, theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def _complex_truncated_normal(key: KeyArray, upper: ArrayLike, +def _complex_truncated_normal(key: Array, upper: ArrayLike, shape: Sequence[int], dtype: DTypeLikeInexact) -> Array: """ @@ -314,7 +313,7 @@ def variance_scaling( dtype: the dtype of the weights. """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: shape = core.canonicalize_shape(shape) @@ -599,7 +598,7 @@ def orthogonal(scale: RealNumeric = 1.0, Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) @@ -654,7 +653,7 @@ def delta_orthogonal( .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393 """ - def init(key: KeyArray, + def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) diff --git a/jax/_src/random.py b/jax/_src/random.py index 6c04b0620080..4313d9036eda 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -55,8 +55,6 @@ Shape = Sequence[int] PRNGImpl = prng.PRNGImpl -KeyArray = Array -KeyArrayLike = ArrayLike UINT_DTYPES = prng.UINT_DTYPES @@ -69,8 +67,8 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -def _check_prng_key(name: str, key: KeyArrayLike, *, - allow_batched: bool = False) -> tuple[KeyArray, bool]: +def _check_prng_key(name: str, key: ArrayLike, *, + allow_batched: bool = False) -> tuple[Array, bool]: if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): wrapped_key = key wrapped = False @@ -113,7 +111,7 @@ def _return_prng_keys(was_wrapped, key): return prng.random_unwrap(key) if was_wrapped else key -def _random_bits(key: KeyArray, bit_width: int, shape: Shape) -> Array: +def _random_bits(key: Array, bit_width: int, shape: Shape) -> Array: assert jnp.issubdtype(key.dtype, dtypes.prng_key) return prng.random_bits(key, bit_width=bit_width, shape=shape) @@ -188,7 +186,7 @@ def resolve_prng_impl(impl_spec: PRNGSpecDesc | None) -> PRNGImpl: def _key(ctor_name: str, seed: int | ArrayLike, - impl_spec: PRNGSpecDesc | None) -> KeyArray: + impl_spec: PRNGSpecDesc | None) -> Array: impl = resolve_prng_impl(impl_spec) if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key): raise TypeError( @@ -200,7 +198,7 @@ def _key(ctor_name: str, seed: int | ArrayLike, return prng.random_seed(seed, impl=impl) def key(seed: int | ArrayLike, *, - impl: PRNGSpecDesc | None = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> Array: """Create a pseudo-random number generator (PRNG) key given an integer seed. The result is a scalar array containing a key, whose dtype indicates @@ -220,7 +218,7 @@ def key(seed: int | ArrayLike, *, return _key('key', seed, impl) def PRNGKey(seed: int | ArrayLike, *, - impl: PRNGSpecDesc | None = None) -> KeyArray: + impl: PRNGSpecDesc | None = None) -> Array: """Create a legacy PRNG key given an integer seed. This function produces old-style legacy PRNG keys, which are arrays @@ -248,7 +246,7 @@ def PRNGKey(seed: int | ArrayLike, *, return _return_prng_keys(True, _key('PRNGKey', seed, impl)) -def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: +def fold_in(key: ArrayLike, data: IntegerArray) -> Array: """Folds in data to a PRNG key to form a new PRNG key. Args: @@ -267,7 +265,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: return _return_prng_keys(wrapped, key_out) -def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: +def _split(key: Array, num: int | tuple[int, ...] = 2) -> Array: # Alternative to split() to use within random samplers. # TODO(frostig): remove and use split(); we no longer need to wait # to always enable_custom_prng @@ -278,7 +276,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: shape = tuple(num) if isinstance(num, Sequence) else (num,) return prng.random_split(key, shape=shape) -def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: +def split(key: ArrayLike, num: int | tuple[int, ...] = 2) -> Array: """Splits a PRNG key into `num` new keys by adding a leading axis. Args: @@ -293,22 +291,22 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: return _return_prng_keys(wrapped, _split(typed_key, num)) -def _key_impl(keys: KeyArray) -> str | PRNGSpec: +def _key_impl(keys: Array) -> str | PRNGSpec: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) keys_dtype = typing.cast(prng.KeyTy, keys.dtype) impl = keys_dtype._impl return impl.name if impl.name in prng.prngs else PRNGSpec(impl) -def key_impl(keys: KeyArrayLike) -> str | PRNGSpec: +def key_impl(keys: ArrayLike) -> str | PRNGSpec: typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) return _key_impl(typed_keys) -def _key_data(keys: KeyArray) -> Array: +def _key_data(keys: Array) -> Array: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) return prng.random_unwrap(keys) -def key_data(keys: KeyArrayLike) -> Array: +def key_data(keys: ArrayLike) -> Array: """Recover the bits of key data underlying a PRNG key array.""" keys, _ = _check_prng_key("key_data", keys, allow_batched=True) return _key_data(keys) @@ -345,7 +343,7 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) -def bits(key: KeyArrayLike, +def bits(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeUInt | None = None) -> Array: """Sample uniform bits in the form of unsigned integers. @@ -374,7 +372,7 @@ def bits(key: KeyArrayLike, return _random_bits(key, bit_width, shape) -def uniform(key: KeyArrayLike, +def uniform(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., @@ -444,7 +442,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: lax.reshape(floats * (maxval - minval) + minval, shape)) -def randint(key: KeyArrayLike, +def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, @@ -533,7 +531,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: return lax.add(minval, lax.convert_element_type(random_offset, dtype)) -def permutation(key: KeyArrayLike, +def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, independent: bool = False) -> Array: @@ -596,7 +594,7 @@ def _shuffle(key, x, axis) -> Array: return x -def choice(key: KeyArrayLike, +def choice(key: ArrayLike, a: int | ArrayLike, shape: Shape = (), replace: bool = True, @@ -677,7 +675,7 @@ def choice(key: KeyArrayLike, arr.shape[0:axis] + tuple(shape) + arr.shape[axis+1:]) -def normal(key: KeyArrayLike, +def normal(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample standard normal random values with given shape and float dtype. @@ -730,7 +728,7 @@ def _normal_real(key, shape, dtype) -> Array: return lax.mul(np.array(np.sqrt(2), dtype), lax.erf_inv(u)) -def multivariate_normal(key: KeyArrayLike, +def multivariate_normal(key: ArrayLike, mean: RealArray, cov: RealArray, shape: Shape | None = None, @@ -813,7 +811,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array: return result -def truncated_normal(key: KeyArrayLike, +def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, @@ -879,7 +877,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))) -def bernoulli(key: KeyArrayLike, +def bernoulli(key: ArrayLike, p: RealArray = np.float32(0.5), shape: Shape | None = None) -> Array: r"""Sample Bernoulli random values with given shape and mean. @@ -924,7 +922,7 @@ def _bernoulli(key, p, shape) -> Array: return uniform(key, shape, lax.dtype(p)) < p -def beta(key: KeyArrayLike, +def beta(key: ArrayLike, a: RealArray, b: RealArray, shape: Shape | None = None, @@ -985,7 +983,7 @@ def _beta(key, a, b, shape, dtype) -> Array: return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled) -def cauchy(key: KeyArrayLike, +def cauchy(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Cauchy random values with given shape and float dtype. @@ -1024,7 +1022,7 @@ def _cauchy(key, shape, dtype) -> Array: return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5)))) -def dirichlet(key: KeyArrayLike, +def dirichlet(key: ArrayLike, alpha: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1096,7 +1094,7 @@ def _softmax(x, axis) -> Array: return unnormalized / unnormalized.sum(axis, keepdims=True) -def exponential(key: KeyArrayLike, +def exponential(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Exponential random values with given shape and float dtype. @@ -1135,7 +1133,7 @@ def _exponential(key, shape, dtype) -> Array: return lax.neg(lax.log1p(lax.neg(u))) -def _gamma_one(key: KeyArray, alpha, log_space) -> Array: +def _gamma_one(key: Array, alpha, log_space) -> Array: # Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang # The algorithm can also be founded in: # https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables @@ -1263,7 +1261,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space): multiple_results=False), platform='cpu') batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule -def gamma(key: KeyArrayLike, +def gamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1310,7 +1308,7 @@ def gamma(key: KeyArrayLike, return _gamma(key, a, shape=shape, dtype=dtype) -def loggamma(key: KeyArrayLike, +def loggamma(key: ArrayLike, a: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1452,7 +1450,7 @@ def _poisson(key, lam, shape, dtype) -> Array: return lax.select(lam == 0, jnp.zeros_like(result), result) -def poisson(key: KeyArrayLike, +def poisson(key: ArrayLike, lam: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: @@ -1497,7 +1495,7 @@ def poisson(key: KeyArrayLike, return _poisson(key, lam, shape, dtype) -def gumbel(key: KeyArrayLike, +def gumbel(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: """Sample Gumbel random values with given shape and float dtype. @@ -1533,7 +1531,7 @@ def _gumbel(key, shape, dtype) -> Array: uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) -def categorical(key: KeyArrayLike, +def categorical(key: ArrayLike, logits: RealArray, axis: int = -1, shape: Shape | None = None) -> Array: @@ -1575,7 +1573,7 @@ def categorical(key: KeyArrayLike, axis=axis) -def laplace(key: KeyArrayLike, +def laplace(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample Laplace random values with given shape and float dtype. @@ -1612,7 +1610,7 @@ def _laplace(key, shape, dtype) -> Array: return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u)))) -def logistic(key: KeyArrayLike, +def logistic(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample logistic random values with given shape and float dtype. @@ -1648,7 +1646,7 @@ def _logistic(key, shape, dtype): return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) -def pareto(key: KeyArrayLike, +def pareto(key: ArrayLike, b: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1697,7 +1695,7 @@ def _pareto(key, b, shape, dtype) -> Array: return lax.exp(e / b) -def t(key: KeyArrayLike, +def t(key: ArrayLike, df: RealArray, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: @@ -1749,7 +1747,7 @@ def _t(key, df, shape, dtype) -> Array: return n * jnp.sqrt(half_df / g) -def chisquare(key: KeyArrayLike, +def chisquare(key: ArrayLike, df: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -1801,7 +1799,7 @@ def _chisquare(key, df, shape, dtype) -> Array: return chi2 -def f(key: KeyArrayLike, +def f(key: ArrayLike, dfnum: RealArray, dfden: RealArray, shape: Shape | None = None, @@ -1865,7 +1863,7 @@ def _f(key, dfnum, dfden, shape, dtype) -> Array: return f -def rademacher(key: KeyArrayLike, +def rademacher(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeInt = int) -> Array: r"""Sample from a Rademacher distribution. @@ -1900,7 +1898,7 @@ def _rademacher(key, shape, dtype) -> Array: return (2 * bernoulli_samples - 1).astype(dtype) -def maxwell(key: KeyArrayLike, +def maxwell(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float) -> Array: r"""Sample from a one sided Maxwell distribution. @@ -1940,7 +1938,7 @@ def _maxwell(key, shape, dtype) -> Array: return jnp.linalg.norm(norm_rvs, axis=-1) -def double_sided_maxwell(key: KeyArrayLike, +def double_sided_maxwell(key: ArrayLike, loc: RealArray, scale: RealArray, shape: Shape = (), @@ -1992,7 +1990,7 @@ def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array: return random_sign * maxwell_rvs * scale + loc -def weibull_min(key: KeyArrayLike, +def weibull_min(key: ArrayLike, scale: RealArray, concentration: RealArray, shape: Shape = (), @@ -2038,7 +2036,7 @@ def _weibull_min(key, scale, concentration, shape, dtype) -> Array: def orthogonal( - key: KeyArrayLike, + key: ArrayLike, n: int, shape: Shape = (), dtype: DTypeLikeFloat = float @@ -2073,7 +2071,7 @@ def orthogonal( return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2])) def generalized_normal( - key: KeyArrayLike, + key: ArrayLike, p: float, shape: Shape = (), dtype: DTypeLikeFloat = float @@ -2108,7 +2106,7 @@ def generalized_normal( return r * g ** (1 / p) def ball( - key: KeyArrayLike, + key: ArrayLike, d: int, p: float = 2, shape: Shape = (), @@ -2140,7 +2138,7 @@ def ball( return g / (((jnp.abs(g) ** p).sum(-1) + e) ** (1 / p))[..., None] -def rayleigh(key: KeyArrayLike, +def rayleigh(key: ArrayLike, scale: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2193,7 +2191,7 @@ def _rayleigh(key, scale, shape, dtype) -> Array: ray = lax.mul(scale, sqrt_u) return ray -def wald(key: KeyArrayLike, +def wald(key: ArrayLike, mean: RealArray, shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2251,7 +2249,7 @@ def _wald(key, mean, shape, dtype) -> Array: w = lax.select(lax.le(z, mean / (mean + x)), x, mean_sq / x) return w -def geometric(key: KeyArrayLike, +def geometric(key: ArrayLike, p: RealArray, shape: Shape | None = None, dtype: DTypeLikeInt = int) -> Array: @@ -2304,7 +2302,7 @@ def _geometric(key, p, shape, dtype) -> Array: return g.astype(dtype) -def triangular(key: KeyArrayLike, +def triangular(key: ArrayLike, left: RealArray, mode: RealArray, right: RealArray, @@ -2368,7 +2366,7 @@ def _triangular(key, left, mode, right, shape, dtype) -> Array: return tri -def lognormal(key: KeyArrayLike, +def lognormal(key: ArrayLike, sigma: RealArray = np.float32(1), shape: Shape | None = None, dtype: DTypeLikeFloat = float) -> Array: @@ -2573,7 +2571,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array: def binomial( - key: KeyArray, + key: Array, n: RealArray, p: RealArray, shape: Shape | None = None, From 2c9b917b9d01149f4b6b5db1523fa742af413ced Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 20 Nov 2024 10:35:16 -0800 Subject: [PATCH 427/698] Don't psum over auto mesh dims in _unmentioned2. PiperOrigin-RevId: 698440525 --- jax/experimental/shard_map.py | 9 +++++---- tests/shard_map_test.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 4ad248c17ee2..c2673b55dd9a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1547,10 +1547,11 @@ def fun(*res_and_args): return jaxpr -def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: +def _unmentioned2(mesh: Mesh, names: AxisNames, + auto: frozenset[AxisName]) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} + name_set = {n for ns in names.values() for n in ns} | auto return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set] @@ -1559,7 +1560,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero else x if rewrite or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns)))) + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) @@ -1577,7 +1578,7 @@ def fun_trans(out_cts, args): ) out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero else x if rewrite - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns))) + else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) for ns, x in zip(in_names, out)] return out diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 84017bab5122..2a343f7ba784 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2046,6 +2046,29 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_grad_nested_partial_auto(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) + + def g(x): + return x * x + + def h(x): + return shard_map(g, mesh, + in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) + + @jax.jit + def f(x): + return shard_map(h, mesh, + in_specs=P('i', None), + out_specs=P('i', None), + check_rep=False, + auto=frozenset({'j'}))(x).sum() + + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False) + def test_axis_size_1_partial_auto(self): mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) From 9584ee3bb9c3a48299635a7c0a11df1029cf9f59 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 20 Nov 2024 10:41:24 -0800 Subject: [PATCH 428/698] [pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing indexing at all! PiperOrigin-RevId: 698442820 --- tests/pallas/mosaic_gpu_test.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 48c047697911..a4bbc67ee14f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1218,33 +1218,33 @@ def kernel_body(x_smem, o_smem): np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16]) def test_emit_with_parallel_grid(self): - self.skipTest("Enable once we support multiple levels of indexing") - - num_steps = 4 + num_steps1 = 4 + num_steps2 = 5 def kernel(x_gmem, o_gmem): - gmem_slice = pl.ds(pl.program_id(0) * 32, 32) + pid = pl.program_id(0) plgpu.emit_pipeline( kernel_body, - in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], - out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], - grid=(num_steps,), + in_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (pid, i))], + grid=(num_steps2,), max_concurrent_steps=2, - )(x_gmem.at[gmem_slice], o_gmem.at[gmem_slice]) + )(x_gmem, o_gmem) def kernel_body(x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 - x = jnp.arange(4 * 32 * num_steps * 16) - x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) + x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) kernel_fn = pl.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - grid=(4, 1), + grid=(num_steps1,), ) - np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + y = x + 1.0 + np.testing.assert_array_equal(kernel_fn(x), y) def test_emit_with_2d_grid(self): num_steps1 = 4 From 621e39de27098a941094fa332cf03f42018f3b91 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 10:47:23 -0800 Subject: [PATCH 429/698] Set __module__ attribute of jax.numpy.linalg APIs --- jax/_src/numpy/linalg.py | 36 ++++++++++++++++++++++++++++++++- tests/package_structure_test.py | 1 + 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 76a4abff48ad..be6828c36e6a 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -35,10 +35,13 @@ from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, ufuncs from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike -from jax._src.util import canonicalize_axis +from jax._src.util import canonicalize_axis, set_module from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg +export = set_module('jax.numpy.linalg') + + class EighResult(NamedTuple): eigenvalues: jax.Array eigenvectors: jax.Array @@ -67,6 +70,7 @@ def _H(x: ArrayLike) -> Array: def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 +@export @partial(jit, static_argnames=['upper']) def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """Compute the Cholesky decomposition of a matrix. @@ -191,6 +195,7 @@ def svd( ... +@export @partial( jit, static_argnames=( @@ -311,6 +316,7 @@ def svd( ) +@export @partial(jit, static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: """Raise a square matrix to an integer power. @@ -392,6 +398,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: return result +@export @jit def matrix_rank( M: ArrayLike, rtol: ArrayLike | None = None, *, @@ -496,6 +503,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]: return sign_diag * sign_taus, log_abs_det +@export @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ @@ -675,6 +683,7 @@ def _det_jvp(primals, tangents): return y, jnp.trace(z, axis1=-1, axis2=-2) +@export @jit def det(a: ArrayLike) -> Array: """ @@ -711,6 +720,7 @@ def det(a: ArrayLike) -> Array: raise ValueError(msg.format(a_shape)) +@export def eig(a: ArrayLike) -> tuple[Array, Array]: """ Compute the eigenvalues and eigenvectors of a square array. @@ -756,6 +766,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: return w, v +@export @jit def eigvals(a: ArrayLike) -> Array: """ @@ -793,6 +804,7 @@ def eigvals(a: ArrayLike) -> Array: compute_right_eigenvectors=False)[0] +@export @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: @@ -848,6 +860,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None, return EighResult(w, v) +@export @partial(jit, static_argnames=('UPLO',)) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ @@ -884,6 +897,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: # TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires. +@export def pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False, *, rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: @@ -997,6 +1011,7 @@ def _pinv_jvp(rtol, hermitian, primals, tangents): return p, p_dot +@export @jit def inv(a: ArrayLike) -> Array: """Return the inverse of a square matrix @@ -1057,6 +1072,7 @@ def inv(a: ArrayLike) -> Array: arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2])) +@export @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: int | str | None = None, axis: None | tuple[int, ...] | int = None, @@ -1222,6 +1238,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ... +@export @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: """Compute the QR decomposition of an array @@ -1305,6 +1322,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: return QRResult(q, r) +@export @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: """Solve a linear system of equations @@ -1408,6 +1426,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) +@export def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: """ @@ -1448,6 +1467,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, return _jit_lstsq(a, b, rcond) +@export def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): r"""Compute the cross-product of two 3D vectors @@ -1493,6 +1513,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): return jnp.cross(x1, x2, axis=axis) +@export def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute the outer product of two 1-dimensional arrays. @@ -1523,6 +1544,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: return x1[:, None] * x2[None, :] +@export def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array: """Compute the norm of a matrix or stack of matrices. @@ -1553,6 +1575,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1)) +@export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose a matrix or stack of matrices. @@ -1608,6 +1631,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) +@export def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: """Compute the vector norm of a vector or batch of vectors. @@ -1652,6 +1676,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa return norm(x, axis=axis, keepdims=keepdims, ord=ord) +@export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -1702,6 +1727,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, preferred_element_type=preferred_element_type) +@export def matmul(x1: ArrayLike, x2: ArrayLike, /, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: @@ -1762,6 +1788,7 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /, *, preferred_element_type=preferred_element_type) +@export def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, precision: PrecisionLike = None, @@ -1843,6 +1870,7 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, preferred_element_type=preferred_element_type) +@export def svdvals(x: ArrayLike, /) -> Array: """Compute the singular values of a matrix. @@ -1867,6 +1895,7 @@ def svdvals(x: ArrayLike, /) -> Array: return svd(x, compute_uv=False, hermitian=False) +@export def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: """Extract the diagonal of an matrix or stack of matrices. @@ -1907,6 +1936,7 @@ def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1) +@export def tensorinv(a: ArrayLike, ind: int = 2) -> Array: """Compute the tensor inverse of an array. @@ -1949,6 +1979,7 @@ def tensorinv(a: ArrayLike, ind: int = 2) -> Array: return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape) +@export def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array: """Solve the tensor equation a x = b for x. @@ -1998,6 +2029,7 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) return solve(a_arr, b_arr.ravel()).reshape(out_shape) +@export def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array: """Efficiently compute matrix products between a sequence of arrays. @@ -2090,6 +2122,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) - optimize='optimal', precision=precision) +@export @partial(jit, static_argnames=['p']) def cond(x: ArrayLike, p=None): """Compute the condition number of a matrix. @@ -2149,6 +2182,7 @@ def cond(x: ArrayLike, p=None): return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r) +@export def trace(x: ArrayLike, /, *, offset: int = 0, dtype: DTypeLike | None = None) -> Array: """Compute the trace of a matrix. diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 9bc8d0f6d71c..25468c4ba700 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -40,6 +40,7 @@ class PackageStructureTest(jtu.JaxTestCase): "number", "object_", "printoptions", "save", "savez", "set_printoptions", "shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"] ), + _mod("jax.numpy.linalg"), _mod("jax.nn.initializers"), _mod( "jax.tree_util", From dfe27a16825663ea3a90417ad452e99dc43d7f53 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 20 Nov 2024 14:53:52 -0500 Subject: [PATCH 430/698] Mention stackless in the release notes. --- CHANGELOG.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37fd68bce39d..be9aaebcd615 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.36 * Breaking Changes + * This release lands "stackless", an internal change to JAX's tracing + machinery. We made trace dispatch purely a function of context rather than a + function of both context and data. This let us delete a lot of machinery for + managing data-dependent tracing: levels, sublevels, `post_process_call`, + `new_base_main`, `custom_bind`, and so on. The change should only affect + users that use JAX internals. + + If you do use JAX internals then you may need to + update your code (see + https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f + for clues about how to do this). There might also be version skew + issues with JAX libraries that do this. If you find this change breaks your + non-JAX-internals-using code then try the + `config.jax_data_dependent_tracing_fallback` flag as a workaround, and if + you need help updating your code then please file a bug. * {func}`jax.experimental.jax2tf.convert` with `native_serialization=False` or with `enable_xla=False` have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. `jax2tf` From 40fc6598f96999271a3c19cfaab6f02579c003d6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 13:06:39 -0800 Subject: [PATCH 431/698] [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs. Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too. Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`. PiperOrigin-RevId: 698493184 --- jax/BUILD | 1 + jax/_src/config.py | 5 ++++- jax/_src/core.py | 30 +++++++++++++++++++++------- jax/_src/mesh.py | 25 ++++++++++++++++++++--- jax/_src/pallas/core.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pjit.py | 32 ++++++++++++++++++++++++++---- jax/_src/state/primitives.py | 5 ++++- jax/experimental/shard_map.py | 5 ++++- 9 files changed, 88 insertions(+), 19 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 26694fec2ad3..64bfa627f42e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -451,6 +451,7 @@ pytype_strict_library( ":deprecations", ":dtypes", ":effects", + ":mesh", ":pretty_printer", ":source_info_util", ":traceback_util", diff --git a/jax/_src/config.py b/jax/_src/config.py index eff9b757b95b..2723b4f90d3b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -209,7 +209,9 @@ def trace_context(): Values included in this set should also most likely be included in the C++ JIT state, which is handled separately. """ - return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value, + return (axis_env_state.value, mesh_context_manager.value, + xla_metadata_context_manager.value, + abstract_mesh_context_manager.value, compute_on_context_manager.value, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, @@ -969,6 +971,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: trace_state = config_ext.Config(None, include_in_jit_key=True) axis_env_state = config_ext.Config((), include_in_jit_key=True) mesh_context_manager = config_ext.Config((), include_in_jit_key=True) + abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True) compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) else: diff --git a/jax/_src/core.py b/jax/_src/core.py index cbf3282fb2cc..1bd3fb4fa889 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -38,6 +38,7 @@ from jax._src import config from jax._src import effects from jax._src import compute_on +from jax._src import mesh as mesh_lib from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -1596,6 +1597,23 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) + +def get_sharding(sharding, ndim): + from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore + + if sharding is not None: + assert len(sharding.spec) == ndim + return sharding + + context_mesh = mesh_lib.mesh_context.mesh + # TODO(yashkatariya): Error out and ask users to set the context mesh in their + # code. + if context_mesh is None: + return None + assert sharding is None + return NamedSharding(context_mesh, P(*[None] * ndim)) + + class ShapedArray(UnshapedArray): __slots__ = ['shape', 'sharding'] # inherits slots from parent array_abstraction_level = 2 @@ -1605,20 +1623,18 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None): self.dtype = _dtype_object(dtype) self.weak_type = weak_type if config.sharding_in_types.value: - if sharding is not None: - assert len(sharding.spec) == len(self.shape) - self.sharding = sharding + self.sharding = get_sharding(sharding, len(self.shape)) - def update(self, shape=None, dtype=None, weak_type=None, sharding=None): + def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: shape = self.shape if dtype is None: dtype = self.dtype if weak_type is None: weak_type = self.weak_type - if sharding is None: - sharding = getattr(self, 'sharding', None) - return ShapedArray(shape, dtype, weak_type, sharding=sharding) + if 'sharding' not in kwargs: + kwargs['sharding'] = getattr(self, 'sharding', None) + return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) size = property(lambda self: diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index a2ab261fa0e9..3d0e1b0cccf5 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -107,6 +107,9 @@ class AxisTypes(enum.Enum): User = enum.auto() Collective = enum.auto() + def __repr__(self): + return self.name + def axis_names_to_types(axis_types) -> dict[str, AxisTypes]: if axis_types is None: return {} @@ -452,14 +455,22 @@ def local_mesh(self): _raise_value_error("local_mesh") def __enter__(self): - raise RuntimeError("AbstractMesh is not a context manager") + mesh_context.stack.append(self) + mesh_context.mesh = self + jax_config.abstract_mesh_context_manager.set_local( + tuple(m for m in mesh_context.stack if m is not None)) + return self def __exit__(self, exc_type, exc_value, traceback): - raise RuntimeError("AbstractMesh is not a context manager") + mesh_context.stack.pop() + mesh_context.mesh = mesh_context.stack[-1] + jax_config.abstract_mesh_context_manager.set_local( + tuple(m for m in mesh_context.stack if m is not None)) + return False @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): - jax_config.mesh_context_manager.set_local(mesh) + jax_config.abstract_mesh_context_manager.set_local(mesh) return @@ -467,3 +478,11 @@ def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): # property raises an exception unconditionally. Remove this once that is fixed. def _raise_value_error(name): raise ValueError(f"AbstractMesh does not implement {name}") + + +class MeshContext(threading.local): + def __init__(self): + self.stack = [None] + self.mesh = self.stack[-1] + +mesh_context = MeshContext() diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 72ed07674f1f..cf1e0b524963 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -873,7 +873,7 @@ def get_grid_mapping( ) # The inputs for the index maps index_map_avals = ( - (index_map_grid_aval,) * len(grid_spec.grid)) + (index_map_grid_aval.update(sharding=None),) * len(grid_spec.grid)) index_map_tree = tree_util.tree_structure((index_map_avals, {})) num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index be4102dff716..c9e20843a806 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1380,7 +1380,7 @@ def _masked_swap_lowering_rule( 1 if b is pallas_core.mapped else next(mem_slice_shape_iter) for b in ref_block_shape ] - mem_aval = aval_out.update(shape=tuple(mem_slice_shape)) + mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None) mem_aval_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)) if need_stride: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f1844c7ba13b..aff956862753 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -16,6 +16,7 @@ from collections import defaultdict from collections.abc import Callable, Sequence, Iterable +import contextlib import dataclasses from functools import partial import inspect @@ -637,10 +638,13 @@ def _infer_params_impl( in_avals, in_tree, dbg, device_or_backend_set, have_kwargs) attr_token = _attr_token(flat_fun, in_type) - jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( - flat_fun, in_type, attr_token, dbg, - HashableFunction(res_paths, closure=()), - IgnoreKey(ji.inline)) + + abstract_mesh = get_abstract_mesh(in_type) + with abstract_mesh: + jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( + flat_fun, in_type, attr_token, dbg, + HashableFunction(res_paths, closure=()), + IgnoreKey(ji.inline)) _attr_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( @@ -683,6 +687,26 @@ def _infer_params_impl( attrs_tracked), args_flat +def get_abstract_mesh(in_avals): + if not config.sharding_in_types.value: + return contextlib.nullcontext() + m = None + for a in in_avals: + # TODO(yashkatariya): Remove this when mesh context can be set by the user. + if a.sharding is None: # type: ignore + continue + if m is not None and m != a.sharding.mesh: + raise ValueError( + f'Mesh for all inputs should be equal. Got one mesh: {m} and' + f' another mesh: {a.sharding.mesh}') + m = a.sharding.mesh # type: ignore + # TODO(yashkatariya): Remove this when mesh context can be set by the user. + if m is None: + return contextlib.nullcontext() + assert m is not None + return m + + class InferParamsCacheEntry: """Mutable value object for _infer_params_cached.""" __slots__ = ['pjit_params'] diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 0897e778d079..14d42ad0809c 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -214,7 +214,10 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) - out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype) + # TODO(yashkatariya): Transform the sharding too instead of setting it to + # None. + out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype, + sharding=None) else: if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c2673b55dd9a..07f631f6ec49 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -483,7 +483,8 @@ def _shard_map_staging( in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) - with core.extend_axis_env_nd(list(mesh.shape.items())): + with (core.extend_axis_env_nd(list(mesh.shape.items())), + pjit.get_abstract_mesh(in_avals_)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: @@ -547,6 +548,8 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) + # TODO(yashkatariya): Reset the mesh properly based on the input avals if the + # mesh of shard_map specifies collective axes. if config.sharding_in_types.value: spec = _names_to_pspec(names)._normalized_spec(aval.ndim) new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec) From 9d2f62f811e23d4c9b2c33d923fff70ed78a4acf Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 20 Nov 2024 14:03:12 -0800 Subject: [PATCH 432/698] [Pallas TPU] Support masked store PiperOrigin-RevId: 698514079 --- jax/_src/pallas/mosaic/lowering.py | 23 ++++++++++++++++--- tests/pallas/tpu_pallas_test.py | 36 ++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index c9e20843a806..1f0062cad0f9 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -42,6 +42,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -1315,12 +1316,20 @@ def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): ref, transforms, val, mask = args_tree.unflatten(args_flat) - ref_aval, transforms_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in) + ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten( + ctx.avals_in + ) (*prev_transforms, idx) = transforms (*_, idx_aval) = transforms_avals if mask is not None: - raise NotImplementedError + if val_aval.dtype.itemsize != 4: + raise NotImplementedError("masked swap with non-32-bit data") + if val_aval.shape != mask_aval.shape: + raise ValueError( + "Expected value and mask to have the same shape, but got" + f" value shape {val_aval.shape} vs. mask shape {mask_aval.shape}." + ) ref_block_shape, *_ = ctx.block_shapes ref, ref_block_shape = _transform_ref( @@ -1351,6 +1360,8 @@ def _masked_swap_lowering_rule( need_stride = not all((s is None or s == 1) for s in strides) if is_smem_store: + if mask is not None: + raise ValueError("SMEM store does not support masks") if val_aval.shape: raise ValueError("Can only store scalars to SMEM") result = memref.load(ref, starts) @@ -1399,9 +1410,15 @@ def _masked_swap_lowering_rule( result = _maybe_cast_load_to_bool(val_aval, result) if need_stride: + if mask is not None: + raise NotImplementedError("masked swap with strided store") tpu.StridedStoreOp(val, ref, starts, strides) - else: + elif jaxlib_version <= (0, 4, 35): + if mask is not None: + raise NotImplementedError("masked swap with vector store") vector.StoreOp(val, ref, starts) + else: + tpu.VectorStoreOp(val, ref, starts, [], mask=mask) return result diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 347a06c50323..9c4788d7447f 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1723,6 +1723,42 @@ def test(x: jax.Array) -> jax.Array: y = test(x) np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1)) + def test_masked_store(self): + if jtu.jaxlib_version() <= (0, 4, 35): + self.skipTest("Test requires masked store support") + shape = (16, 256) + mask_shape = (10, 130) + mask_start = (4, 5) + dtype = jnp.float32 + def body(scalar_ref, x_ref, o_ref): + o_ref[...] = jnp.full(shape, -1, dtype=dtype) + b0, b1 = scalar_ref[0], scalar_ref[1] + e0, e1 = b0 + mask_shape[0], b1 + mask_shape[1] + iota0 = lax.broadcasted_iota(jnp.int32, shape, 0) + iota1 = lax.broadcasted_iota(jnp.int32, shape, 1) + mask0 = jnp.logical_and(b0 <= iota0, iota0 < e0) + mask1 = jnp.logical_and(b1 <= iota1, iota1 < e1) + pl.store( + o_ref, + (slice(None), slice(None)), + x_ref[...], + mask=jnp.logical_and(mask0, mask1), + ) + + s = jnp.array(mask_start, jnp.int32) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + out = pl.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + ), + )(s, x) + slices = tuple(slice(b, b + l) for b, l in zip(mask_start, mask_shape)) + expected = jnp.full(shape, -1, dtype=dtype) + expected = expected.at[slices].set(x[slices]) + np.testing.assert_array_equal(out, expected) + class PallasUXTest(PallasBaseTest): From 9b941808463ee614a541c5174ea98cb5c58a080b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 14:29:59 -0800 Subject: [PATCH 433/698] [sharding_in_types] Add slice_p and squeeze_p sharding rule to make flash attention work in backward pass For `slice_p`'s sharding rule, I error out if the operand dim is sharded and the output dim is not divisible by that axis size. I am working on a design to make JAX support uneven sharding at the top level after which slice_p's sharding rule can just `return operand.sharding`. Another option is to add `out_sharding` to `slice` but after uneven sharding support lands, it won't be necessary. PiperOrigin-RevId: 698522980 --- jax/_src/lax/lax.py | 14 ++++++++++++-- jax/_src/lax/slicing.py | 34 +++++++++++++++++++++++++++++++--- jax/_src/pallas/core.py | 4 ++++ tests/pjit_test.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 39c5bca5819c..79e48c440271 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4527,6 +4527,12 @@ def _squeeze_dtype_rule(operand, *, dimensions): def _squeeze_shape_rule(operand, *, dimensions): return _compute_squeeze_shape(np.shape(operand), dimensions) +def _squeeze_sharding_rule(operand, *, dimensions): + dims_set = set(dimensions) + new_spec = tuple(s for i, s in enumerate(operand.sharding.spec) + if i not in dims_set) + return NamedSharding(operand.sharding.mesh, P(*new_spec)) + def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) if len(dims_set) != len(dimensions): @@ -4555,7 +4561,7 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze') + 'squeeze', sharding_rule=_squeeze_sharding_rule) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -4563,7 +4569,11 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. - return [mlir.reshape(ctx, operand, ctx.avals_out[0])] + aval_out, = ctx.avals_out + out = mlir.reshape(ctx, operand, aval_out) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(squeeze_p, _squeeze_lower) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 40a04ff11d2c..117c8b655152 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -42,6 +42,7 @@ _input_dtype, standard_primitive, ) +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike, Shape @@ -1270,6 +1271,29 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): return tuple(core.stride_dim(d, window_size=1, window_stride=s) for d, s in zip(diff, strides)) +def _get_sub_spec_size(mesh, sub_spec): + if isinstance(sub_spec, tuple): + return math.prod(mesh.shape[s] for s in sub_spec) + return mesh.shape[sub_spec] + +def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _slice_shape_rule(operand, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + mesh = operand.sharding.mesh + new_spec = [] + for op_sh, out_sh, op_spec in safe_zip( + operand.shape, out_shape, operand.sharding.spec): + if (op_sh != out_sh and op_spec is not None and + out_sh % _get_sub_spec_size(mesh, op_spec) != 0): + raise NotImplementedError( + f"slicing on sharded dims where out dim ({out_sh}) is not divisble by" + f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" + f" ({op_spec}) is not implemented.") + new_spec.append(op_spec) + return NamedSharding(mesh, P(*new_spec)) + def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) operand_shape = operand.aval.shape @@ -1308,7 +1332,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, out = slice(operand, new_start_indices, new_limit_indices, new_strides) return out, bdim -slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice') +slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', + sharding_rule=_slice_sharding_rule) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule # TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries @@ -1333,8 +1358,11 @@ def _slice_impl(x, start_indices, limit_indices, strides): def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): strides = strides or [1] * len(start_indices) aval_out, = ctx.avals_out - return [mlir.slice_op(ctx, x, aval_out, - start_indices=start_indices, limit_indices=limit_indices, strides=strides)] + out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(slice_p, _slice_lower) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index cf1e0b524963..acbf0d4f7ed5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -219,6 +219,10 @@ def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' + @property + def sharding(self): + return self.inner_aval.sharding + def update_weak_type(self, weak_type): return AbstractMemoryRef( self.inner_aval.update_weak_type(weak_type), self.memory_space) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e32424cfdded..e52c805ef5e6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5285,6 +5285,43 @@ def f(x, y): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_slice(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(4, 4) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) + + @jax.jit + def f(x): + y = lax.slice(x, (0, 0), (4, 3)) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertIn('@Sharding', f.lower(arr).as_text()) + + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) + + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) + + def test_squeeze(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(4, 4, 1) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) + + @jax.jit + def f(x): + y = lax.squeeze(x, (2,)) + self.assertEqual(y.sharding.spec, P('x', None)) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertIn('@Sharding', f.lower(arr).as_text()) + self.assertArraysEqual(out, np.squeeze(np_inp, axis=2)) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 6fe78042b51d066e1b886dcf8f77df627831c00e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 20 Nov 2024 14:37:36 -0800 Subject: [PATCH 434/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e763f8875b0a9bfca876be9b02c874979e55422a. PiperOrigin-RevId: 698525361 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a554cfd03687..327e4ca422ac 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fcee07f619a765db815d9ed4e2bc229275818a2b" -XLA_SHA256 = "1dd144e64e2c2dcc20a2130e10607fec7b3a810926ba912918dd5437698a3375" +XLA_COMMIT = "e763f8875b0a9bfca876be9b02c874979e55422a" +XLA_SHA256 = "7b6a33894c6510167cac6e0ab7a6331ffa84e7fcaaa1d3b1c462ec5ecacb0682" def repo(): tf_http_archive( From f749fca760cbcd019bed8f9e1a64cf525c0bcb14 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 14:50:06 -0800 Subject: [PATCH 435/698] [array api] use most recent version of array_api_tests --- .github/workflows/jax-array-api.yml | 2 +- pyproject.toml | 1 - tests/array_api_skips.txt | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 763a4c04be5d..8f2029eb9191 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'a3f3f376308e64f0ac15b307dfe27be945409e41' # Latest commit as of 2024-11-14 + ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index 73e1c51fc8af..d688f7fbbf01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ filterwarnings = [ # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", - "default:The .* method is good for exploring strategies.*", # NOTE: this is probably not where you want to add code to suppress a # warning. Only pytest tests look at this list, whereas Bazel tests also diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2ac2edcdfd99..e1d4c35eae68 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -6,6 +6,7 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] # Test suite attempts in-place mutation: array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking +array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] From 2699e9507e462047fb853c32768812467ec1c13c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 15:13:14 -0800 Subject: [PATCH 436/698] DOC: add examples for jax.lax.pad --- jax/_src/lax/lax.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 79e48c440271..1bf1ea816ca7 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1306,6 +1306,36 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, Returns: The ``operand`` array with padding value ``padding_value`` inserted in each dimension according to the ``padding_config``. + + Examples: + >>> from jax import lax + >>> import jax.numpy as jnp + + Pad a 1-dimensional array with zeros, We'll specify two zeros in front and + three at the end: + + >>> x = jnp.array([1, 2, 3, 4]) + >>> lax.pad(x, 0, [(2, 3, 0)]) + Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32) + + Pad a 1-dimensional array with *interior* zeros; i.e. insert a single zero + between each value: + + >>> lax.pad(x, 0, [(0, 0, 1)]) + Array([1, 0, 2, 0, 3, 0, 4], dtype=int32) + + Pad a 2-dimensional array with the value ``-1`` at front and end, with a pad + size of 2 in each dimension: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) + Array([[-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, 1, 2, 3, -1, -1], + [-1, -1, 4, 5, 6, -1, -1], + [-1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) From 17825882d2bb87b84387963bfaf53ce191cbf71b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 20 Nov 2024 16:21:45 -0800 Subject: [PATCH 437/698] jax.lax.pad: improve input validation --- jax/_src/lax/lax.py | 3 ++- tests/lax_test.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 79e48c440271..934f100ffe34 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4441,7 +4441,8 @@ def _pad_dtype_rule(operand, padding_value, *, padding_config): return _input_dtype(operand, padding_value) def _pad_shape_rule(operand, padding_value, *, padding_config): - del padding_value + if np.ndim(padding_value) != 0: + raise ValueError(f"padding_value must be a scalar; got {np.shape(padding_value)=}") op_shape = np.shape(operand) if not len(padding_config) == np.ndim(operand): raise ValueError("length of padding_config must equal the number of axes " diff --git a/tests/lax_test.py b/tests/lax_test.py index 78bc5857acb7..10fa8c006184 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1597,6 +1597,8 @@ def testPadAgainstNumpy(self, shape, dtype, pads): self._CheckAgainstNumpy(numpy_op, op, args_maker) def testPadErrors(self): + with self.assertRaisesRegex(ValueError, "padding_value must be a scalar"): + lax.pad(np.zeros(2), np.zeros(2), [(0, 0, 0)]) with self.assertRaisesRegex(ValueError, "padding_config"): lax.pad(np.zeros(2), 0., [(0, 1, 0), (0, 1, 0)]) with self.assertRaisesRegex(ValueError, "interior padding in padding_config must be nonnegative"): From bf7f9aa8f27da525bc0a1e42a3b6e15c1f93b2f4 Mon Sep 17 00:00:00 2001 From: barnesjoseph Date: Wed, 20 Nov 2024 16:38:58 -0800 Subject: [PATCH 438/698] Adds Google Sans font --- docs/_static/style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/_static/style.css b/docs/_static/style.css index d801c2a412a6..36b54b8432f0 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -1,4 +1,5 @@ @import url("theme.css"); +@import url('https://fonts.googleapis.com/css2?family=Google+Sans'); /* Base LP sidebar modifications */ body:has(.hero) .sidebar-toggle, From 1f6152d11e28bc94ab86906ae07d43967f8c759e Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 20 Nov 2024 17:12:01 -0800 Subject: [PATCH 439/698] [Pallas] Use Pallas cost estimator for flash attention. PiperOrigin-RevId: 698573265 --- jax/_src/pallas/cost_estimate.py | 40 ++++++++++++++++--- jax/experimental/pallas/__init__.py | 1 + .../pallas/ops/tpu/flash_attention.py | 27 +++++++------ tests/pallas/pallas_cost_estimate_test.py | 10 ++--- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 1bcf704b3579..b83c36159555 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -16,9 +16,12 @@ import math from typing import Any, Sequence +import jax from jax._src import core as jax_core -from jax._src.pallas import core as pallas_core +from jax._src import custom_derivatives from jax._src import linear_util as lu +from jax._src import pjit +from jax._src.pallas import core as pallas_core from jax._src.interpreters import partial_eval as pe from jax._src.util import safe_map from jax._src.util import safe_zip @@ -71,22 +74,28 @@ def cost_estimate_jaxpr( bytes_accessed=total_cost.bytes_accessed, ) -def cost_estimate(fun, *args) -> pallas_core.CostEstimate: +def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate: """Computes a cost estimate for the given function. Args: fun: The function to compute the cost estimate for. *args: The arguments to the function. Can be jax.ShapeDtypeStruct or jax.Array. + **kwargs: The keyword arguments to the function. Returns: A pallas_core.CostEstimate object containing the cost estimate. """ - wrapped_fun = lu.wrap_init(lambda *args, **kwargs: (fun(*args, **kwargs),)) - avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in args] + flattened_args, treedef = jax.tree.flatten(args) + def _partial_fun(*flat_args): + return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs) + wrapped_fun = lu.wrap_init( + lambda *args, **kwargs: (_partial_fun(*args, **kwargs),)) + avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) - input_bytes = sum(math.prod(a.shape) * a.dtype.itemsize for a in args) + input_bytes = sum( + math.prod(a.shape) * a.dtype.itemsize for a in flattened_args) output_bytes = sum( math.prod(a.aval.shape) * a.aval.dtype.itemsize for a in jaxpr.outvars) return pallas_core.CostEstimate( @@ -213,3 +222,24 @@ def dot_general_cost_rule(ctx: Context, bytes_accessed=0, ) register_cost_rule(lax.dot_general_p, dot_general_cost_rule) + +# Higher-order primitives +def _pjit_cost_rule(ctx, *, jaxpr: jax_core.ClosedJaxpr, **_): + del ctx + inner_cost = cost_estimate_jaxpr(jaxpr) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(pjit.pjit_p, _pjit_cost_rule) + +def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): + del ctx + inner_cost = cost_estimate_jaxpr(fun_jaxpr) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 34cb5328f36a..7e6527ad999a 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -30,6 +30,7 @@ from jax._src.pallas.core import no_block_spec as no_block_spec from jax._src.pallas.core import Unblocked as Unblocked from jax._src.pallas.core import unblocked as unblocked +from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost from jax._src.pallas.pallas_call import pallas_call as pallas_call from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p from jax._src.pallas.primitives import atomic_add as atomic_add diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 9b122fcc03ef..0cb3d798d09e 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -574,26 +574,23 @@ def _fwd_cost_estimate( q: jax.Array, k: jax.Array, v: jax.Array, + ab: jax.Array | None, + segment_ids: SegmentIds | None, *, + causal: bool, + sm_scale: jax.Array | None, kernel_inputs_specs, kernel_outputs_specs, ) -> pl.CostEstimate | None: - b, h, tq, dqk = q.shape - tk = k.shape[-2] - dv = v.shape[-1] - - # Simplify flop computation to include only matmul operations. - qk_flops = 2 * tq * tk * dqk - av_flops = 2 * tq * tk * dv - per_head_flops = qk_flops + av_flops - flops = b * h * per_head_flops - - transcendentals = b * tq * tk * h + body_cost = pl.estimate_cost( + mha_reference, + q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale + ) input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs)) output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs)) return pl.CostEstimate( - flops=flops, - transcendentals=transcendentals, + flops=body_cost.flops, + transcendentals=body_cost.transcendentals, bytes_accessed=input_bytes + output_bytes, ) @@ -790,6 +787,10 @@ def kv_segment_ids_index_map( q, k, v, + ab, + segment_ids, + causal=causal, + sm_scale=sm_scale, kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids), kernel_outputs_specs=out_shape, ), diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py index 74dd150fbc10..fcdeac4cab82 100644 --- a/tests/pallas/pallas_cost_estimate_test.py +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -29,7 +29,7 @@ class PallasCostEstimateTest(jtu.JaxTestCase): def test_exp_add(self): def exp_add(x, y): return jnp.exp(x + y) - cost = cost_estimate.cost_estimate(exp_add, + cost = cost_estimate.estimate_cost(exp_add, jnp.ones(10, dtype=jnp.float32), jnp.ones(10, dtype=jnp.float32)) self.assertEqual(cost.flops, 10) @@ -40,7 +40,7 @@ def test_very_large_matmul(self): def matmul(a, b): return a @ b m, k, n = 400_000, 800_000, 900_000 - cost = cost_estimate.cost_estimate( + cost = cost_estimate.estimate_cost( matmul, jax.ShapeDtypeStruct((m, k), jnp.bfloat16), jax.ShapeDtypeStruct((k, n), jnp.bfloat16)) @@ -52,7 +52,7 @@ def test_batched_matmul(self): def matmul(a, b): return jnp.matmul(a, b) b, m, k, n = 7, 37, 91, 23 - cost = cost_estimate.cost_estimate( + cost = cost_estimate.estimate_cost( matmul, jax.ShapeDtypeStruct((b, m, k), jnp.float32), jax.ShapeDtypeStruct((b, k, n), jnp.float32)) @@ -67,7 +67,7 @@ def test_attention(self): q_len = 64 def attention(q, k, v): return jax.nn.softmax(q @ k.T, axis=-1) @ v - cost = cost_estimate.cost_estimate( + cost = cost_estimate.estimate_cost( attention, jnp.zeros((q_len, qk_dim), dtype=jnp.float32), jnp.zeros((kv_len, qk_dim), dtype=jnp.float32), @@ -85,7 +85,7 @@ def attention(q, k, v): (1, 0), (7, 5), (8, 4), (9, 5) ) def test_integer_pow(self, power, expected_flops_per_element): - cost = cost_estimate.cost_estimate(lambda x: lax.integer_pow(x, power), + cost = cost_estimate.estimate_cost(lambda x: lax.integer_pow(x, power), jnp.ones(10, dtype=jnp.float32)) self.assertEqual(cost.flops, 10 * expected_flops_per_element) self.assertEqual(cost.transcendentals, 0) From 840cf3f7d20ce06861a5c48c684ac9a61009d856 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 17:12:29 -0800 Subject: [PATCH 440/698] [sharding_in_types] Add `pad_p` support to sharding_in_types to handle transpose to slice correctly. PiperOrigin-RevId: 698573396 --- jax/_src/lax/lax.py | 27 +++++++++++++++++-- tests/pjit_test.py | 66 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4a9925ce33a4..9a27460906ab 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4490,6 +4490,25 @@ def _pad_shape_rule(operand, padding_value, *, padding_config): raise ValueError(msg) return result +def _pad_sharding_rule(operand, padding_value, *, padding_config): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _pad_shape_rule(operand, padding_value, + padding_config=padding_config) + mesh = operand.sharding.mesh + new_spec = [] + for op_sh, out_sh, op_spec in safe_zip( + operand.shape, out_shape, operand.sharding.spec): + if (op_sh != out_sh and op_spec is not None and + out_sh % slicing._get_sub_spec_size(mesh, op_spec) != 0): + raise NotImplementedError( + f"padding on sharded dims where out dim ({out_sh}) is not divisble by" + f" mesh axes ({slicing._get_sub_spec_size(mesh, op_spec)}) with spec" + f" ({op_spec}) is not implemented.") + new_spec.append(op_spec) + return NamedSharding(mesh, P(*new_spec)) + + def _pad_transpose(t, operand, padding_value, *, padding_config): if type(t) is ad_util.Zero: t_operand = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None @@ -4529,14 +4548,18 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): (operand_bdim,)) return select(mask, x, broadcasted_padding), operand_bdim -pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad') +pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', + sharding_rule=_pad_sharding_rule) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule def _pad_lower(ctx, x, padding_value, *, padding_config): aval_out, = ctx.avals_out low, high, interior = util.unzip3(padding_config) - return [mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)] + out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(pad_p, _pad_lower) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e52c805ef5e6..372d6c334f5d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5287,7 +5287,7 @@ def f(x, y): def test_slice(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(4, 4) + np_inp = np.arange(16.).reshape(4, 4) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @jax.jit @@ -5300,6 +5300,16 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) self.assertIn('@Sharding', f.lower(arr).as_text()) + def g(x): + out = f(x) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) @@ -5308,7 +5318,7 @@ def f(x): def test_squeeze(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - np_inp = np.arange(16).reshape(4, 4, 1) + np_inp = np.arange(16.).reshape(4, 4, 1) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) @jax.jit @@ -5322,6 +5332,58 @@ def f(x): self.assertIn('@Sharding', f.lower(arr).as_text()) self.assertArraysEqual(out, np.squeeze(np_inp, axis=2)) + def g(x): + out = f(x) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + def test_pad(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + np_inp = np.arange(8.) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + + @partial(jax.jit, static_argnums=(1, 2)) + def f(x, padding_config, spec): + y = lax.pad(x, 0., padding_config) + self.assertEqual(y.sharding.spec, spec) + return y + + out = f(arr, ((2, 2, 0),), P('x')) + self.assertArraysEqual(out, np.pad(np_inp, 2)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertIn('@Sharding', f.lower(arr, ((2, 2, 0),), P('x')).as_text()) + + out = f(arr, ((0, 0, 0),), P('x')) + self.assertArraysEqual(out, np_inp) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + f(arr, ((0, 3, 1), ), P('x')) # doesn't crash + + def g(x): + out = f(x, ((2, 2, 0),), P('x')) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + f(arr, ((2, 3, 0), ), None) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + f(arr, ((0, 3, 0), ), None) + + with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) + f(arr, ((4, 4, 1),), None) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 869a53345d1551071ce613d56f1f18cce20837e3 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 20 Nov 2024 17:27:25 -0800 Subject: [PATCH 441/698] [Mosaic TPU] Add bound check for general vector store op. PiperOrigin-RevId: 698577015 --- .../dialect/tpu/transforms/debug_assert_insertion.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc index 5478c64f9944..846e3bbb341f 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc @@ -122,6 +122,14 @@ void tpu_strided_store_rule(tpu::StridedStoreOp op) { /*strides=*/op.getStrides()); } +void tpu_vector_store_rule(tpu::VectorStoreOp op) { + // TODO(b/379925823): Take strides into account. + assertIsValidSubwindow( + op, op.getIndices(), + /*window_shape=*/op.getValueToStore().getType().getShape(), + /*full_shape=*/op.getBase().getType().getShape()); +} + const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ // TODO: tpu::LoadOp, tpu::StoreOp @@ -133,6 +141,8 @@ const llvm::StringMap &rules() { as_generic_rule(tpu_strided_load_rule)}, {tpu::StridedStoreOp::getOperationName(), as_generic_rule(tpu_strided_store_rule)}, + {tpu::VectorStoreOp::getOperationName(), + as_generic_rule(tpu_vector_store_rule)}, }; return *rules; } From 6568713a046b46c8e7f484f7d1db653e20d3aded Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 20 Nov 2024 20:12:01 -0800 Subject: [PATCH 442/698] [sharding_in_types] Add `concatenate_p` support PiperOrigin-RevId: 698621325 --- jax/_src/lax/lax.py | 20 +++++++++++++++++--- tests/pjit_test.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9a27460906ab..0519fa48f45a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1700,6 +1700,8 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) + if config.sharding_in_types.value: + return broadcast(scalar_zero, aval.shape, sharding=aval.sharding) return broadcast(scalar_zero, aval.shape) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array @@ -4401,7 +4403,7 @@ def _concatenate_shape_rule(*operands, **kwargs): raise TypeError(msg.format(dimension, ", ".join([str(o.shape) for o in operands]))) shapes = [operand.shape[:dimension] + operand.shape[dimension+1:] for operand in operands] - if not shapes[:-1] == shapes[1:]: + if shapes[:-1] != shapes[1:]: msg = ("Cannot concatenate arrays with shapes that differ in dimensions " "other than the one being concatenated: concatenating along " "dimension {} for shapes {}.") @@ -4412,6 +4414,13 @@ def _concatenate_shape_rule(*operands, **kwargs): ex_shape = operands[0].shape return ex_shape[:dimension] + (concat_size,) + ex_shape[dimension+1:] +def _concatenate_sharding_rule(*operands, **kwargs): + if not all(o.sharding == operands[0].sharding for o in operands): + ss = ", ".join(str(o.sharding) for o in operands) + raise TypeError( + f"All operands should have the same sharding. Got shardings {ss}") + return operands[0].sharding + def _concatenate_dtype_rule(*operands, **kwargs): check_same_dtypes('concatenate', *operands) return operands[0].dtype @@ -4452,14 +4461,19 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): raise NotImplementedError # TODO(mattjj) concatenate_p = standard_primitive( - _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate') + _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', + sharding_rule=_concatenate_sharding_rule) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): - return [hlo.concatenate(xs, mlir.i64_attr(dimension))] + aval_out, = ctx.avals_out + out = hlo.concatenate(xs, mlir.i64_attr(dimension)) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(concatenate_p, _concatenate_lower) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 372d6c334f5d..dd1415b680a4 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5384,6 +5384,48 @@ def g(x): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) + def test_concatenate(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + np_inp = np.arange(16.).reshape(4, 4) + s = NamedSharding(mesh, P('x', 'y')) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np.arange(4.).reshape(4, 1), s) + + @partial(jax.jit, static_argnums=2) + def f(x, y, method='jnp'): + if method == 'jnp': + y = jnp.concatenate([x, y], axis=1) + else: + assert method == 'lax' + y = lax.concatenate([x, y], dimension=1) + self.assertEqual(y.sharding.spec, P('x', 'y')) + return y + + out = f(arr1, arr2) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) + self.assertIn('@Sharding', f.lower(arr1, arr2).as_text()) + + out = f(arr1, arr2, method='lax') + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) + + with self.assertRaisesRegex( + TypeError, "All operands should have the same sharding"): + arr3 = jax.device_put(np.arange(4.).reshape(4, 1), + NamedSharding(mesh, P('x'))) + f(arr1, arr3) + + def g(x, y): + out = f(x, y) + return jnp.square(jnp.sum(out)) + + out = jax.grad(g)(arr1, arr2) + self.assertEqual(out.sharding, s) + + out = jax.jit(jax.grad(g))(arr1, arr2) + self.assertEqual(out.sharding, s) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From e72b449089f6af4ceb18288e36215b3c76e69245 Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Wed, 20 Nov 2024 22:45:05 -0800 Subject: [PATCH 443/698] Reverts c04aec9d525dd2e767495e41b98e82dd79315f37 PiperOrigin-RevId: 698654038 --- jaxlib/mosaic/dialect/tpu/tpu.td | 5 +-- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 36 ++++----------- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 44 ++++++------------- 3 files changed, 24 insertions(+), 61 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index de5e3514fc1d..8a4f573bce24 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -654,15 +654,14 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { I32:$amount, Optional:$device_id, // For remote DMAs Optional:$core_id, // For megacore - Optional:$subcore_id, // For the SC vector subcore OptionalAttr:$core_type ); let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`subcore_id` $subcore_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; let builders = [ - // A backward-compatible builder that sets `subcore_id` and `core_type` to nullptr. + // A backward-compatible builder that sets `core_type` to nullptr. OpBuilder<(ins "Value":$semaphore, "Value":$amount, "Value":$device_id, "Value":$core_id)>, ]; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 8586e2a16c8a..3271c0874572 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -844,7 +844,7 @@ void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state, Value semaphore, Value amount, Value device_id, Value core_id) { build(builder, state, semaphore, amount, device_id, core_id, - /*subcore_id=*/nullptr, /*core_type=*/nullptr); + /*core_type=*/nullptr); } LogicalResult SemaphoreSignalOp::verify() { @@ -861,39 +861,21 @@ LogicalResult SemaphoreSignalOp::verify() { CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc); CoreType target_core_type = getCoreType().value_or(issuing_core_type); - if (getCoreId() == nullptr && getDeviceId() == nullptr && - getSubcoreId() == nullptr) { + if (getCoreId() == nullptr && getDeviceId() == nullptr) { if (target_core_type != issuing_core_type) { - return emitOpError(absl::StrFormat( - "Target core type (%s) must match source core type " - "(%s) when device_id, core_id and subcore_id are not specified", - stringifyCoreType(target_core_type), - stringifyCoreType(issuing_core_type))); + return emitOpError( + absl::StrFormat("Target core type (%s) must match source core type " + "(%s) when device_id and core_id are not specified", + stringifyCoreType(target_core_type), + stringifyCoreType(issuing_core_type))); } } - if (target_core_type == CoreType::kScVectorSubcore && - issuing_core_type != CoreType::kScVectorSubcore && - getSubcoreId() == nullptr) { - return emitOpError( - "Subcore ID must be specified for the SC vector subcore"); - } - if (target_core_type != CoreType::kScVectorSubcore && - getSubcoreId() != nullptr) { - return emitOpError( - "Subcore ID must be specified only for the SC vector subcore"); - } if ((issuing_core_type == CoreType::kTc && - (target_core_type == CoreType::kScScalarSubcore || - target_core_type == CoreType::kScVectorSubcore)) || - ((issuing_core_type == CoreType::kScScalarSubcore || - issuing_core_type == CoreType::kScVectorSubcore) && + target_core_type == CoreType::kScScalarSubcore) || + (issuing_core_type == CoreType::kScScalarSubcore && target_core_type == CoreType::kTc)) { return emitOpError("Signalling between TC and SC is not implemented"); } - if (target_core_type == CoreType::kScVectorSubcore && - (getCoreId() != nullptr || getDeviceId() != nullptr)) { - return emitOpError("Signalling remote SC vector subcores is not supported"); - } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 27a886ebeb7e..fd68c9e6c95e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -15,21 +15,19 @@ limitations under the License. // We need to keep some extra headers for the code in tpu_passes.h.inc. -#include #include // IWYU pragma: keep #include #include #include -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" -#include "absl/strings/str_format.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" @@ -45,7 +43,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 4; +constexpr int kVersion = 3; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -88,37 +86,21 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) { LogicalResult semaphore_signal_rule(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. - // Added subcore_id in version 4. if (version < 2) { if (op->getNumOperands() == 2) { // Local signal. - op->setAttr( - OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0, 0})); + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); } else if (op->getNumOperands() == 3) { // Remote signal. - op->setAttr( - OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0, 0})); - } - return op->emitError("Unexpected operand count in tpu.semaphore_signal"); - } else if (version < 4) { - ArrayRef operand_segment_sizes = - op->getAttrOfType( - OpTrait::AttrSizedOperandSegments< - SemaphoreSignalOp>::getOperandSegmentSizeAttr()); - if (operand_segment_sizes.size() != 4) { - return op->emitError(absl::StrFormat( - "Expected operand count to be 4 in tpu.semaphore_signal. Got %d", - operand_segment_sizes.size())); + // Hardcoding that one optional value is device_id, not core_id. This + // could misinterpret sem_signals where core_id is specified, but + // device_id isn't. + op->setAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr(), + mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); + } else { + return op->emitError("Unexpected operand count in tpu.semaphore_signal"); } - SmallVector new_operand_segment_sizes( - operand_segment_sizes.begin(), operand_segment_sizes.end()); - new_operand_segment_sizes.push_back(0); - op->setAttr(OpTrait::AttrSizedOperandSegments< - EnqueueDMAOp>::getOperandSegmentSizeAttr(), - mlir::DenseI32ArrayAttr::get(op->getContext(), - new_operand_segment_sizes)); } return success(); } From f18df8f39cfa9471449e6c66a5b765e17f10c90d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 21 Nov 2024 03:12:51 -0800 Subject: [PATCH 444/698] [pallas:mosaic_gpu] Pulled `delay_release` into `emit_pipeline` The implementation exactly matches the one we have in the lowering. PiperOrigin-RevId: 698713343 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 33 +++++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 90c00765e8b1..feb7f1af6301 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -125,20 +125,41 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore def emit_pipeline( - body, + body: Callable[..., None], *, grid: pallas_core.StaticGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, + delay_release: int = 0, ): - """Creates a function to emit a manual pipeline within a Pallas kernel.""" + """Creates a function to emit a manual pipeline within a Pallas kernel. + + Args: + body: The pipeline body. + grid: The grid to use for the pipeline. + in_specs: The block specs for the inputs. + out_specs: The block specs for the outputs. + max_concurrent_steps: The maximum number of sequential stages that are + active concurrently. Defaults to 1. + delay_release: The number of steps to wait before reusing the input/output + references. Defaults to 0, and must be strictly smaller than + ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you + don't await the WGMMA in the body. + """ num_steps = math.prod(grid) + if max_concurrent_steps <= delay_release: + raise ValueError( + "max_concurrent_steps must be greater than delay_release, but" + f" {max_concurrent_steps=}, {delay_release=}" + ) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to - # reduce the size of the allocated buffers below. + # reduce the size of the refs allocated in SMEM. if max_concurrent_steps > num_steps: max_concurrent_steps = num_steps + delay_release = 0 # No need to delay anything. def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)): @@ -208,7 +229,7 @@ def loop_body(step, carry): gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. gpu_primitives.wait_smem_to_gmem( - max_concurrent_steps - 1, wait_read_only=True + max_concurrent_steps - (1 + delay_release), wait_read_only=True ) with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): @@ -245,10 +266,10 @@ def loop_body(step, carry): predicate=lax.bitwise_or(slices_changed, is_last_step), ) - fetch_step = step + max_concurrent_steps + fetch_step = step + (max_concurrent_steps - delay_release) fetch_slot = slot # (x + y) % y == x % y jax.lax.cond( - fetch_step < num_steps, + lax.bitwise_and(fetch_step >= delay_release, fetch_step < num_steps), lambda: map( lambda bref: bref.copy_in(fetch_slot, fetch_indices, barrier_ref), in_brefs, From 1bc9df429d87920bdbbf874e84a63fbe3111e27d Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Thu, 21 Nov 2024 05:24:38 -0800 Subject: [PATCH 445/698] Integrate LLVM at llvm/llvm-project@33fcd6acc755 Updates LLVM usage to match [33fcd6acc755](https://github.com/llvm/llvm-project/commit/33fcd6acc755) PiperOrigin-RevId: 698742870 --- jax/experimental/mosaic/gpu/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index b716456eceb3..0ce1140cfa07 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -296,6 +296,12 @@ def globaltimer(kind: Literal["low", "high"] | None = None): def bytewidth(ty: ir.Type): + # The actual width of TF32 is 19 bits. However, sinc we need to treat it as + # 32 bits for compatibility reasons. TF32 used to be 32 bits wide in upstream + # MLIR, but it changed in + # https://github.com/llvm/llvm-project/commit/67a1fdb014790a38a205d28e1748634de34471dd. + if ir.FloatTF32Type.isinstance(ty): + return 4 if ir.IntegerType.isinstance(ty): return ir.IntegerType(ty).width // 8 if ir.FloatType.isinstance(ty): From 0831e2e3401dfde3b12e407cb4c366b420b16348 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 20 Nov 2024 20:50:37 -0800 Subject: [PATCH 446/698] [shape_poly] Adding shape polymorphism support for the state primitives. --- benchmarks/shape_poly_benchmark.py | 3 +- jax/_src/core.py | 64 ++++++++++++++++++++++++++++++ jax/_src/numpy/lax_numpy.py | 61 +--------------------------- jax/_src/state/indexing.py | 8 ++-- tests/shape_poly_test.py | 30 +++++++++++++- tests/state_test.py | 2 +- 6 files changed, 100 insertions(+), 68 deletions(-) diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index d26801d8dfe5..d365a6facd90 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -17,7 +17,6 @@ import jax from jax import core -from jax._src.numpy import lax_numpy from jax import export jax.config.parse_flags_with_absl() @@ -76,7 +75,7 @@ def inequalities_slice(state): while state: for _ in range(30): a.scope._clear_caches() - start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b) + start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b) _ = 0 <= slice_size <= b _ = start >= 0 _ = start + slice_size <= b diff --git a/jax/_src/core.py b/jax/_src/core.py index cbf3282fb2cc..faf33f00bbf9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2047,6 +2047,70 @@ def dimension_as_value(d: DimSize): if hasattr(d, "dimension_as_value"): return d.dimension_as_value() return operator.index(d) +def canonicalize_slice( + s: slice, + axis_size: DimSize + ) -> tuple[DimSize, DimSize, DimSize]: + """Computes the start index, step, and size of the slice `x[s]`. + + This is similar to `s.indices(axis_size)`, except that it returns + `(start, step, size)`, and it works when the slice and/or the + `axis_size` are symbolic. + + See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding + """ + def convert_to_index(d: DimSize) -> DimSize: + # Convert np.array and jax.Array to int, leave symbolic dimensions alone + try: + return operator.index(d) + except: + return d + + # Must resolve statically if step is {<0, ==0, >0} + step = convert_to_index(s.step) if s.step is not None else 1 + try: + if step == 0: + raise ValueError("slice step cannot be zero") + step_gt_0 = (step > 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the step ({step}) must " + + f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") + + def clamp_index(i: DimSize, which: str): + try: + i_ge_0 = (i >= 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the {which} ({i}) must " + + f"be resolved statically if it is >= 0.\nDetails: {e}") + if i_ge_0: + if step_gt_0: + return min_dim(axis_size, i) + else: + return min_dim(axis_size - 1, i) + else: + if step_gt_0: + return max_dim(0, axis_size + i) + else: + return max_dim(-1, axis_size + i) + + if s.start is None: + start = 0 if step_gt_0 else axis_size - 1 + else: + start = clamp_index(convert_to_index(s.start), "start") + + if s.stop is None: + stop = axis_size if step_gt_0 else -1 + else: + stop = clamp_index(convert_to_index(s.stop), "stop") + + gap = step if step_gt_0 else - step + distance = (stop - start) if step_gt_0 else (start - stop) + slice_size = max_dim(0, distance + gap - 1) // gap + return start, step, slice_size + + class SomeTracer: __slots__ = () def __repr__(self): return "[dynamic]" diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 898e4255dd8e..5f380fad902c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -12116,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], "arrays within JIT compiled functions).") raise IndexError(msg) - start, step, slice_size = _preprocess_slice(i, x_shape[x_axis]) + start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis]) slice_shape.append(slice_size) if core.definitely_equal(step, 1): @@ -12319,65 +12319,6 @@ def _canonicalize_tuple_index(arr_ndim, idx): idx = tuple(idx) + colons return idx -def _preprocess_slice( - s: slice, - axis_size: core.DimSize - ) -> tuple[core.DimSize, core.DimSize, core.DimSize]: - """Computes the start index, step, and size of the slice `x[s]`.""" - # See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding - # "this is harder to get right than you may think" - # (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275) - def convert_to_index(d: DimSize) -> DimSize: - # Convert np.array and jax.Array to int, leave symbolic dimensions alone - try: - return operator.index(d) - except: - return d - - # Must resolve statically if step is {<0, ==0, >0} - step = convert_to_index(s.step) if s.step is not None else 1 - try: - if step == 0: - raise ValueError("slice step cannot be zero") - step_gt_0 = (step > 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the step ({step}) must " + - f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") - - def clamp_index(i: DimSize, which: str): - try: - i_ge_0 = (i >= 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the {which} ({i}) must " + - f"be resolved statically if it is >= 0.\nDetails: {e}") - if i_ge_0: - if step_gt_0: - return core.min_dim(axis_size, i) - else: - return core.min_dim(axis_size - 1, i) - else: - if step_gt_0: - return core.max_dim(0, axis_size + i) - else: - return core.max_dim(-1, axis_size + i) - - if s.start is None: - start = 0 if step_gt_0 else axis_size - 1 - else: - start = clamp_index(convert_to_index(s.start), "start") - - if s.stop is None: - stop = axis_size if step_gt_0 else -1 - else: - stop = clamp_index(convert_to_index(s.stop), "stop") - - gap = step if step_gt_0 else - step - distance = (stop - start) if step_gt_0 else (start - stop) - slice_size = core.max_dim(0, distance + gap - 1) // gap - return start, step, slice_size - @export def blackman(M: int) -> Array: diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 538f3f8e4888..2da93e3d8e80 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -46,11 +46,11 @@ def __post_init__(self): @property def is_dynamic_start(self): - return not isinstance(self.start, int) + return not core.is_dim(self.start) @property def is_dynamic_size(self): - return not isinstance(self.size, int) + return not core.is_dim(self.size) def tree_flatten(self): # If `start` is statically known, we treat it as static information @@ -72,10 +72,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice: @classmethod def from_slice(cls, slc: slice, size: int) -> Slice: - start, stop, step = slc.indices(size) + start, step, size = core.canonicalize_slice(slc, size) if step < 1: raise ValueError(f"slice must have a step >= 1 (found: {step})") - return cls(start, max((stop - start + step - 1) // step, 0), step) + return cls(start, size, step) def dslice( diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index eda4c4309960..668907ffee27 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -48,6 +48,9 @@ from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.state import discharge +from jax._src.state import primitives as ref_primitives + import numpy as np config.parse_flags_with_absl() @@ -2062,6 +2065,31 @@ def test_vmap_error(self): polymorphic_shapes=["b, ...", "c, ...", None]) + @jtu.parameterized_filterable( + kwargs=[ + dict(slc=slc) + for slc in [ + slice(None, None, None), + slice(2, 5), + ] + ]) + def test_stateful(self, slc: slice): + w, = export.symbolic_shape("w", constraints=["w >= 3"]) + def f(x_ref): + ones = jnp.ones_like(x_ref)[slc] + ref_primitives.ref_addupdate(x_ref, slc, ones) + x1 = ref_primitives.ref_get(x_ref, slc) + x2 = x1 + ones + ref_primitives.ref_set(x_ref, slc, x2) + + exp = export.export(jax.jit(discharge.run_state(f)))( + jax.ShapeDtypeStruct((w,), dtype=_f32)) + x = np.ones((32,), dtype=_f32) + expected = np.copy(x) + expected[slc] = 3. + self.assertAllClose(exp.call(x), expected) + + # List containing either harnesses, or lists of harnesses _POLY_SHAPE_TEST_HARNESSES = [ PolyHarness("add", "", @@ -3603,7 +3631,7 @@ def test_harness(self, harness: PolyHarness): not harness.polymorphic_shapes[0].endswith("...") and jtu.test_device_matches(["tpu"])): raise unittest.SkipTest( - "Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.") + "Shape polymorphism for Eigh and Svd is only supported for batch dimensions on TPU.") config_flags = harness.override_jax_config_flags # Update this here rather than in harness object because vmap_random_gamma is derived diff --git a/tests/state_test.py b/tests/state_test.py index c8458742619d..44caded0ca64 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -752,7 +752,7 @@ def f(a_ref, b_ref): lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) - prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) + prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns) self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr)) self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr)) From 7d7a0fa249c7d42dfa11d492fb62b4d1909fa628 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 21 Nov 2024 07:25:15 -0800 Subject: [PATCH 447/698] Run the TPU workflow on new self-hosted runners We are not able to run the TPU workflows because of no active runners (https://github.com/jax-ml/jax/actions/runs/11879479226/job/33101456081). So this adds the new self-hosted runners to the TPU workflow to fix this issue. The v3 type is disabled as we do not have that available yet. PiperOrigin-RevId: 698772505 --- .github/workflows/cloud-tpu-ci-nightly.yml | 49 ++++++++++++---------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index a5fac5ebdbc3..16c0751f40f8 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -13,7 +13,7 @@ name: CI - Cloud TPU (nightly) on: schedule: - - cron: "0 14 * * *" # daily at 7am PST + - cron: "* */2 * * *" # Run every 2 hours workflow_dispatch: # allows triggering the workflow run manually # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. @@ -26,15 +26,18 @@ jobs: matrix: jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - {type: "v3-8", cores: "4"}, - {type: "v4-8", cores: "4"}, - {type: "v5e-8", cores: "8"} + # {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available + # {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] + python-version: ["3.10"] name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: LIBTPU_OLDEST_VERSION_DATE: 20240722 ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} - runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"] + PYTHON: python${{ matrix.python-version }} + runs-on: ${{ matrix.tpu.runner }} + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" timeout-minutes: 120 defaults: run: @@ -46,37 +49,37 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install JAX test requirements run: | - pip install -U -r build/test-requirements.txt - pip install -U -r build/collect-profile-requirements.txt + $PYTHON -m pip install -U -r build/test-requirements.txt + $PYTHON -m pip install -U -r build/collect-profile-requirements.txt - name: Install JAX run: | - pip uninstall -y jax jaxlib libtpu + $PYTHON -m pip uninstall -y jax jaxlib libtpu if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then - pip install .[tpu] \ + $PYTHON -m pip install .[tpu] \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - pip install --pre libtpu \ + $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $PYTHON -m pip install --pre libtpu \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $PYTHON -m pip install requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then - pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. - pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ + $PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install requests + $PYTHON -m pip install requests else echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" exit 1 fi - python3 -c 'import sys; print("python version:", sys.version)' - python3 -c 'import jax; print("jax version:", jax.__version__)' - python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' - strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on' - python3 -c 'import jax; print("libtpu version:", + $PYTHON -c 'import sys; print("python version:", sys.version)' + $PYTHON -c 'import jax; print("jax version:", jax.__version__)' + $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' + strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' + $PYTHON -c 'import jax; print("libtpu version:", jax.lib.xla_bridge.get_backend().platform_version)' - name: Run tests env: @@ -84,14 +87,14 @@ jobs: PY_COLORS: 1 run: | # Run single-accelerator tests in parallel - JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ + JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ --maxfail=20 -m "not multiaccelerator" tests examples # Run Pallas printing tests, which need to run with I/O capturing disabled. - TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \ + TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \ tests/pallas/tpu_pallas_test.py::PallasCallPrintTest # Run multi-accelerator across all chips - python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests + $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests - name: Send chat on failure # Don't notify when testing the workflow from a branch. if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }} From bf0150bb22b2ed7986adf3762cb1bc555ed3fee8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 21 Nov 2024 08:20:41 -0800 Subject: [PATCH 448/698] [JAX] Ignore xla_gpu_experimental_autotune_cache_mode when calculating module hash. PiperOrigin-RevId: 698789020 --- jax/_src/cache_key.py | 3 +++ tests/cache_key_test.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6e025653b81d..324fa85f81ed 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -21,6 +21,7 @@ from typing import cast as type_cast from jax._src import config +from jax._src.lib import version as jaxlib_version from jax._src.lib import version_str as jaxlib_version_str from jax._src.lib import xla_client from jax._src.lib.mlir import ir @@ -225,6 +226,8 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_dump_hlo_as_long_text = False debug_options.xla_dump_disable_metadata = False debug_options.xla_dump_hlo_pipeline_re = "" + if jaxlib_version > (0, 4, 35): + debug_options.xla_gpu_experimental_autotune_cache_mode = 0 # Optional way to specify the cuda install path to be used by the compiler. # This could possibly affect the cuda version compiled with, but this should # already be included in the platform information (and might not be reflected diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 00925c5f7dfc..8f9c5d0e8b82 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -31,6 +31,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.mesh import Mesh from jax._src.partition_spec import PartitionSpec as P @@ -68,6 +69,8 @@ def test_serialized_compile_options(self): debug_options.xla_dump_hlo_as_long_text = True debug_options.xla_dump_disable_metadata = True debug_options.xla_dump_hlo_pipeline_re = "xyzzy" + if jaxlib_version > (0, 4, 35): + debug_options.xla_gpu_experimental_autotune_cache_mode = 2 hash2 = self.get_hashed_value( cache_key._hash_serialized_compile_options, compile_options ) From 1e6654a0314cc067a6f257dd4f5c5a5a5d409f39 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Thu, 21 Nov 2024 09:08:23 -0800 Subject: [PATCH 449/698] Fix cron schedule to run past minute 0 every 2nd hour In the previous schedule, we were running at every minute at every 2nd hour. PiperOrigin-RevId: 698804124 --- .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 16c0751f40f8..4ac167bd37c1 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -13,7 +13,7 @@ name: CI - Cloud TPU (nightly) on: schedule: - - cron: "* */2 * * *" # Run every 2 hours + - cron: "0 */2 * * *" # Run every 2 hours workflow_dispatch: # allows triggering the workflow run manually # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. From 1d2dc17e5f226db7de2a8996c4a2d3bef4c8a0f6 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 21 Nov 2024 09:49:35 -0800 Subject: [PATCH 450/698] [mgpu] Pointwise op can handle LHS splats. PiperOrigin-RevId: 698818035 --- .../mosaic/gpu/fragmented_array.py | 34 ++++++++++++++++++- tests/mosaic/gpu_test.py | 23 +++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2b985ff5c9b8..e1ee37f3d24d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -623,6 +623,38 @@ def to_layout(self, new_layout: FragmentedLayout): ) def _pointwise(self, op, *other, output_is_signed: bool | None = None): + if isinstance(self.layout, WGSplatFragLayout): + # Find either the largest operand or an operand that has a + # concrete layout base the layout computation of that. + widest_idx = None + for i, o in enumerate(other): + if not isinstance(o, FragmentedArray): + continue + elif not isinstance(o.layout, WGSplatFragLayout): + widest_idx = i + break + elif not o.layout.can_broadcast_to(self.layout.shape): + # Note: equal shapes can be broadcast to each other. Using + # the negation we make sure to only consider strictly larger + # shapes so that we don't end up ping ponging between equal + # shapes. + widest_idx = i + + if widest_idx is not None: + # We need to retain the order of arguments that the op + # expects. + def _op(wide_o, self_o, *args): + pre_wide = args[:widest_idx - 1] + post_wide = args[widest_idx - 1:] + return op(self_o, *pre_wide, wide_o, *post_wide) + return other[widest_idx]._pointwise( + _op, + self, + *other[:widest_idx], + *other[widest_idx + 1:], + output_is_signed=output_is_signed, + ) + other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): @@ -642,7 +674,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): o.registers.flat[0], shape=self.shape, layout=self.layout, - is_signed=self.is_signed, + is_signed=o.is_signed, ) else: if self.layout != o.layout: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ab2a00c730d6..87dc2c452041 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1489,6 +1489,29 @@ def kernel(ctx, dst, _): )() np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) + + def test_splat_binary_ops(self): + def kernel(ctx, src, dst, _): + f32 = ir.F32Type.get() + pi_arr = mgpu.FragmentedArray.load_strided(src) + assert isinstance(pi_arr.layout, mgpu.WGStridedFragLayout) + pi_scalar = arith.constant(f32, ir.FloatAttr.get(f32, 3.14)) + pi_splat = mgpu.FragmentedArray.splat(pi_scalar, ()) + assert isinstance(pi_splat.layout, mgpu.WGSplatFragLayout) + pi_arr_sq = pi_arr * pi_splat.broadcast(pi_arr.shape) + assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) + pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq + assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) + (pi_arr_sq + pi_arr_cube).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) + inp = jnp.ones_like(out_shape) * 3.14 + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, () + )(inp) + np.testing.assert_allclose(result, np.full((128, 32), 3.14 ** 2 + 3.14 ** 3, np.float32)) + + @parameterized.product(in_shape=((128, 128), (128, 64), (64, 128))) def test_strided_load_store(self, in_shape): def kernel(ctx, *args): From 2178ed2fa42eeb7f609369d56d90950af60d25ca Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Thu, 21 Nov 2024 09:49:48 -0800 Subject: [PATCH 451/698] [pallas] Add more test cases for Triton bitcast_convert_type lowering rule. PiperOrigin-RevId: 698818103 --- tests/pallas/ops_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 9f0b9aef5af3..d7c1bac5dc61 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1941,9 +1941,13 @@ def kernel(x_ref, out_ref): @parameterized.parameters( (jnp.float16, jnp.float16), # Noop - (jnp.int16, jnp.float16), (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), ) def test_bitcast_convert_type(self, in_dtype, out_dtype): if jtu.test_device_matches(["tpu"]): From 96c012990de86ccb0eb815a11ae4e2c337802794 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 21 Nov 2024 10:32:37 -0800 Subject: [PATCH 452/698] Fix false positive `debug_nans` error caused by NaNs that are properly handled in `jax.scipy.stats.gamma` As reported in https://github.com/jax-ml/jax/issues/24939, even though the implementation of `jax.scipy.stats.gamma.logpdf` handles invalid inputs (e.g. `x < loc`) by returning `-inf`, the existing implementation incorrectly triggers the NaN checks introduced by JAX's debug NaNs mode. This change updates the implementation to no longer produce internal NaNs. Fixes https://github.com/jax-ml/jax/issues/24939 PiperOrigin-RevId: 698833589 --- jax/_src/scipy/stats/gamma.py | 5 +++-- tests/scipy_stats_test.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index f410d08e4f3d..4343c080251c 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -51,12 +51,13 @@ def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) - :func:`jax.scipy.stats.gamma.logsf` """ x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale) + ok = lax.ge(x, loc) one = _lax_const(x, 1) - y = lax.div(lax.sub(x, loc), scale) + y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) log_probs = lax.sub(log_linear_term, shape_terms) - return jnp.where(lax.lt(x, loc), -jnp.inf, log_probs) + return jnp.where(ok, log_probs, -jnp.inf) def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index f02ed0fc04bb..88a126c284a7 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -543,6 +543,13 @@ def testGammaLogPdfZero(self): self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) + def testGammaDebugNans(self): + # Regression test for https://github.com/jax-ml/jax/issues/24939 + with jax.debug_nans(True): + self.assertAllClose( + osp_stats.gamma.pdf(0.0, 1.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0, 1.0) + ) + @genNamedParametersNArgs(4) def testGammaLogCdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) From 1efef6bf6b3af91b91fb601e6302b4f17db739e0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 21 Nov 2024 11:38:04 -0800 Subject: [PATCH 453/698] [pallas:mosaic_gpu] `emit_pipeline` now correctly supports `BlockSpec`s in GMEM This is necessary to replace the pipelining logic in the lowering with `emit_pipeline`. PiperOrigin-RevId: 698858380 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 49 ++++++++++++++++++-------- tests/pallas/mosaic_gpu_test.py | 33 +++++++++++++++++ 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index feb7f1af6301..9fcca6acdacc 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -46,7 +46,16 @@ class BufferedRef: spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True}) is_index_invariant: bool = dataclasses.field(metadata={"static": True}) gmem_ref: pallas_core.AbstractMemoryRef - smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape] + # ``None`` if the ref is pinned to GMEM; otherwise, has shape + # [num_slots, *spec.block_shape]. + smem_ref: pallas_core.AbstractMemoryRef | None + + def get_ref_for_slot( + self, slot: int | jax.Array + ) -> pallas_core.AbstractMemoryRef: + if self.smem_ref is None: + return self.gmem_ref + return self.smem_ref.at[slot] def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: index_map = self.spec.index_map @@ -59,6 +68,9 @@ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: ) def copy_in(self, slot, grid_indices, barrier_ref): + if not _in_smem(self.spec): + return + assert self.smem_ref is not None gmem_slices = self.compute_gmem_slice(grid_indices) gpu_primitives.copy_gmem_to_smem( self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands @@ -67,6 +79,9 @@ def copy_in(self, slot, grid_indices, barrier_ref): ) def copy_out(self, slot, grid_indices, predicate=None): + if not _in_smem(self.spec): + return + assert self.smem_ref is not None gmem_slices = self.compute_gmem_slice(grid_indices) gpu_primitives.copy_smem_to_gmem( self.smem_ref.at[slot], @@ -88,8 +103,8 @@ def _uses_arguments( def _is_index_invariant( spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid ) -> bool: - index_map = spec.index_map - assert index_map is not None + if (index_map := spec.index_map) is None: + return True return not any(_uses_arguments(index_map, len(grid))) @@ -105,6 +120,10 @@ def _inc_grid_by_1( return tuple(reversed(next_indices)) +def _in_smem(spec: pallas_core.BlockSpec) -> bool: + return spec.memory_space in (None, gpu_core.SMEM) + + # ``pl.Slice`` uses a different pytree encoding, depending on whether the # start/size are static or dynamic. This leads to pytree structure mismatch # in the pipeline body. So, we define a different ``Slice`` class below. @@ -166,6 +185,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): if any( spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore for idx in range(1, len(grid) + 1) + if spec.block_shape is not None ): raise NotImplementedError( f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block" @@ -174,14 +194,12 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( - map( - lambda spec, ref: gpu_core.SMEM( - (max_concurrent_steps, *spec.block_shape), # type: ignore - ref.dtype, - ), - it.chain(in_specs, out_specs), - gmem_refs, - ), + [ + gpu_core.SMEM((max_concurrent_steps, *spec.block_shape), ref.dtype) # type: ignore + if _in_smem(spec) + else None + for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs) + ], [len(in_specs)], ) return pl.run_scoped( @@ -194,7 +212,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): out_smem_refs=out_smem_refs, barrier_ref=gpu_core.Barrier( # TODO(slebedev): Change this to arrive only once. - len(in_specs), + sum(map(_in_smem, in_specs)), num_barriers=max_concurrent_steps, ), ) @@ -233,9 +251,10 @@ def loop_body(step, carry): ) with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body( - *(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs)) - ) + body(*( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + )) if not all(bref.is_index_invariant for bref in out_brefs): gpu_primitives.commit_smem() diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index a4bbc67ee14f..110d83bd992b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1186,6 +1186,39 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_nested_emit(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + grid=(), + )(x_gmem, o_gmem) + + def nested_kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + nested_kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def nested_kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit_with_grid_invariant_output(self): num_steps = 4 From f3e7e6829adae587e60a536a45852a9014389ab6 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Thu, 21 Nov 2024 12:17:28 -0800 Subject: [PATCH 454/698] Remove unneeded dependency from rocm_plugin_extension. PiperOrigin-RevId: 698872849 --- jaxlib/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 8c402cfcefe8..987fe24a8008 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -243,7 +243,6 @@ pybind_extension( "@local_config_rocm//rocm:rocm_headers", "@nanobind", "@xla//third_party/python_runtime:headers", - "@xla//xla:status", "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", From f899d515354d19801f631b6c096e9db075ac820d Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 21 Nov 2024 13:28:30 -0800 Subject: [PATCH 455/698] [Mosaic TPU] Fold sublane offset to indices when storing to untiled ref. This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]). PiperOrigin-RevId: 698896373 --- .../dialect/tpu/transforms/infer_vector_layout.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index d84e4b883172..c0b2c6c96e7e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1640,14 +1640,14 @@ class VectorLayoutInferer { // Since it is untiled, we can store to any arbitrary address which // means the sublane offset can be any value and we can fold it to // 2nd minor index. - // TODO(jevinjiang): We can fold the sublane offset into the 2nd minor - // index. But we need to handle negative index in lower-to-llo. For - // now, we just force the sublane offset to be 0. + auto prev_store_layout = getLayout(op.getValueToStore()); + TPU_CHECK_OP(prev_store_layout.has_value(), "missing vector layout"); + offsets[0] = prev_store_layout->offsets()[0].value_or(0); if (offsets[1].value_or(0) >= tiling[1]) { offsets[1] = 0; } - store_layout = VectorLayout(bitwidth, {0, offsets[1]}, - nativeTiling(bitwidth), ImplicitDim::kNone); + store_layout = VectorLayout(bitwidth, offsets, nativeTiling(bitwidth), + ImplicitDim::kNone); } else { store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]}, ImplicitDim::kNone); From 26443bbd6696ab296408b808fdc9f3974c4cfa3b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 21 Nov 2024 14:25:39 -0800 Subject: [PATCH 456/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/85360d67ffc0a6d6923605b848de12ec204ca336. PiperOrigin-RevId: 698915433 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 327e4ca422ac..46f71523be05 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e763f8875b0a9bfca876be9b02c874979e55422a" -XLA_SHA256 = "7b6a33894c6510167cac6e0ab7a6331ffa84e7fcaaa1d3b1c462ec5ecacb0682" +XLA_COMMIT = "85360d67ffc0a6d6923605b848de12ec204ca336" +XLA_SHA256 = "7afa7e599adf7b1a636ea9e55419c253a115ef27217ec862ca8a03cef1abd11a" def repo(): tf_http_archive( From 344d0d998d682ffae5e35ad0ed4a3de50e51940c Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 21 Nov 2024 15:42:34 -0800 Subject: [PATCH 457/698] [Pallas] Add readme page for debugging tips. PiperOrigin-RevId: 698939951 --- jax/experimental/pallas/g3doc/debugging.md | 207 +++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 jax/experimental/pallas/g3doc/debugging.md diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md new file mode 100644 index 000000000000..40b109d102d5 --- /dev/null +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -0,0 +1,207 @@ +# Debugging Pallas + + + + + +[TOC] + +This document contains a collection of tips and tricks for debugging Pallas +programs. For any specific requests or ideas for improvement, please create +a ticket on https://github.com/jax-ml/jax/issues. + +## Debugging Tools + +### Interpret (HLO) Mode + +Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. + +Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations. + +### debug_print + +The `pl.debug_print` function can be used to print runtime values inside of a kernel. The implementation is currently limited to scalar values, but we are working on lifting this limitation. + +For TPUs only, the kernel must be compiled with the 'xla_tpu_enable_log_recorder' option. + + +```python +kernel = pl.pallas_call(...) +compiled_kernel = ( + jax.jit(kernel) + .lower(x) + .compile({'xla_tpu_enable_log_recorder': 'true'}) + ) +result = compiled_kernel(x) +``` + +### Runtime Asserts + +Checkify can be used to insert runtime asserts, nan checks, out of bounds errors, etc. inside of a kernel. +Pallas implements two options for assertions: a *hard assert* which will crash the TPU if failed, and a *functionalized assertion* which will simulate a runtime assertion that can be thrown +as a Python error after the kernel has successfully executed. + +#### Hard assertion + +Hard assertions can be inserted with `checkify.check` +and running your program with the `--jax_pallas_enable_runtime_assert` flag. + +Your code will look like the following: + +```python +from jax.experimental import checkify + +def kernel(...): + checkify.check(x > y, "Check x > y failed") # Will halt if x <= y +``` + +This will print a relatively lengthy dump which resembles the following: + +``` +E1001 15:22:33.275768 4353 real_program_continuator.cc:1350] 0x0x0_TC0: [Physical location: dldgr4:pe1:1] generic::internal: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x169 (from TensorCoreSequencer:1:0x213): Check x > y failed HLO: main; HLO computation: main.3 +``` + +The benefit of a hard assertion is that it is guaranteed to either pass or +halt the TPU. The kernel will never proceed past the assertion if it fails. +However, the downside is that if the assertion fails you will +likely have to restart the program in order to run any other TPU operations, +and there is no Python error thrown that can be caught. + +#### Functionalized assertion +Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op like so: + +```python +from jax.experimental import checkify + +def kernel(...): + checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y + +kernel = pl.pallas_call(...) +checkified_kernel = checkify.checkify(kernel, + errors=checkify.all_checks) +error, result = checkified_kernel(x) +error.throw() +``` + +This will throw a Python error if any checks failed, such as if a NaN occurred +or if an out-of-bounds index was accessed. + +The benefit of a functionalized assert is that it will throw Python errors +that can be caught, and it will not interfere with downstream TPU operations. +However, it requires the kernel to successfully complete, meaning if your +error would have caused a TPU crash, the crash would still happen and +the error would not be thrown. + + +### Dumping Jaxprs + +Passing in `debug=True` into `pl.pallas_call` will print out the Jaxpr of the kernel as well as the lowered Mosaic code. + +```python +def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +x = jnp.ones((8, 128), dtype=jnp.float32) +pl.pallas_call( + kernel, + out_shape=jax.ShapeDTypeStruct((8, 128), jnp.float32) + debug=True, + name="my_call", +)(x, x) +``` + +This will output: + +``` +The kernel jaxpr for the pallas_call my_call for kernel function kernel at ...:1000: +{ lambda ; a:MemRef{float32[8,128]} b:MemRef{float32[8,128]} c:MemRef{float32[8,128]}. let + d:f32[8,128] <- a[:,:] + e:f32[8,128] <- b[:,:] + f:f32[8,128] = add d e + c[:,:] <- f + in () } + +The Mosaic module for the pallas_call my_call for kernel function kernel at ...:1000: +module { + func.func @main(%arg0: memref<8x128xf32, #tpu.memory_space>, %arg1: memref<8x128xf32, #tpu.memory_space>, %arg2: memref<8x128xf32, #tpu.memory_space>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} { + %c0 = arith.constant 0 : index + %c0_0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0_0] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + %c0_1 = arith.constant 0 : index + %c0_2 = arith.constant 0 : index + %1 = vector.load %arg1[%c0_1, %c0_2] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + %2 = arith.addf %0, %1 : vector<8x128xf32> + %c0_3 = arith.constant 0 : index + %c0_4 = arith.constant 0 : index + %3 = vector.load %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + vector.store %2, %arg2[%c0_3, %c0_4] : memref<8x128xf32, #tpu.memory_space>, vector<8x128xf32> + return + } +} +``` + +### Dumping Mosaic Passes + +Mosaic is the underlying TPU compiler for Pallas. It can be useful to dump Mosaic if you are running into errors that are originating from the Mosaic compiler to see what code is actually being generated. + +Passing the `--xla_mosaic_dump_to=` argument will dump the output of all intermediate Mosaic passes. The names of the files contain either the parameter `name` passed to the `pallas_call`, or the name of the kernel function. A useful option is to dump to Sponge with `--test_arg=--xla_mosaic_dump_to=sponge` after which you will see all passes under the “Artifacts” tab in sponge. + +### Static Verification + +The static verification tool can be used to automatically detect race conditions in distributed kernels. +Because this tool uses formal verification, it is best used for small kernels (<=2 devices). + +Verification can be performed by running your kernel with the `--jax_pallas_dump_promela_to=`, +which will output a Promela dump file. Afterwards, the dump file can be +analyzed using the [`spin`](https://spinroot.com) tool. For example, with a dump named `dump.pml`, run: + +``` +spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan +``` + + + +## Useful Command line flags + +* OOB Checks: `--xla_mosaic_on_device_checks=bounds` +* Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true` + +* Dump Mosaic: `--xla_mosaic_dump_to=` +* Enable trace markers in XProf: `--xla_enable_transpose_trace` + +## Common Errors + +### INTERNAL Mosaic failed to compile TPU Kernel + +`INTERNAL Mosaic failed to compile TPU Kernel: Not implemented X` + +This error means that you hit an unimplemented case in the underlying Mosaic compiler. +Our recommended course of action here is to file a ticket if one does not already +exist for your specific error. + +In some cases, your error may be due to an operation which cannot be implemented +efficiently in the compiler, in which your best course of action is to find a workaround. This +is most commonly seen in `layout` and `shape_cast` errors. The important tip +to remember regarding layouts is that the last 2 dimensions of arrays in Pallas +are physically tiled into registers, so any reshapes, slicing, transposes, etc. +on the last 2 dimensions may trigger a relayout. + + +### VerificationError + +A verification error indicates that Pallas produced invalid code for Mosaic. + +This is a bug in Pallas, so please file a bug under https://github.com/jax-ml/jax/issues. + +### LoweringError + +This is a catch-all error type during Pallas to Mosaic lowering and can have many causes. +In most cases the error message should hint at what is wrong. + +For specific errors: + +* `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod + + From 170718c8d476e6727baf070f66c2ddbd8829f95a Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 21 Nov 2024 17:46:21 -0800 Subject: [PATCH 458/698] Change signature of linearization rules. Give the rule the nonzero tangent pattern up-front. This is needed to make a linearization rule for pjit_p. Also make the rules return the nonzero tangents out, an explicit residual, and a closed tangent function. Add a rule for sin_p to test it out. We still need to figure out how to avoid having to precompute `cos(x)`. I think we need to update our backward pass code. --- jax/_src/interpreters/ad.py | 31 ++++++++++++++++++------------- jax/_src/lax/lax.py | 6 +++++- tests/api_test.py | 2 +- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 91f061fd2210..9fa2fdb9ffbf 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -483,39 +483,44 @@ def to_primal_tangent_pair(self, val): def process_primitive(self, primitive, args, params): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) + tangent_nonzeros = [type(t) is not Zero for t in tangents_in] if all(type(t) is Zero for t in tangents_in): return primitive.bind_with_trace(self.parent_trace, primals_in, params) lin = primitive_linearizations.get(primitive) if lin is None: lin = partial(fallback_linearize_rule, primitive) with core.set_current_trace(self.parent_trace): - primal_out, linearized = lin(*primals_in, **params) + primal_out, tangent_nonzeros_out, residuals, linearized = lin( + tangent_nonzeros, *primals_in, **params) with core.set_current_trace(self.tangent_trace): - tangent_out = linearized(*tangents_in) + tangent_out = linearized(residuals, *tangents_in) if primitive.multiple_results: - return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] + return [maybe_linearize_tracer(self, x, nz, t) + for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)] else: - return maybe_linearize_tracer(self, primal_out, tangent_out) + return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out) -def maybe_linearize_tracer(trace, primal, tangent): - if type(tangent) is Zero: - return primal - else: +def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): + if is_nonzero: + assert not type(tangent) is Zero return LinearizeTracer(trace, primal, tangent) + else: + assert type(tangent) is Zero + return primal -def fallback_linearize_rule(prim, *args, **kwargs): +def fallback_linearize_rule(prim, _, *args, **kwargs): def call_prim(*args_): return prim.bind(*args_, **kwargs) with config.use_direct_linearize(False): out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize( lu.wrap_init(call_prim), *args, **kwargs) - def linearized(*tangents): - tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents)) + def linearized(residuals, *tangents): + tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents)) full_out = [pval.get_known() if pval.is_known() else next(tangents_out) for pval in out_tangents_pvals] assert next(tangents_out, None) is None return full_out - return out_primals, linearized + return out_primals, [True for _ in out_primals], consts, linearized class LinearizeTracer(Tracer): __slots__ = ['primal', 'tangent'] @@ -547,7 +552,7 @@ def to_concrete_value(self): primitive_transposes: dict[core.Primitive, Callable] = {} # transpose rules that internally perform reductions over the given named axes reducing_transposes: dict[core.Primitive, Callable] = {} -primitive_linearizations: dict[core.Primitive, Callable] = {} +primitive_linearizations : dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0519fa48f45a..1099919a6474 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2400,12 +2400,16 @@ def _sin_lowering(ctx, x): return sine(ctx, x) return _nary_lower_hlo(hlo.sine, ctx, x) +def _sin_p_lin(_, x): + cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) + return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule - def _cos_complex(x): # cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x))) # see also _sin_complex diff --git a/tests/api_test.py b/tests/api_test.py index ff7855b68991..a27938eed392 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4818,7 +4818,7 @@ def check_invariant_to_use_direct_linearize(f): self.assertEqual(ans1, ans2) def sin_of_sin(x): - return jnp.sin(jnp.sin(x)) + return lax.sin(lax.sin(x)) check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) From 355589f32b29ab1a2c59b58cef2ede80d4d3f642 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 21 Nov 2024 20:12:21 -0800 Subject: [PATCH 459/698] [sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here * Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path) * Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager. * Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them. * scan only allows `xs` where the 0th dim is full replicated i.e. None. PiperOrigin-RevId: 699014167 --- jax/_src/core.py | 8 +++- jax/_src/lax/control_flow/loops.py | 31 ++++++++++----- jax/_src/lax/lax.py | 14 +------ jax/_src/lax/slicing.py | 63 ++++++++++++++++++++++-------- jax/_src/pjit.py | 41 +++++++++++-------- jax/_src/sharding_impls.py | 5 +++ jax/_src/stages.py | 11 ++++-- tests/pjit_test.py | 42 +++++++++++++++++++- 8 files changed, 153 insertions(+), 62 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 2ad4264e9edd..86646faa980b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2263,16 +2263,20 @@ def _map_shaped_array( assert axis is None or aval.shape[axis] == size # TODO: Extend the named shape if axis is None: return aval + sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) + if config.sharding_in_types.value else None) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type) + weak_type=aval.weak_type, sharding=sharding) def _unmap_shaped_array( size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray ) -> ShapedArray: if axis is None: return aval elif type(axis) is int: + sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name)) + if config.sharding_in_types.value else None) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type) + weak_type=aval.weak_type, sharding=sharding) else: raise TypeError(axis) def _map_dshaped_array( diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d15917b8b1da..76132ccdc99a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -227,6 +227,11 @@ def scan(f, init, xs, length=None): msg.format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err + if (config.sharding_in_types.value and + not all(x.sharding.spec[0] is None for x in xs_flat)): + raise ValueError('0th dimension of all xs should be replicated. Got ' + f'{", ".join(str(x.sharding.spec) for x in xs_flat)}') + if length is not None: try: length = int(length) @@ -250,7 +255,8 @@ def scan(f, init, xs, length=None): if config.disable_jit.value: if length == 0: - raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") + raise ValueError("zero-length scan is not supported in disable_jit() " + "mode because the output type is unknown.") carry = init ys = [] maybe_reversed = reversed if reverse else lambda x: x @@ -424,7 +430,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, num_trips, remainder = 0, length if unroll == 1: xss = xs_ - yss = _map(partial(_empty_array, (length,)), y_avals) + yss = _map(partial(_empty_array, (length,), None), y_avals) else: if remainder: if not reverse: @@ -432,7 +438,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, else: xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_)) xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_] - yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals) + yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals) def cond_fun(while_carry): i, _, _ = while_carry @@ -477,8 +483,11 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) -def _empty_array(prefix, aval): - return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape)) +def _empty_array(prefix, length_spec, aval): + sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec)) + if config.sharding_in_types.value else None) + return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), + sharding=sharding) eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True @@ -486,11 +495,13 @@ def _stage_jaxpr(trace, *tracers, jaxpr): params = dict(call_jaxpr=jaxpr) return trace.default_process_primitive(core.closed_call_p, tracers, params) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr + @eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf -def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects +def _stage_jaxpr_abstract_eval(*_, jaxpr): + return jaxpr.out_avals, jaxpr.effects def _prepend_dim_to_aval(sz, aval): - return core.unmapped_aval(sz, core.no_axis_name, 0, aval) + return core.unmapped_aval(sz, None, 0, aval) def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): @@ -674,7 +685,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, extensive_res = _map(trace.new_instantiated_const, extensive_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) - ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval) + ys_avals = [core.unmapped_aval(length, None, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in itertools.chain(carry_avals, ys_avals)] @@ -1041,7 +1052,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Create residual variables. intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals) - ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a) + ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a) for a in ext_avals_mapped] newvar = core.gensym() intensive_res = _map(newvar, intensive_avals) @@ -1119,7 +1130,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, jaxpr.in_avals, [num_consts, num_carry]) carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry]) x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) - y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a) + y_avals = [core.unmapped_aval(length, None, 0, a) for a in y_avals_mapped] if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1099919a6474..1b84797d630e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4513,18 +4513,8 @@ def _pad_sharding_rule(operand, padding_value, *, padding_config): # change this logic to `return operand.sharding` directly. out_shape = _pad_shape_rule(operand, padding_value, padding_config=padding_config) - mesh = operand.sharding.mesh - new_spec = [] - for op_sh, out_sh, op_spec in safe_zip( - operand.shape, out_shape, operand.sharding.spec): - if (op_sh != out_sh and op_spec is not None and - out_sh % slicing._get_sub_spec_size(mesh, op_spec) != 0): - raise NotImplementedError( - f"padding on sharded dims where out dim ({out_sh}) is not divisble by" - f" mesh axes ({slicing._get_sub_spec_size(mesh, op_spec)}) with spec" - f" ({op_spec}) is not implemented.") - new_spec.append(op_spec) - return NamedSharding(mesh, P(*new_spec)) + return slicing._get_sharding_for_varying_out_shape( + out_shape, operand, 'padding') def _pad_transpose(t, operand, padding_value, *, padding_config): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 117c8b655152..c6c85ce4f6a3 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -42,7 +42,6 @@ _input_dtype, standard_primitive, ) -from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike, Shape @@ -1276,23 +1275,33 @@ def _get_sub_spec_size(mesh, sub_spec): return math.prod(mesh.shape[s] for s in sub_spec) return mesh.shape[sub_spec] -def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): - # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, - # change this logic to `return operand.sharding` directly. - out_shape = _slice_shape_rule(operand, start_indices=start_indices, - limit_indices=limit_indices, strides=strides) +def _get_sharding_for_varying_out_shape(out_shape, operand, name): + """Returns a sharding when out_shape may not be the same as operand shape""" mesh = operand.sharding.mesh - new_spec = [] for op_sh, out_sh, op_spec in safe_zip( operand.shape, out_shape, operand.sharding.spec): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): raise NotImplementedError( - f"slicing on sharded dims where out dim ({out_sh}) is not divisble by" + f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" f" ({op_spec}) is not implemented.") - new_spec.append(op_spec) - return NamedSharding(mesh, P(*new_spec)) + # TODO(yashkatariya): Returning operand.sharding as is may or may not move + # data. So think about how to avoid it which might include creating a new + # mesh? For example: + # mesh = {'x': 4} + # x = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x')))` + # ys = lax.split(x, [4, 4]) # This will create outputs of shape (4,) + # According to the current logic, ys[0].sharding.spec == P('x') + # which involves data movement. + return operand.sharding + +def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shape = _slice_shape_rule(operand, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + return _get_sharding_for_varying_out_shape(out_shape, operand, 'slicing') def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) @@ -1367,8 +1376,7 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): mlir.register_lowering(slice_p, _slice_lower) -def _dynamic_slice_shape_rule( - operand, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if operand.ndim != len(start_indices): msg = ("dynamic_slice start_indices must have length equal to the number " @@ -1391,6 +1399,12 @@ def _dynamic_slice_shape_rule( f" got indices {start_indices}") return tuple(lax._merge_dyn_shape(slice_sizes, dyn)) +def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes): + out_shape = _dynamic_slice_shape_rule( + operand, *starts_and_dyn_sizes, slice_sizes=slice_sizes) + return _get_sharding_for_varying_out_shape(out_shape, operand, 'dynamic_slice') + + def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) if any(i.dtype != start_indices[0].dtype or @@ -1494,7 +1508,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + sharding_rule=_dynamic_slice_sharding_rule) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule @@ -1508,7 +1523,10 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes): aval_out, = ctx.avals_out if dyn: aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn)) - return [mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)] + out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower) @@ -1539,6 +1557,14 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): f"scalars, got indices {start_indices}") return operand.shape +def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): + if operand.sharding != update.sharding: + raise TypeError( + "dynamic_update_slice update sharding must be equal to operand" + f" sharding, got update sharding {update.sharding} for operand sharding" + f" {operand.sharding}.") + return operand.sharding + def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): lax.check_same_dtypes("dynamic_update_slice", operand, update) if any(i.dtype != start_indices[0].dtype or @@ -1604,7 +1630,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice') + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -1613,8 +1639,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): def _dynamic_update_slice_lower(ctx, x, update, *start_indices): aval_out, = ctx.avals_out - return [mlir.dynamic_update_slice(ctx, aval_out, x, update, - start_indices=start_indices)] + out = mlir.dynamic_update_slice(ctx, aval_out, x, update, + start_indices=start_indices) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] + return [out] mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index aff956862753..4f16e0013f25 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -185,16 +185,19 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - if (core.trace_state_clean() and - not config.debug_key_reuse.value and - not config.data_dependent_tracing_fallback.value): - args_flat = map(core.full_lower, args_flat) - core.check_eval_args(args_flat) - out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) - else: - out_flat = pjit_p.bind(*args_flat, **p.params) - compiled = None - profiler = None + # TODO(yashkatariya): Maybe thread this into pjit params like resource_env + # and set the context manager down the stack? + with p.abstract_mesh: + if (core.trace_state_clean() and + not config.debug_key_reuse.value and + not config.data_dependent_tracing_fallback.value): + args_flat = map(core.full_lower, args_flat) + core.check_eval_args(args_flat) + out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) + else: + out_flat = pjit_p.bind(*args_flat, **p.params) + compiled = None + profiler = None except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if p.params['resource_env'] is None else 'pjit' @@ -330,9 +333,10 @@ def cache_miss(*args, **kwargs): if config.no_tracing.value: raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " "`jit`, but 'no_tracing' is set") - outs, out_flat, out_tree, args_flat, jaxpr, \ - attrs_tracked, executable, pgle_profiler = _python_pjit_helper( - fun, jit_info, *args, **kwargs) + + (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable, + pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) + maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, @@ -495,10 +499,10 @@ def trace(*args, **kwargs) -> stages.Traced: donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) + pgle_profiler=None) return stages.Traced( p.params['jaxpr'], args_info, p.params["name"], p.out_tree, - lower_callable, args_flat, p.arg_names, p.num_consts) + lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts) wrapped = _cpp_pjit(fun, jit_info) wrapped.lower = lower @@ -534,6 +538,7 @@ class PjitParams(NamedTuple): arg_names: tuple[str, ...] | None num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + abstract_mesh: AbstractMesh def _infer_params_impl( @@ -639,7 +644,9 @@ def _infer_params_impl( attr_token = _attr_token(flat_fun, in_type) - abstract_mesh = get_abstract_mesh(in_type) + abstract_mesh = ( + get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None + else mesh_lib.mesh_context.mesh) with abstract_mesh: jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, dbg, @@ -684,7 +691,7 @@ def _infer_params_impl( ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), donated_invars, dbg.arg_names if dbg else None, len(consts), - attrs_tracked), args_flat + attrs_tracked, abstract_mesh), args_flat def get_abstract_mesh(in_avals): diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index dc4171eec146..8abe58e52a74 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -363,6 +363,11 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) + def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: + if not isinstance(spec, PartitionSpec): + spec = PartitionSpec(*spec) + return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 92c680009c93..b6f3b63d3de4 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,6 +30,7 @@ """ from __future__ import annotations +import contextlib import functools from collections.abc import Sequence from dataclasses import dataclass @@ -716,13 +717,14 @@ class Traced(Stage): "_args_flat", "_arg_names", "_num_consts"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, args_flat=None, arg_names=None, - num_consts: int = 0): + lower_callable, abstract_mesh=contextlib.nullcontext(), + args_flat=None, arg_names=None, num_consts: int = 0): self.jaxpr = jaxpr self.args_info = args_info self.fun_name = fun_name self._out_tree = out_tree self._lower_callable = lower_callable + self._abstract_mesh = abstract_mesh self._args_flat = args_flat self._arg_names = arg_names self._num_consts = num_consts @@ -743,7 +745,10 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, self._lower_callable, lowering_platforms=lowering_platforms, lowering_parameters=_private_parameters) try: - lowering = new_callable() + # TODO(yashkatariya): Maybe thread this into pjit params like resource_env + # and set the context manager down the stack? + with self._abstract_mesh: + lowering = new_callable() except pxla.DeviceAssignmentMismatchError as e: fails, = e.args msg = pjit._device_assignment_mismatch_error( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index dd1415b680a4..293026b2b612 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4674,7 +4674,7 @@ def f(x): if config.use_shardy_partitioner.value: self.assertIn('sdy.sharding_constraint', lowered_text) else: - self.assertEqual(lowered_text.count('@Sharding'), 2) + self.assertEqual(lowered_text.count('@Sharding'), 3) @jax.jit def g(x): @@ -5244,6 +5244,7 @@ def test_shard_map_full_manual(self): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective) return x * y @jax.jit @@ -5268,6 +5269,7 @@ def test_shard_map_dot(self): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y return jax.lax.psum(z, axis_name='y') @@ -5426,6 +5428,44 @@ def g(x, y): out = jax.jit(jax.grad(g))(arr1, arr2) self.assertEqual(out.sharding, s) + def test_scan(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + carry = jax.device_put(np.arange(16.).reshape(2, 8), + NamedSharding(mesh, P(None, 'x'))) + arr = jax.device_put(np.arange(128.).reshape(8, 8, 2), + NamedSharding(mesh, P(None, 'x', 'y'))) + + @jax.jit + def f(carry, xs): + def g(carry, x): + self.assertEqual(carry.sharding.spec, P(None, 'x')) + self.assertEqual(x.sharding.spec, P('x', 'y')) + y = carry @ x + self.assertEqual(y.sharding.spec, P(None, 'y')) + z = jax.nn.relu(y) + self.assertEqual(z.sharding.spec, P(None, 'y')) + a = z @ x.T + self.assertEqual(a.sharding.spec, P(None, 'x')) + return a, y + return jax.lax.scan(g, carry, xs) + + activation, mean = f(carry, arr) + self.assertEqual(activation.sharding, NamedSharding(mesh, P(None, 'x'))) + self.assertEqual(mean.sharding, NamedSharding(mesh, P(None, None, 'y'))) + + f.lower(carry, arr).compile()(carry, arr) # doesn't crash + + def g(carry, arr): + out = f(carry, arr) + return jnp.sum(out[0]) + out = jax.jit(jax.grad(g, argnums=(0, 1)))(carry, arr) + self.assertEqual(out[0].sharding, carry.sharding) + self.assertEqual(out[1].sharding, arr.sharding) + + with self.assertRaisesRegex( + ValueError, "0th dimension of all xs should be replicated"): + f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 73fa0f48cb0081fc69cb82dffe6ceb5433cdc446 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 21 Nov 2024 23:33:50 -0800 Subject: [PATCH 460/698] [Pallas] Deprecate dictionary compiler_params in favor of dataclass. PiperOrigin-RevId: 699057658 --- jax/experimental/pallas/ops/gpu/layer_norm.py | 4 ++-- jax/experimental/pallas/ops/gpu/rms_norm.py | 2 +- tests/pallas/tpu_pallas_pipeline_test.py | 11 +++++------ 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index 7d11e4faf299..d37afaf4d9e0 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -94,7 +94,7 @@ def layer_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -215,7 +215,7 @@ def layer_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index 2a7824315a0f..ff224c6dfde7 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -196,7 +196,7 @@ def rms_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index ca64275d3f09..2af00cf6b8c6 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -486,12 +486,11 @@ def _wait_on_prev_dma(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs ), - compiler_params=dict( - mosaic=dict(collective_id=0, - # must set scoped vmem flag *larger* than below! e.g.: - # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.TPUCompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB ), ) From 34a2f0ca4a8f8a26d9a056f8785f412bd156dc23 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 22 Nov 2024 06:44:46 -0800 Subject: [PATCH 461/698] Add a jaxlib at head build to the cloud-tpu-ci-nightly workflow This will allow us to test TPU compatibility with jaxlib at head. Also, enable v4 runners as they are now online. PiperOrigin-RevId: 699155667 --- .github/workflows/cloud-tpu-ci-nightly.yml | 28 ++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 4ac167bd37c1..7d7bc84fe135 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -24,10 +24,10 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] + jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - # {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available - # {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, + # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available + {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} ] python-version: ["3.10"] @@ -47,6 +47,13 @@ jobs: # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Checkout XLA at head, if we're building jaxlib at head. + - name: Checkout XLA at head + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ matrix.jaxlib-version == 'head' }} + with: + repository: openxla/xla + path: xla - name: Install JAX test requirements run: | $PYTHON -m pip install -U -r build/test-requirements.txt @@ -54,7 +61,20 @@ jobs: - name: Install JAX run: | $PYTHON -m pip uninstall -y jax jaxlib libtpu - if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then + if [ "${{ matrix.jaxlib-version }}" == "head" ]; then + # Build and install jaxlib at head + $PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \ + --bazel_options="--override_repository=xla=$(pwd)/xla" \ + --bazel_options=--color=yes + $PYTHON -m pip install dist/*.whl + + # Install "jax" at head + $PYTHON -m pip install -U -e . + + # Install libtpu + $PYTHON -m pip install --pre libtpu \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then $PYTHON -m pip install .[tpu] \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html From 236d4c605f0493eeb582bca774ded66c175e417f Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 22 Nov 2024 10:19:30 -0500 Subject: [PATCH 462/698] Use optimize='auto' for multi_dot. --- jax/_src/numpy/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index be6828c36e6a..e7e2e369722d 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -2119,7 +2119,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) - if arrs[-1].ndim == 1: einsum_axes[-1] = einsum_axes[-1][:1] return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload] - optimize='optimal', precision=precision) + optimize='auto', precision=precision) @export From 846697f761a5e6857ecea7fcadf02cb7dd5ff18e Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Fri, 22 Nov 2024 10:36:01 -0600 Subject: [PATCH 463/698] Longer timeout for doc render --- .github/workflows/ci-build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 0fd188098ee9..b3f683f89f78 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -139,7 +139,7 @@ jobs: documentation_render: name: Documentation - render documentation runs-on: ubuntu-latest - timeout-minutes: 10 + timeout-minutes: 20 strategy: matrix: python-version: ['3.10'] From c0811c9dffb5a6ddd6f5baf41a41651ffb7efea1 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Fri, 22 Nov 2024 09:13:46 -0800 Subject: [PATCH 464/698] Adds coverage for spmd-axisname-filtering in shard_map transpose. PiperOrigin-RevId: 699193349 --- tests/shard_map_test.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 2a343f7ba784..56cf9987911d 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -709,6 +709,26 @@ def f(x): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) + def test_vmap_of_grad_spmd_axis_name(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + @partial( + shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False + ) + def f(x): + return jnp.sin(jnp.sum(x)) + + x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4) + put_x = jax.device_put( + x, + jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')), + ) + vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x) + vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x) + self.assertArraysEqual( + vmap_spmd_axisname_result, vmap_no_spmd_axisname_result + ) + def test_vmap_spmd_axis_name_pair(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) From 763560526221374d74248517a2e1d98a8712ece3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 22 Nov 2024 11:01:20 -0800 Subject: [PATCH 465/698] Use with_spec where possible to clean up the code a bit PiperOrigin-RevId: 699226058 --- jax/_src/lax/lax.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1b84797d630e..56a4266cc2df 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3426,7 +3426,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, if config.sharding_in_types.value: xs = x.aval.sharding inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) - ds = NamedSharding(xs.mesh, P(*inverse_spec)) + ds = xs.with_spec(inverse_spec) else: ds = None dot_general_out = dot_general(g, y, dims, precision=precision, @@ -4116,7 +4116,7 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, orig_spec = iter(operand.sharding.spec) new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] assert next(orig_spec, None) is None - return NamedSharding(operand.sharding.mesh, P(*new_spec)) + return operand.sharding.with_spec(new_spec) def _broadcast_in_dim_typecheck_rule( _, operand, *dyn_shape, shape, broadcast_dimensions, sharding): @@ -4593,7 +4593,7 @@ def _squeeze_sharding_rule(operand, *, dimensions): dims_set = set(dimensions) new_spec = tuple(s for i, s in enumerate(operand.sharding.spec) if i not in dims_set) - return NamedSharding(operand.sharding.mesh, P(*new_spec)) + return operand.sharding.with_spec(new_spec) def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) @@ -4688,7 +4688,7 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions): if n != sh: raise NotImplementedError new_spec.append(sp) - return NamedSharding(operand.sharding.mesh, P(*new_spec)) + return operand.sharding.with_spec(new_spec) def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions): if not dyn_shape: @@ -4791,7 +4791,7 @@ def _transpose_shape_rule(operand, *, permutation): def _transpose_sharding_rule(operand, *, permutation): o_spec = operand.sharding.spec new_spec = [o_spec[old_idx] for old_idx in permutation] - return NamedSharding(operand.sharding.mesh, P(*new_spec)) + return operand.sharding.with_spec(new_spec) def _transpose_batch_rule(batched_args, batch_dims, *, permutation): operand, = batched_args @@ -5165,7 +5165,7 @@ def _reduce_op_sharding_rule(operand, *, axes): axes = frozenset(axes) new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) if i not in axes)) - return NamedSharding(operand.sharding.mesh, new_spec) + return operand.sharding.with_spec(new_spec) reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), @@ -6237,7 +6237,7 @@ def _const(example, val): def _one(x): if config.sharding_in_types.value: return full_like(x, shape=(), fill_value=1, - sharding=NamedSharding(x.sharding.mesh, P())) + sharding=x.sharding.with_spec(P())) return full_like(x, shape=(), fill_value=1) _twos: Callable = partial(full_like, fill_value=2) From 21f8885a9e104b8828c9a8b721eed0c68b622691 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 22 Nov 2024 11:59:39 -0800 Subject: [PATCH 466/698] [sharding_in_types] Make argmax and argmin work with sharding_in_types. This also requires adding `reduce_p` sharding rule PiperOrigin-RevId: 699244204 --- jax/_src/lax/lax.py | 41 +++++++++++++++++++++++++++++------------ jax/_src/lax/utils.py | 11 ++++++++--- tests/pjit_test.py | 22 ++++++++++++++++++++++ 3 files changed, 59 insertions(+), 15 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 56a4266cc2df..a86a17b3c636 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5007,6 +5007,11 @@ def _reduce_shape_rule(*avals, computation, jaxpr, dimensions): raise ValueError(f'reduce found non-scalar initial value: {init_val_shapes}') return [tuple(np.delete(op.shape, dimensions)) for op in operand_avals] +def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): + operand_avals, _ = split_list(avals, [len(avals) // 2]) + return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) + for op in operand_avals] + def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions): operand_avals, init_val_avals = split_list(avals, [len(avals) // 2]) operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals] @@ -5093,7 +5098,7 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -5115,6 +5120,9 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): *reducer.arguments, dim_var_values=ctx.dim_var_values) hlo.return_(mlir.flatten_ir_values(out_nodes)) + if config.sharding_in_types.value: + return [mlir.lower_sharding_under_shit(ctx, r, aval) + for r, aval in safe_zip(op.results, ctx.avals_out)] return op.results mlir.register_lowering(reduce_p, _reduce_lower) @@ -5227,7 +5235,12 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype): if operand.shape[axis] < 1: raise ValueError("argmin and argmax require non-empty reduced dimension. " f"operand.shape={operand.shape} {axis=}") - return tuple(np.delete(operand.shape, axis)) + return util.tuple_delete(operand.shape, axis) + +def _argminmax_sharding_rule(operand, *, axes, index_dtype): + axis, = axes + return operand.sharding.with_spec( + util.tuple_delete(operand.sharding.spec, axis)) def _argminmax_dtype_rule(operand, *, axes, index_dtype): if not dtypes.issubdtype(index_dtype, np.integer): @@ -5264,7 +5277,9 @@ def _compute_argminmax(value_comparator, get_identity, # value_comparator is either lax.lt (for argmin) or lax.gt # get_identity(operand.dtype) is inf for argmin or -inf for argmax axis, = axes - indices = broadcasted_iota(index_dtype, np.shape(operand), axis) + indices = broadcasted_iota( + index_dtype, np.shape(operand), axis, + _sharding=operand.sharding if config.sharding_in_types.value else None) res = reduce([operand, indices], [get_identity(operand.dtype), np.array(0, index_dtype)], _ArgMinMaxReducer(value_comparator), @@ -5272,22 +5287,24 @@ def _compute_argminmax(value_comparator, get_identity, return res[1] argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, - 'argmin', weak_type_rule=_strip_weak_type) + 'argmin', weak_type_rule=_strip_weak_type, + sharding_rule=_argminmax_sharding_rule) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, - 'argmax', weak_type_rule=_strip_weak_type) + 'argmax', weak_type_rule=_strip_weak_type, + sharding_rule=_argminmax_sharding_rule) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) -mlir.register_lowering(argmin_p, mlir.cache_lowering(mlir.lower_fun( - partial(_compute_argminmax, lt, _get_min_identity), - multiple_results=False))) +mlir.register_lowering(argmin_p, mlir.cache_lowering( + mlir.lower_fun(partial(_compute_argminmax, lt, _get_min_identity), + multiple_results=False))) -mlir.register_lowering(argmax_p, mlir.cache_lowering(mlir.lower_fun( - partial(_compute_argminmax, gt, _get_max_identity), - multiple_results=False))) +mlir.register_lowering(argmax_p, mlir.cache_lowering( + mlir.lower_fun(partial(_compute_argminmax, gt, _get_max_identity), + multiple_results=False))) def _reduce_logical_shape_rule(operand, *, axes): @@ -5882,7 +5899,7 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule)) + _rng_bit_generator_weak_type_rule, None)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 82804c796e6e..78d125436029 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -69,7 +69,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs): + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, + *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = max(map(type, avals), key=_get_array_abstraction_level) @@ -77,8 +78,12 @@ def standard_multi_result_abstract_eval( if least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) - return [core.ShapedArray(s, d, weak_type=weak_type) - for s, d, weak_type in zip(out_shapes, out_dtypes, weak_types)] + out_shardings = (sharding_rule(*avals, **kwargs) + if config.sharding_in_types.value else + [None] * len(out_shapes)) + return [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) + for s, d, weak_type, sh in zip(out_shapes, out_dtypes, weak_types, + out_shardings)] elif least_specialized is core.UnshapedArray: out_dtypes = dtype_rule(*avals, **kwargs) return [core.UnshapedArray(dtype, weak_type=weak_type) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 293026b2b612..4d9b98b4d595 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5466,6 +5466,28 @@ def g(carry, arr): ValueError, "0th dimension of all xs should be replicated"): f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) + def test_argminmax(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + z = jnp.argmax(x, axis=0) + self.assertEqual(z.sharding.spec, P('y')) + a = jnp.argmin(x, axis=1) + self.assertEqual(a.sharding.spec, P('x')) + return z, a + + out1, out2 = f(arr) + self.assertArraysEqual(out1, np.argmax(np_inp, axis=0)) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('y'))) + self.assertArraysEqual(out2, np.argmin(np_inp, axis=1)) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) + + self.assertIn('@Sharding', f.lower(arr).as_text()) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 2f28601608e54e463b99fdefbd0a3cd6b188fa06 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Fri, 22 Nov 2024 15:41:48 -0600 Subject: [PATCH 467/698] Fix upstream PR workflow to use origin branches (#151) --- .github/workflows/rocm-open-upstream-pr.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml index 09dfd06e907e..e711d964a0fb 100644 --- a/.github/workflows/rocm-open-upstream-pr.yml +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -20,8 +20,9 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Rebase code to main run: | - git checkout -b $NEW_BRANCH_NAME ${{ github.head_ref }} - git rebase --onto main + git fetch + git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }} + git rebase --onto origin/main git push origin HEAD # TODO: Change the base of the PR to upstream main - name: Create a PR to upstream From b1d1dcf607761533fbe0348a496657c9b1cf986e Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 22 Nov 2024 14:15:46 -0800 Subject: [PATCH 468/698] Add linearization rule for pjit_p --- jax/_src/interpreters/ad.py | 66 +++++++++++++++++++++------ jax/_src/interpreters/partial_eval.py | 20 ++++---- jax/_src/lax/lax.py | 5 +- jax/_src/pjit.py | 46 +++++++++++++++++++ tests/api_test.py | 2 +- 5 files changed, 112 insertions(+), 27 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 9fa2fdb9ffbf..804df185d4ed 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -105,22 +105,56 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): store.store(aux_primals) return out_primals, out_tangents +def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: + dbg = jaxpr.debug_info and jaxpr.debug_info._replace( + arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars)) + return core.Jaxpr(constvars=(), + invars=jaxpr.invars + jaxpr.constvars, + outvars=jaxpr.outvars, eqns=jaxpr.eqns, + effects=jaxpr.effects, debug_info=dbg) + +def linearize_jaxpr(jaxpr, nonzeros): + primal_trace = pe.DynamicJaxprTrace() + tangent_trace = pe.DynamicJaxprTrace() + lin_trace = LinearizeTrace(primal_trace, tangent_trace) + + def new_arg(primal_aval, nz): + primal = primal_trace.new_arg(primal_aval) + tangent_aval = primal_aval.to_tangent_aval() + tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval) + return LinearizeTracer(lin_trace, primal, tangent) + + tracers = [new_arg(v.aval, nz) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)] + with core.set_current_trace(lin_trace): + ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers) + + out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans)) + nzs_out = [type(t) is not Zero for t in out_tangents] + out_tangents = [tangent_trace.to_jaxpr_tracer(t) + for (nz, t) in zip(nzs_out, out_tangents) if nz] + tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) + del attrs_tracked # TODO: attrs + residuals_and_primals = (*tangent_consts, *out_primals) + primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals) + num_residuals = len(tangent_consts) + tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr)) + del attrs_tracked # TODO: attrs + return core.ClosedJaxpr(primal_jaxpr, primal_consts), num_residuals, nzs_out, tangent_jaxpr + def direct_linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) assert not has_aux with core.take_current_trace() as parent_trace: - frame = pe.JaxprStackFrame() - tangent_trace = pe.DynamicJaxprTrace(frame) + tangent_trace = pe.DynamicJaxprTrace() tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] - tag = core.TraceTag() - linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag) + linearize_trace = LinearizeTrace(parent_trace, tangent_trace) tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] with core.set_current_trace(linearize_trace): ans = traceable.call_wrapped(*tracers) out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) - jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents) + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents] del attrs_tracked # TODO: attrs return out_primals, out_tangents_pvals, jaxpr, consts @@ -469,8 +503,8 @@ def _primal_tangent_shapes_match(primal, tangent): class LinearizeTrace(Trace): - def __init__(self, parent_trace, tangent_trace, tag): - self.tag = tag + def __init__(self, parent_trace, tangent_trace, tag=None): + self.tag = core.TraceTag() if tag is None else tag self.parent_trace = parent_trace self.tangent_trace = tangent_trace @@ -509,18 +543,20 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): return primal def fallback_linearize_rule(prim, _, *args, **kwargs): + assert not prim.multiple_results + def call_prim(*args_): - return prim.bind(*args_, **kwargs) + return [prim.bind(*args_, **kwargs)] + with config.use_direct_linearize(False): - out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize( + (out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize( lu.wrap_init(call_prim), *args, **kwargs) + def linearized(residuals, *tangents): - tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents)) - full_out = [pval.get_known() if pval.is_known() else next(tangents_out) - for pval in out_tangents_pvals] - assert next(tangents_out, None) is None - return full_out - return out_primals, [True for _ in out_primals], consts, linearized + out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents) + return out_tangent + + return out_primal, True, consts, linearized class LinearizeTracer(Tracer): __slots__ = ['primal', 'tangent'] diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 943c15b6ea49..6e2f11833b9d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1575,6 +1575,7 @@ def get_referent(self): val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) + def _dynamic_jaxpr_tracer_shaped_abstractify(x): return x.aval api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify @@ -1805,8 +1806,8 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: class DynamicJaxprTrace(core.Trace): - def __init__(self, frame): - self.frame = frame + def __init__(self): + self.frame = JaxprStackFrame() def invalidate(self): # avoid cyclic refs @@ -2068,6 +2069,9 @@ def transpose_jaxpr_thunk(): self.frame.add_eqn(eqn) return out_tracers + def to_jaxpr(self, out_tracers: Sequence[Tracer]): + return self.frame.to_jaxpr(self, out_tracers) + custom_staging_rules: dict[Primitive, Callable] = {} @@ -2166,10 +2170,8 @@ def trace_to_jaxpr_dynamic( list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - frame = JaxprStackFrame() - frame.debug_info = debug_info - - trace = DynamicJaxprTrace(frame) + trace = DynamicJaxprTrace() + trace.frame.debug_info = debug_info with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] @@ -2177,8 +2179,8 @@ def trace_to_jaxpr_dynamic( ans = fun.call_wrapped(*in_tracers) out_tracers = map(trace.to_jaxpr_tracer, ans) - jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers) - del trace, fun, frame, in_tracers, out_tracers, ans + jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers) + del trace, fun, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked @@ -2188,7 +2190,7 @@ def trace_to_jaxpr_dynamic2( fun: lu.WrappedFun, debug_info: DebugInfo | None = None ) -> tuple[Jaxpr, OutputType, list[Any]]: - trace = DynamicJaxprTrace(JaxprStackFrame()) + trace = DynamicJaxprTrace() with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): trace.frame.debug_info = debug_info in_avals, keep_inputs = unzip2(fun.in_type) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1b84797d630e..0a12d04445a5 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2400,9 +2400,10 @@ def _sin_lowering(ctx, x): return sine(ctx, x) return _nary_lower_hlo(hlo.sine, ctx, x) -def _sin_p_lin(_, x): +def _sin_p_lin(nzs, x): + nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) - return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 4f16e0013f25..e0af0b7d2137 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2107,6 +2107,52 @@ def _filter_zeros(is_nz_l, l): ad.primitive_jvps[pjit_p] = _pjit_jvp +def _pjit_linearization(nzs, *primals_in, jaxpr, + in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): + primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) + # constvars will become residuals. Move them to the end of the ordinary args. + res_shardings = (UNSPECIFIED,) * num_residuals + res_layouts = (None,) * num_residuals + res_donated = (False,) * num_residuals + def tangent_fun(consts_, *tangents): + tangents_nz = _filter_zeros(nzs, tangents) + assert len(consts_) == num_residuals + return pjit_p.bind(*(*tangents_nz, *consts_), + jaxpr=tangent_jaxpr, + in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, + out_shardings=_filter_zeros(nzs_out, out_shardings), + in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts, + out_layouts=_filter_zeros(nzs_out, out_layouts), + resource_env=resource_env, + donated_invars=_filter_zeros(nzs, donated_invars) + res_donated, + name=name, + keep_unused=keep_unused, + inline=inline, + compiler_options_kvs=compiler_options_kvs) + + def _filter_zeros(is_nz_l, l): + return tuple(x for nz, x in zip(is_nz_l, l) if nz) + + ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, + in_shardings=in_shardings, + out_shardings=(*res_shardings, *out_shardings), + in_layouts=in_layouts, + out_layouts=(*res_layouts, *out_layouts), + resource_env=resource_env, + donated_invars=donated_invars, + name=name, + keep_unused=keep_unused, + inline=inline, + compiler_options_kvs=compiler_options_kvs) + residuals_ans, primal_ans = split_list(ans, [num_residuals]) + + return primal_ans, nzs_out, residuals_ans, tangent_fun + +ad.primitive_linearizations[pjit_p] = _pjit_linearization + + def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, diff --git a/tests/api_test.py b/tests/api_test.py index a27938eed392..a6c1c5d53d91 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4818,7 +4818,7 @@ def check_invariant_to_use_direct_linearize(f): self.assertEqual(ans1, ans2) def sin_of_sin(x): - return lax.sin(lax.sin(x)) + return lax.sin(jax.jit(lax.sin)(x)) check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0)) From 9f6dbef3dc7a499cae3a4e3ec001d028add3561a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 22 Nov 2024 14:50:10 -0800 Subject: [PATCH 469/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0564969ba385bfc895baad8f64879236bfbc717b. PiperOrigin-RevId: 699295115 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 46f71523be05..171c5f774c5f 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "85360d67ffc0a6d6923605b848de12ec204ca336" -XLA_SHA256 = "7afa7e599adf7b1a636ea9e55419c253a115ef27217ec862ca8a03cef1abd11a" +XLA_COMMIT = "0564969ba385bfc895baad8f64879236bfbc717b" +XLA_SHA256 = "e54fccae6075c574493c27443658ee178833362d8890f3285a9d4787f4bdbc09" def repo(): tf_http_archive( From a07abe2466b578247a31d75eee17fe59741159e4 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Fri, 22 Nov 2024 16:56:29 -0600 Subject: [PATCH 470/698] Add token for GitHub CLI (#152) --- .github/workflows/rocm-open-upstream-pr.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml index e711d964a0fb..96c2d6e8128a 100644 --- a/.github/workflows/rocm-open-upstream-pr.yml +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -36,5 +36,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Leave comment on old PR + env: + GH_TOKEN: ${{ github.token }} run: gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body ${{ needs.open-upstream.outputs.new-pr-link.link }} From 8699f5d970cc9cd2aa4340d2021d3e0d65a3d0ea Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 22 Nov 2024 17:32:20 -0800 Subject: [PATCH 471/698] When host local inputs on all hosts are the same, use `_DeferredShardArg` to do the transfers instead of `jit` to avoid blocking. PiperOrigin-RevId: 699336402 --- jax/_src/dispatch.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 081abf394f98..b3f16c724ee4 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -411,15 +411,12 @@ def _device_put_sharding_impl(x, aval, device, copy): if not s.is_fully_addressable: if ((isinstance(x, array.ArrayImpl) and not x._committed) or type(x) in array_types): - # TODO(yashkatariya): Move this check to `jit`. multihost_utils.assert_equal( x, fail_message=( f"{type(x)} passed to device_put is not the same on each" " process. Make sure you are passing the same value of" f" {type(x)} on each process.")) - return api.jit( - _identity_fn, out_shardings=s, - donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x) + return _DeferredShardArg(x, s, aval, True, copy) # TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array. raise ValueError( "device_put's second argument must be a Device or a Sharding which" From b259fde5415201350caf4979353365122a915957 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 22 Nov 2024 17:39:10 -0800 Subject: [PATCH 472/698] Fix member access to xla backend. The correct member is `client` instead of `backend` PiperOrigin-RevId: 699338495 --- jax/experimental/colocated_python/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index 971002f51160..abf92306ef0f 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -32,7 +32,7 @@ def colocated_cpu_devices( raise NotImplementedError("Requires xla_extension_version >= 290") cpu_devices_by_colocation_id = collections.defaultdict(list) - for device in devices[0].backend._get_all_devices(): # pylint: disable=protected-access + for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access if device.device_kind == "cpu": cpu_devices_by_colocation_id[device.colocation_id].append(device) if not cpu_devices_by_colocation_id: From e53ff2cbfc5ff153208ca061d4204583b16c696d Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 22 Nov 2024 23:39:05 -0800 Subject: [PATCH 473/698] [Mosaic][Easy] - Wire up kernel names to MLIR dump PiperOrigin-RevId: 699408419 --- jax/_src/tpu_custom_call.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index f463986ffb50..10a979dffc6a 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -278,6 +278,7 @@ def _lower_tpu_kernel( module: ir.Module, hardware_generation: int, target_shape: tuple[int, int], + kernel_name: str | None = None, ) -> ir.Module: """Runs MLIR passes lowering the given module to an MLIR module. @@ -303,8 +304,7 @@ def _lower_tpu_kernel( tpu.register_dialect(ctx) mhlo.register_mhlo_dialect(ctx) mhlo.register_mhlo_passes() - - dump_mlir(module, "original") + dump_mlir(module, "original", kernel_name) if _MOSAIC_ALLOW_HLO.value: # Run hlo dialect conversion: hlo -> linalg -> vector. @@ -406,6 +406,7 @@ def _lower_mosaic_module_to_asm( *, backend: str, device_type: str | None, + kernel_name: str | None, ) -> tuple[ir.Module, tuple[bool, bool, bool, bool]]: has_communication, has_custom_barrier = tpu.private_has_communication( module.operation @@ -429,7 +430,7 @@ def _lower_mosaic_module_to_asm( hardware_generation = int(device_kind[len("TPU v")]) target_shape = get_target_shape(hardware_generation) module = _lower_tpu_kernel( - module, hardware_generation, target_shape=target_shape + module, hardware_generation, target_shape=target_shape, kernel_name=kernel_name, ) needs_hlo_passes = False needs_layout_passes = False @@ -504,6 +505,7 @@ def _lower_to_custom_call_config( collective_id: int | None, serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + kernel_name: str | None = None, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -514,6 +516,7 @@ def _lower_to_custom_call_config( module, backend=backend, device_type=device_type, + kernel_name=kernel_name, ) return _lowered_to_custom_call_config( lowered_module_asm, @@ -613,6 +616,7 @@ def lower_module_to_custom_call( device_type=device_type, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, + kernel_name=kernel_name, ) return _tpu_custom_call_lowering( ctx, @@ -654,6 +658,7 @@ def as_tpu_kernel( collective_id=collective_id, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, + kernel_name=kernel_name, ) return _as_jax_callable( config, @@ -735,7 +740,7 @@ def apply_kernel(*args): return jax.jit(apply_kernel) -def dump_mlir(module: ir.Module, name: str): +def dump_mlir(module: ir.Module, name: str, kernel_name: str | None = None): """A helper function to dump mosaic mlir module""" try: should_dump = FLAGS["xla_mosaic_dump_to"].value @@ -744,6 +749,8 @@ def dump_mlir(module: ir.Module, name: str): if should_dump == "sponge": outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None) if outdir: + if kernel_name: + name = f"{kernel_name}-{name}" path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}-py.txt") with open(path, "w") as f: f.write(str(module)) From 4d8751bff4b8f3bb020eef4cae4a6a53df14be0f Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 23 Nov 2024 15:30:03 -0800 Subject: [PATCH 474/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/90af2896ab4992ff14a1cd2a75ce02e43f46c090. PiperOrigin-RevId: 699545393 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 171c5f774c5f..fda631f6abfc 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "0564969ba385bfc895baad8f64879236bfbc717b" -XLA_SHA256 = "e54fccae6075c574493c27443658ee178833362d8890f3285a9d4787f4bdbc09" +XLA_COMMIT = "90af2896ab4992ff14a1cd2a75ce02e43f46c090" +XLA_SHA256 = "a232f506f5c0ff31863fdb9b612691742c6e72874c529920156406ff520ee376" def repo(): tf_http_archive( From b372ce4b1ab0bee7a1da495b098ff3948a6c0d4d Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 24 Nov 2024 14:59:35 -0800 Subject: [PATCH 475/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/40d457a268baf95e42cd95709dedef70c0ea2994. PiperOrigin-RevId: 699768724 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index fda631f6abfc..32508f8310bb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "90af2896ab4992ff14a1cd2a75ce02e43f46c090" -XLA_SHA256 = "a232f506f5c0ff31863fdb9b612691742c6e72874c529920156406ff520ee376" +XLA_COMMIT = "40d457a268baf95e42cd95709dedef70c0ea2994" +XLA_SHA256 = "a7bd8cc608964ba91d9edfd6070929cffe09b448e9e36fd8224bc8fc99202db3" def repo(): tf_http_archive( From 69e3f0d37de437337d3bad91af651237c3402a80 Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Mon, 25 Nov 2024 03:30:19 -0800 Subject: [PATCH 476/698] [pallas:mosaic_gpu] Add test for FragmentedArray.bitcast. PiperOrigin-RevId: 699919048 --- .../mosaic/gpu/fragmented_array.py | 10 ++++-- tests/mosaic/gpu_test.py | 34 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index e1ee37f3d24d..094c683c1695 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -463,8 +463,8 @@ def __init__( if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype): raise TypeError( - "is_signed must only be non-None if the MLIR type is an integer" - f" type, got {_is_signed=} for {self.mlir_dtype}" + "is_signed must be non-None if and only if the MLIR type is an" + f" integer type, got {_is_signed=} for {self.mlir_dtype}" ) match self.layout: @@ -962,6 +962,12 @@ def fast_instr(x): return fast_instr def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): + if (output_is_signed is not None) != ir.IntegerType.isinstance(elt): + raise TypeError( + "output_is_signed must be non-None if and only if the MLIR type is an" + f" integer type, got {output_is_signed=} for {elt}" + ) + if elt == self.mlir_dtype: return self reg_type = self.registers.flat[0].type diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 87dc2c452041..aeddbc7e033d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1577,6 +1577,40 @@ def kernel(ctx, _): _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), (), None)() + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.bfloat16), + (jnp.int16, jnp.float16), + (jnp.uint16, jnp.float16), + (jnp.float32, jnp.int32), + (jnp.float32, jnp.uint32), + (jnp.uint32, jnp.int32), + (jnp.int32, jnp.uint32), + ) + def test_bitcast(self, in_dtype, out_dtype): + out_ir_type = utils.dtype_to_ir_type(out_dtype) + in_is_signed = utils.is_signed(in_dtype) + out_is_signed = utils.is_signed(out_dtype) + + def kernel(ctx, inp, out, smem): + del ctx, smem + arr = mgpu.FragmentedArray.load_strided(inp, is_signed=in_is_signed) + arr = arr.bitcast(out_ir_type, output_is_signed=out_is_signed) + arr.store_untiled(out) + + x = jnp.arange(256, dtype=in_dtype) + reference = jax.lax.bitcast_convert_type(x, out_dtype) + + result = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + x, + reference, + None, + )(x) + np.testing.assert_array_equal(result, reference) + class ProfilerTest(TestCase): From 84a9cba85bd358960d674b1724922515b752c80a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 21 Nov 2024 12:44:27 -0500 Subject: [PATCH 477/698] Refactor FFI examples to consolidate several examples into one submodule. --- examples/ffi/CMakeLists.txt | 13 ++--- examples/ffi/README.md | 27 ++++----- examples/ffi/src/jax_ffi_example/counter.cc | 53 ------------------ examples/ffi/src/jax_ffi_example/counter.py | 38 ------------- .../{attrs.cc => cpu_examples.cc} | 55 ++++++++++++++++++- .../{attrs.py => cpu_examples.py} | 17 +++--- .../{cuda_e2e.cu => cuda_examples.cu} | 0 .../{cuda_e2e.py => cuda_examples.py} | 2 +- examples/ffi/tests/counter_test.py | 55 ------------------- .../{attrs_test.py => cpu_examples_test.py} | 47 +++++++++++++--- ...cuda_e2e_test.py => cuda_examples_test.py} | 4 +- 11 files changed, 122 insertions(+), 189 deletions(-) delete mode 100644 examples/ffi/src/jax_ffi_example/counter.cc delete mode 100644 examples/ffi/src/jax_ffi_example/counter.py rename examples/ffi/src/jax_ffi_example/{attrs.cc => cpu_examples.cc} (57%) rename examples/ffi/src/jax_ffi_example/{attrs.py => cpu_examples.py} (73%) rename examples/ffi/src/jax_ffi_example/{cuda_e2e.cu => cuda_examples.cu} (100%) rename examples/ffi/src/jax_ffi_example/{cuda_e2e.py => cuda_examples.py} (99%) delete mode 100644 examples/ffi/tests/counter_test.py rename examples/ffi/tests/{attrs_test.py => cpu_examples_test.py} (55%) rename examples/ffi/tests/{cuda_e2e_test.py => cuda_examples_test.py} (96%) diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 9f9090e2b7ef..843c2cda0e3b 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -15,8 +15,7 @@ find_package(nanobind CONFIG REQUIRED) set( JAX_FFI_EXAMPLE_PROJECTS "rms_norm" - "attrs" - "counter" + "cpu_examples" ) foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS}) @@ -27,9 +26,9 @@ endforeach() if(JAX_FFI_EXAMPLE_ENABLE_CUDA) enable_language(CUDA) - add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu") - set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON - CUDA_STANDARD 17) - target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR}) - install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) + add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu") + set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON + CUDA_STANDARD 17) + target_include_directories(_cuda_examples PUBLIC ${XLA_DIR}) + install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) endif() diff --git a/examples/ffi/README.md b/examples/ffi/README.md index eb730b483b76..bd45408e50d8 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -11,18 +11,19 @@ Within the example project, there are several example calls: demonstrates the most basic use of the FFI. It also includes customization of behavior under automatic differentiation using `jax.custom_vjp`. -2. `counter`: This example demonstrates a common pattern for how an FFI call can - use global cache to maintain state between calls. This pattern is useful when - an FFI call requires an expensive initialization step which shouldn't be - run on every execution, or if there is other shared state that could be - reused between calls. In this simple example we just count the number of - times the call was executed. +2. `cpu_examples`: This submodule includes several smaller examples: -3. `attrs`: An example demonstrating the different ways that attributes can be - passed to the FFI. For example, we can pass arrays, variadic attributes, and - user-defined types. Full support of user-defined types isn't yet supported - by XLA, so that example will be added in the future. + * `counter`: This example demonstrates a common pattern for how an FFI call + can use global cache to maintain state between calls. This pattern is + useful when an FFI call requires an expensive initialization step which + shouldn't be run on every execution, or if there is other shared state + that could be reused between calls. In this simple example we just count + the number of times the call was executed. + * `attrs`: An example demonstrating the different ways that attributes can be + passed to the FFI. For example, we can pass arrays, variadic attributes, + and user-defined types. Full support of user-defined types isn't yet + supported by XLA, so that example will be added in the future. -4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with - CUDA. The specifics of the kernels are not very important, but the general - structure, and packaging of the extension are useful for testing. +3. `cuda_examples`: An end-to-end example demonstrating the use of the JAX FFI + with CUDA. The specifics of the kernels are not very important, but the + general structure, and packaging of the extension are useful for testing. diff --git a/examples/ffi/src/jax_ffi_example/counter.cc b/examples/ffi/src/jax_ffi_example/counter.cc deleted file mode 100644 index d7f17e730fd6..000000000000 --- a/examples/ffi/src/jax_ffi_example/counter.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "nanobind/nanobind.h" -#include "xla/ffi/api/ffi.h" - -namespace nb = nanobind; -namespace ffi = xla::ffi; - -ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0 out) { - static std::mutex mutex; - static auto& cache = *new std::unordered_map(); - { - const std::lock_guard lock(mutex); - auto it = cache.find(index); - if (it != cache.end()) { - out->typed_data()[0] = ++it->second; - } else { - cache.insert({index, 0}); - out->typed_data()[0] = 0; - } - } - return ffi::Error::Success(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - Counter, CounterImpl, - ffi::Ffi::Bind().Attr("index").Ret>()); - -NB_MODULE(_counter, m) { - m.def("registrations", []() { - nb::dict registrations; - registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); - return registrations; - }); -} diff --git a/examples/ffi/src/jax_ffi_example/counter.py b/examples/ffi/src/jax_ffi_example/counter.py deleted file mode 100644 index 12c7f015bf58..000000000000 --- a/examples/ffi/src/jax_ffi_example/counter.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""An example demonstrating how an FFI call can maintain "state" between calls - -In this case, the ``counter`` call simply accumulates the number of times it -was executed, but this pattern can also be used for more advanced use cases. -For example, this pattern is used in jaxlib for: - -1. The GPU solver linear algebra kernels which require an expensive "handler" - initialization, and -2. The ``triton_call`` function which caches the compiled triton modules after - their first use. -""" - -import jax -import jax.extend as jex - -from jax_ffi_example import _counter - -for name, target in _counter.registrations().items(): - jex.ffi.register_ffi_target(name, target) - - -def counter(index): - return jex.ffi.ffi_call( - "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/cpu_examples.cc similarity index 57% rename from examples/ffi/src/jax_ffi_example/attrs.cc rename to examples/ffi/src/jax_ffi_example/cpu_examples.cc index 7ff5c98e52e1..3832c86b29b2 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.cc +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" @@ -21,6 +24,17 @@ limitations under the License. namespace nb = nanobind; namespace ffi = xla::ffi; +// ---------- +// Attributes +// ---------- +// +// An example demonstrating the different ways that attributes can be passed to +// the FFI. +// +// For example, we can pass arrays, variadic attributes, and user-defined types. +// Full support of user-defined types isn't yet supported by XLA, so that +// example will be added in the future. + ffi::Error ArrayAttrImpl(ffi::Span array, ffi::ResultBufferR0 res) { int64_t total = 0; @@ -54,13 +68,52 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, .Ret>() .Ret>()); -NB_MODULE(_attrs, m) { +// ------- +// Counter +// ------- +// +// An example demonstrating how an FFI call can maintain "state" between calls +// +// In this case, the ``Counter`` call simply accumulates the number of times it +// was executed, but this pattern can also be used for more advanced use cases. +// For example, this pattern is used in jaxlib for: +// +// 1. The GPU solver linear algebra kernels which require an expensive "handler" +// initialization, and +// 2. The ``triton_call`` function which caches the compiled triton modules +// after their first use. + +ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0 out) { + static std::mutex mutex; + static auto &cache = *new std::unordered_map(); + { + const std::lock_guard lock(mutex); + auto it = cache.find(index); + if (it != cache.end()) { + out->typed_data()[0] = ++it->second; + } else { + cache.insert({index, 0}); + out->typed_data()[0] = 0; + } + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Counter, CounterImpl, + ffi::Ffi::Bind().Attr("index").Ret>()); + +// Boilerplate for exposing handlers to Python +NB_MODULE(_cpu_examples, m) { m.def("registrations", []() { nb::dict registrations; registrations["array_attr"] = nb::capsule(reinterpret_cast(ArrayAttr)); registrations["dictionary_attr"] = nb::capsule(reinterpret_cast(DictionaryAttr)); + + registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); + return registrations; }); } diff --git a/examples/ffi/src/jax_ffi_example/attrs.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py similarity index 73% rename from examples/ffi/src/jax_ffi_example/attrs.py rename to examples/ffi/src/jax_ffi_example/cpu_examples.py index 2f215e8e25b1..7771237e41d1 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.py +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py @@ -12,22 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An example demonstrating the different ways that attributes can be passed to -the FFI. - -For example, we can pass arrays, variadic attributes, and user-defined types. -Full support of user-defined types isn't yet supported by XLA, so that example -will be added in the future. -""" - import numpy as np import jax import jax.extend as jex -from jax_ffi_example import _attrs +from jax_ffi_example import _cpu_examples -for name, target in _attrs.registrations().items(): +for name, target in _cpu_examples.registrations().items(): jex.ffi.register_ffi_target(name, target) @@ -43,3 +35,8 @@ def dictionary_attr(**kwargs): "dictionary_attr", (jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)), )(**kwargs) + + +def counter(index): + return jex.ffi.ffi_call( + "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu b/examples/ffi/src/jax_ffi_example/cuda_examples.cu similarity index 100% rename from examples/ffi/src/jax_ffi_example/cuda_e2e.cu rename to examples/ffi/src/jax_ffi_example/cuda_examples.cu diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.py b/examples/ffi/src/jax_ffi_example/cuda_examples.py similarity index 99% rename from examples/ffi/src/jax_ffi_example/cuda_e2e.py rename to examples/ffi/src/jax_ffi_example/cuda_examples.py index 500677050a4b..b60b12af577e 100644 --- a/examples/ffi/src/jax_ffi_example/cuda_e2e.py +++ b/examples/ffi/src/jax_ffi_example/cuda_examples.py @@ -27,7 +27,7 @@ import jax.extend as jex # Load the shared library with the FFI target definitions -SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so") +SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so") library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd), diff --git a/examples/ffi/tests/counter_test.py b/examples/ffi/tests/counter_test.py deleted file mode 100644 index 1e2ad38a363f..000000000000 --- a/examples/ffi/tests/counter_test.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -import jax -from jax._src import test_util as jtu - -from jax_ffi_example import counter - -jax.config.parse_flags_with_absl() - - -class CounterTests(jtu.JaxTestCase): - def setUp(self): - super().setUp() - if not jtu.test_device_matches(["cpu"]): - self.skipTest("Unsupported platform") - - def test_basic(self): - self.assertEqual(counter.counter(0), 0) - self.assertEqual(counter.counter(0), 1) - self.assertEqual(counter.counter(0), 2) - self.assertEqual(counter.counter(1), 0) - self.assertEqual(counter.counter(0), 3) - - def test_jit(self): - @jax.jit - def counter_fun(x): - return x, counter.counter(2) - - self.assertEqual(counter_fun(0)[1], 0) - self.assertEqual(counter_fun(0)[1], 1) - - # Persists across different cache hits - self.assertEqual(counter_fun(1)[1], 2) - - # Persists after the cache is cleared - counter_fun.clear_cache() - self.assertEqual(counter_fun(0)[1], 3) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/examples/ffi/tests/attrs_test.py b/examples/ffi/tests/cpu_examples_test.py similarity index 55% rename from examples/ffi/tests/attrs_test.py rename to examples/ffi/tests/cpu_examples_test.py index 2eef1f627006..cb2653d2e928 100644 --- a/examples/ffi/tests/attrs_test.py +++ b/examples/ffi/tests/cpu_examples_test.py @@ -18,7 +18,7 @@ import jax.numpy as jnp from jax._src import test_util as jtu -from jax_ffi_example import attrs +from jax_ffi_example import cpu_examples jax.config.parse_flags_with_absl() @@ -30,11 +30,11 @@ def setUp(self): self.skipTest("Unsupported platform") def test_array_attr(self): - self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum()) - self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum()) + self.assertEqual(cpu_examples.array_attr(5), jnp.arange(5).sum()) + self.assertEqual(cpu_examples.array_attr(3), jnp.arange(3).sum()) def test_array_attr_jit_cache(self): - jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,)) + jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,)) with jtu.count_jit_and_pmap_lowerings() as count: jit_array_attr(5) self.assertEqual(count[0], 1) # compiles once the first time @@ -44,22 +44,51 @@ def test_array_attr_jit_cache(self): def test_array_attr_no_jit(self): with jax.disable_jit(): - attrs.array_attr(5) # doesn't crash + cpu_examples.array_attr(5) # doesn't crash def test_dictionary_attr(self): - secret, count = attrs.dictionary_attr(secret=5) + secret, count = cpu_examples.dictionary_attr(secret=5) self.assertEqual(secret, 5) self.assertEqual(count, 1) - secret, count = attrs.dictionary_attr(secret=3, a_string="hello") + secret, count = cpu_examples.dictionary_attr(secret=3, a_string="hello") self.assertEqual(secret, 3) self.assertEqual(count, 2) with self.assertRaisesRegex(Exception, "Unexpected attribute"): - attrs.dictionary_attr() + cpu_examples.dictionary_attr() with self.assertRaisesRegex(Exception, "Wrong attribute type"): - attrs.dictionary_attr(secret="invalid") + cpu_examples.dictionary_attr(secret="invalid") + + +class CounterTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + def test_basic(self): + self.assertEqual(cpu_examples.counter(0), 0) + self.assertEqual(cpu_examples.counter(0), 1) + self.assertEqual(cpu_examples.counter(0), 2) + self.assertEqual(cpu_examples.counter(1), 0) + self.assertEqual(cpu_examples.counter(0), 3) + + def test_jit(self): + @jax.jit + def counter_fun(x): + return x, cpu_examples.counter(2) + + self.assertEqual(counter_fun(0)[1], 0) + self.assertEqual(counter_fun(0)[1], 1) + + # Persists across different cache hits + self.assertEqual(counter_fun(1)[1], 2) + + # Persists after the cache is cleared + counter_fun.clear_cache() + self.assertEqual(counter_fun(0)[1], 3) if __name__ == "__main__": diff --git a/examples/ffi/tests/cuda_e2e_test.py b/examples/ffi/tests/cuda_examples_test.py similarity index 96% rename from examples/ffi/tests/cuda_e2e_test.py rename to examples/ffi/tests/cuda_examples_test.py index 83397f7ff5d7..f4a736599ce4 100644 --- a/examples/ffi/tests/cuda_e2e_test.py +++ b/examples/ffi/tests/cuda_examples_test.py @@ -28,8 +28,8 @@ def setUp(self): self.skipTest("Unsupported platform") # Import here to avoid trying to load the library when it's not built. - from jax_ffi_example import cuda_e2e - self.foo = cuda_e2e.foo + from jax_ffi_example import cuda_examples + self.foo = cuda_examples.foo def test_fwd_interpretable(self): shape = (2, 3) From 914600a063e31a6fe935ea672fbac50af6c39542 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 25 Nov 2024 07:59:55 -0800 Subject: [PATCH 478/698] [Mosaic GPU] Simplify logic for pointwise splat operands The previous version of the code was too complicated and failed to account for the fact that in an op that broadcasts there does not necessarily exist and operand that has the output shape. Reading through the code now, it's a bit weird that we allow implicit broadcasting of operands with splat layouts, but not any other operands. But I guess that's a thing to implement later. PiperOrigin-RevId: 699983045 --- .../mosaic/gpu/fragmented_array.py | 42 +++++++------------ 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 094c683c1695..26157f407844 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -623,37 +623,27 @@ def to_layout(self, new_layout: FragmentedLayout): ) def _pointwise(self, op, *other, output_is_signed: bool | None = None): + # If our layout is a splat, then we should either dispatch to a non-splat + # layout, or broadcast ourselves to the output shape first. if isinstance(self.layout, WGSplatFragLayout): - # Find either the largest operand or an operand that has a - # concrete layout base the layout computation of that. - widest_idx = None + output_shape = self.shape for i, o in enumerate(other): if not isinstance(o, FragmentedArray): continue elif not isinstance(o.layout, WGSplatFragLayout): - widest_idx = i - break - elif not o.layout.can_broadcast_to(self.layout.shape): - # Note: equal shapes can be broadcast to each other. Using - # the negation we make sure to only consider strictly larger - # shapes so that we don't end up ping ponging between equal - # shapes. - widest_idx = i - - if widest_idx is not None: - # We need to retain the order of arguments that the op - # expects. - def _op(wide_o, self_o, *args): - pre_wide = args[:widest_idx - 1] - post_wide = args[widest_idx - 1:] - return op(self_o, *pre_wide, wide_o, *post_wide) - return other[widest_idx]._pointwise( - _op, - self, - *other[:widest_idx], - *other[widest_idx + 1:], - output_is_signed=output_is_signed, - ) + return o._pointwise( + lambda o, *args: op(*args[:i], o, *args[i:]), + self, + *other[:i], + *other[i + 1 :], + output_is_signed=output_is_signed, + ) + else: + output_shape = np.broadcast_shapes(output_shape, o.shape) + # If we get here then we haven't found any non-splat layout. + return self.broadcast(output_shape)._pointwise( + op, *other, output_is_signed=output_is_signed + ) other_arrs = [] for o in other: From e8934b95ebc228d86704a94f5a1db5b7c1b0f4df Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Mon, 25 Nov 2024 10:21:48 -0600 Subject: [PATCH 479/698] [ROCm] Add rocm version information --- jax_plugins/rocm/plugin_setup.py | 7 ++++++- jax_plugins/rocm/setup.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index a84a6b34ea48..d504d0a11666 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -22,6 +22,11 @@ project_name = f"jax-rocm{rocm_version}-plugin" package_name = f"jax_rocm{rocm_version}_plugin" +# Extract ROCm version from the `ROCM_PATH` environment variable. +default_rocm_path = "/opt/rocm" +rocm_path = os.getenv("ROCM_PATH", default_rocm_path) +rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown" + def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( 'version', os.path.join(pkg_path, 'version.py')) @@ -43,7 +48,7 @@ def has_ext_modules(self): name=project_name, version=__version__, cmdclass=_cmdclass, - description="JAX Plugin for AMD GPUs", + description=f"JAX Plugin for AMD GPUs (ROCm:{rocm_detected_version})", long_description="", long_description_content_type="text/markdown", author="Ruturaj4", diff --git a/jax_plugins/rocm/setup.py b/jax_plugins/rocm/setup.py index d131e732c91a..ec3eae2d8821 100644 --- a/jax_plugins/rocm/setup.py +++ b/jax_plugins/rocm/setup.py @@ -21,6 +21,11 @@ project_name = f"jax-rocm{rocm_version}-pjrt" package_name = f"jax_plugins.xla_rocm{rocm_version}" +# Extract ROCm version from the `ROCM_PATH` environment variable. +default_rocm_path = "/opt/rocm" +rocm_path = os.getenv("ROCM_PATH", default_rocm_path) +rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown" + def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( 'version', os.path.join(pkg_path, 'version.py')) @@ -41,7 +46,7 @@ def load_version_module(pkg_path): setup( name=project_name, version=__version__, - description="JAX XLA PJRT Plugin for AMD GPUs", + description=f"JAX XLA PJRT Plugin for AMD GPUs (ROCm:{rocm_detected_version})", long_description="", long_description_content_type="text/markdown", author="Ruturaj4", From aa05dc0b5cf8d9008519a4196bdfed7c619fee5c Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 25 Nov 2024 08:29:19 -0800 Subject: [PATCH 480/698] Automated Code Change PiperOrigin-RevId: 699991540 --- jaxlib/gpu/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index e888f6a42a9b..a5069cfb4a8e 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -20,6 +20,7 @@ load( "jax_visibility", "xla_py_proto_library", ) +# Placeholder: load proto_library licenses(["notice"]) From c35f8b22c1b081135e0644a936c262a974c75f07 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 25 Nov 2024 09:17:56 -0800 Subject: [PATCH 481/698] Add abstract mesh context manager to trace_context in the fallback path too (which will be deleted after jax 0.4.36 release) PiperOrigin-RevId: 700006186 --- jax/_src/config.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 2723b4f90d3b..43c29c996cfb 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -244,6 +244,7 @@ def trace_context(): tls = jax_jit.thread_local_state() axis_env_state = () mesh_context_manager = () + abstract_mesh_context_manager = () xla_metadata_context_manager = () compute_on_context_manager = () @@ -252,11 +253,14 @@ def trace_context(): axis_env_state = context.axis_env_state if context and context.mesh_context_manager: mesh_context_manager = context.mesh_context_manager + if context and context.abstract_mesh_context_manager: + abstract_mesh_context_manager = context.abstract_mesh_context_manager if context and context.xla_metadata_context_manager: xla_metadata_context_manager = context.xla_metadata_context_manager if context and context.compute_on_context_manager: compute_on_context_manager = context.compute_on_context_manager - return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, + return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager, + xla_metadata_context_manager, compute_on_context_manager, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, @@ -1014,6 +1018,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): trace_state: Any | None = None axis_env_state: Hashable = () mesh_context_manager: Hashable = () + abstract_mesh_context_manager: Hashable = () compute_on_context_manager: Hashable = () xla_metadata_context_manager: Hashable = () @@ -1080,6 +1085,7 @@ def set_local(self, value): trace_state = JitConfig('trace_state') axis_env_state = JitConfig('axis_env_state') mesh_context_manager = JitConfig('mesh_context_manager') + abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager') compute_on_context_manager = JitConfig('compute_on_context_manager') xla_metadata_context_manager = JitConfig('xla_metadata_context_manager') From 9866372d310a7546c86c397b0635071d26f6a3c9 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Mon, 25 Nov 2024 17:42:36 +0000 Subject: [PATCH 482/698] [cuda] Bump nvidia-cuda-nvcc-cu12 dependency to 12.6.85 --- jax_plugins/cuda/plugin_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 8e99907d7078..ce31684de46f 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -55,7 +55,7 @@ def has_ext_modules(self): 'with_cuda': [ "nvidia-cublas-cu12>=12.1.3.1", "nvidia-cuda-cupti-cu12>=12.1.105", - "nvidia-cuda-nvcc-cu12>=12.1.105", + "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", "nvidia-cudnn-cu12>=9.1,<10.0", "nvidia-cufft-cu12>=11.0.2.54", From bb1024f3fd1047e0f5e8f68b1a1fa41098900d53 Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Mon, 25 Nov 2024 10:30:27 -0800 Subject: [PATCH 483/698] [SDY] enable `cpu_shardy` for JAX shard_alike test. PiperOrigin-RevId: 700029576 --- tests/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index a645a971a799..cfebd88f99f2 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -279,9 +279,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], - disable_configs = [ - "cpu_shardy", # TODO(b/355263220): enable once shard_alike is supported. - ], enable_configs = [ "tpu_v3_2x2", "tpu_v5e_4x2", From 066859e62fd7dc86defec79b4af3f77c496bd3b2 Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Mon, 25 Nov 2024 11:09:09 -0800 Subject: [PATCH 484/698] [SDY] Enable `test_pjit_array_multi_input_multi_output` since Shardy conflict resolution is now complete. PiperOrigin-RevId: 700042542 --- tests/pjit_test.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 4d9b98b4d595..3340483ebc86 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1900,11 +1900,6 @@ def _checks(out, input_data): ) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): - if config.use_shardy_partitioner.value: - self.skipTest( - 'TODO(b/355263220) Shardy conflict resolution is not complete. Issue ' - 'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while ' - 'Shardy gives it fully replicated.') global_mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) From 84dc9bab3394260d1d71dc956a68e8ed392c8b98 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 25 Nov 2024 19:25:08 +0000 Subject: [PATCH 485/698] Update ROCm scripts to match new build.py usage --- build/rocm/dev_build_rocm.py | 9 +++++---- build/rocm/tools/build_wheels.py | 7 ++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/build/rocm/dev_build_rocm.py b/build/rocm/dev_build_rocm.py index 2be64152f667..aa5754b789d3 100755 --- a/build/rocm/dev_build_rocm.py +++ b/build/rocm/dev_build_rocm.py @@ -77,13 +77,14 @@ def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path): build_command = [ "python3", "./build/build.py", - "--enable_rocm", - "--build_gpu_plugin", - "--gpu_plugin_rocm_version=60", + "build" f"--use_clang={str(use_clang).lower()}", + "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt" + "--rocm_path=%/opt/rocm-{rocm_version}/", + "--rocm_version=60", f"--rocm_amdgpu_targets={rocm_target}", - f"--rocm_path=/opt/rocm-{rocm_version}/", bazel_options, + "--verbose" ] if clang_option: diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index deb6ab703391..ec825f40b7d2 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -93,11 +93,12 @@ def build_jaxlib_wheel( cmd = [ "python", "build/build.py", - "--enable_rocm", - "--build_gpu_plugin", - "--gpu_plugin_rocm_version=60", + "build" + "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt" "--rocm_path=%s" % rocm_path, + "--rocm_version=60", "--use_clang=%s" % use_clang, + "--verbose" ] # Add clang path if clang is used. From deab6fbd803a4f78384843a82fbe81e9099154f4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 25 Nov 2024 11:40:05 -0800 Subject: [PATCH 486/698] Remove _pjit_lower_cached cache. We can simplify the caching of jit as we have downstream caches and a cpp cache too. If you drop out of cpp cache, things are going to be slow anyways. PiperOrigin-RevId: 700052522 --- jax/_src/api.py | 1 - jax/_src/pjit.py | 10 +++------- jax/experimental/pjit.py | 2 -- tests/pjit_test.py | 40 ++++++++++++---------------------------- 4 files changed, 15 insertions(+), 38 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 8d464b51d741..308e7c230dc2 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2777,7 +2777,6 @@ def clear_backends(): dispatch.xla_primitive_callable.cache_clear() util.clear_all_caches() pjit._infer_params_cached.cache_clear() - pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache_fun_only.clear() pjit._cpp_pjit_cache_explicit_attributes.clear() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e0af0b7d2137..fb60f9d52727 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1665,7 +1665,8 @@ def _pjit_call_impl_python( compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) # TODO(patrios): Do not pass mutable profile session through cached lowering # chain. Instead we need to move profilers dictionary to pxla module and use - # module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode. + # module as key. Right now we can't do that since there is no way to evict + # _pjit_lower_cached cache for in PGLE mode. compiled = _resolve_and_lower( args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, @@ -1776,12 +1777,7 @@ def call_impl_cache_miss(*args_, **kwargs_): pjit_p.def_impl(_pjit_call_impl) -def _pjit_lower(*args, **kwargs): - return _pjit_lower_cached(*args, **kwargs) - - -@weakref_lru_cache -def _pjit_lower_cached( +def _pjit_lower( jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index a711c6bc472c..8ba7eb25d646 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -22,5 +22,3 @@ AUTO as AUTO, UNSPECIFIED as _UNSPECIFIED, ) - -from jax._src.pjit import _pjit_lower_cached, _pjit_lower diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3340483ebc86..e5aa1a604eae 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -52,7 +52,6 @@ from jax._src.sharding_impls import ( AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, SingleDeviceSharding, parse_flatten_op_sharding) -import jax._src.pjit as pjit_lib from jax._src.pjit import pjit from jax._src import mesh as mesh_lib from jax._src.interpreters import pxla @@ -2157,13 +2156,13 @@ def add(x, y): return x + y out = add(a, b) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + cache_info1 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, a + b) self.assertFalse(out._committed) out2 = add(out, out) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + cache_info2 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out2, array.ArrayImpl) self.assertArraysEqual(out2, 2 * (a + b)) self.assertFalse(out2._committed) @@ -2173,7 +2172,7 @@ def add(x, y): c = jax.device_put(a, jax.devices()[0]) out3 = add(c, c) - cache_info3 = pjit_lib._pjit_lower_cached.cache_info() + cache_info3 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(out3, 2 * c) self.assertTrue(out3._committed) @@ -2216,14 +2215,11 @@ def test_pjit_different_device_recompilation(self): f = pjit(lambda x: x) - out1 = f(a) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() - - out2 = f(b) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + with jtu.count_jit_compilation_cache_miss() as count: + out1 = f(a) + out2 = f(b) + self.assertEqual(count[0], 2) - self.assertEqual(cache_info2.hits, cache_info1.hits) - self.assertEqual(cache_info2.misses, cache_info1.misses + 1) self.assertArraysEqual(out1, val1) self.assertArraysEqual(out2, val2) @@ -2880,13 +2876,13 @@ def f(x, y, z): return x, y, z o1, o2, o3 = f(a, y=b, z=c) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + cache_info1 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(o1, a) self.assertArraysEqual(o2, b) self.assertArraysEqual(o3, c) o4, o5, o6 = f(x=a, y=b, z=c) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + cache_info2 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(o4, a) self.assertArraysEqual(o5, b) self.assertArraysEqual(o6, c) @@ -2895,7 +2891,7 @@ def f(x, y, z): self.assertEqual(cache_info2.misses, cache_info1.misses + 1) o7, o8, o9 = f(a, b, c) - cache_info3 = pjit_lib._pjit_lower_cached.cache_info() + cache_info3 = pxla._cached_lowering_to_hlo.cache_info() self.assertArraysEqual(o7, a) self.assertArraysEqual(o8, b) self.assertArraysEqual(o9, c) @@ -2982,26 +2978,19 @@ def _check(out, expected_device, expected_out): x = jnp.arange(8).reshape(4, 2) f_out = f(x) f_out2 = f(f_out) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() _check(f_out, jax.devices()[1], x) _check(f_out2, jax.devices()[1], f_out) y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y'))) out2 = f(y) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() _check(out2, jax.devices()[1], y) - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): h = pjit(mul, device=jax.devices()[-1]) h_out = h(y) - cache_info3 = pjit_lib._pjit_lower_cached.cache_info() _check(h_out, jax.devices()[-1], y) - self.assertEqual(cache_info3.hits, cache_info2.hits) - # AOT test compiled = f.lower(core.ShapedArray(y.shape, y.dtype)).compile() out3 = compiled(y) @@ -3531,11 +3520,11 @@ def mul(x): with jtu.count_pjit_cpp_cache_miss() as count: out = f(arr) - cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + cache_info1 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out.sharding, NamedSharding) out2 = f(np_arr) - cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + cache_info2 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out2.sharding, NamedSharding) # Drops out of C++ cache i.e. cache miss @@ -3624,13 +3613,11 @@ def test_jit_mul_sum_sharding_preserved(self): f = jax.jit(lambda x: x * 2) out = f(arr) cache_info1 = pxla._cached_compilation.cache_info() - pl_cache_info1 = pjit_lib._pjit_lower_cached.cache_info() self.assertIsInstance(out.sharding, NamedSharding) with jtu.count_pjit_cpp_cache_miss() as count: out2 = f(arr2) cache_info2 = pxla._cached_compilation.cache_info() - pl_cache_info2 = pjit_lib._pjit_lower_cached.cache_info() self.assertIsInstance(out2.sharding, PositionalSharding) # This will hit the cpp cache. @@ -3641,9 +3628,6 @@ def test_jit_mul_sum_sharding_preserved(self): self.assertEqual(cache_info2.hits, cache_info1.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) - self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits) - self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1) - out4 = jnp.sum(arr) self.assertIsInstance(out4.sharding, NamedSharding) From 107bc96c2911338eed4719a32f362a9225c64f38 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 25 Nov 2024 11:40:07 -0800 Subject: [PATCH 487/698] [Mosaic GPU] Support batch dimensions in FA3 MGPU kernel. PiperOrigin-RevId: 700052530 --- .../pallas/ops/gpu/attention_mgpu.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 1c5b4d9f741b..c4ac7e625942 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -59,17 +59,27 @@ def attention(q, k, v, config: TuningConfig): q_heads_per_kv_head = num_q_heads // num_kv_heads if head_dim % 64: raise ValueError(f"{head_dim=} must be divisible by 64") - if batch_size != 1: - raise NotImplementedError(f"Only batch_size=1 is supported, got: {batch_size=}") if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") - q, k, v = map(lambda x: x[0], (q, k, v)) + # Squash batch and sequence dimensions. + # This is required because CUDA grid/TMA descriptors have a limited number of + # slice dimensions. + # TODO(apaszke): Implement slice squashing for TMAs. + q = jnp.reshape(q, (batch_size * q_seq_len, num_q_heads, head_dim)) + k = jnp.reshape(k, (batch_size * kv_seq_len, num_kv_heads, head_dim)) + v = jnp.reshape(v, (batch_size * kv_seq_len, num_kv_heads, head_dim)) + max_concurrent_steps = min( config.max_concurrent_steps, kv_seq_len // config.block_kv ) block_q, block_kv = config.block_q, config.block_kv + num_q_tiles, rem = divmod(q_seq_len, block_q * 2) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + bidx = lax.div(lax.axis_index("bq"), num_q_tiles) + qidx = lax.rem(lax.axis_index("bq"), num_q_tiles) smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") qo_smem2, k_smem, v_smem = smem_buffers @@ -83,7 +93,7 @@ def perform_schedule_barrier(): def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] - q_seq_base = lax.axis_index("q") * (2 * block_q) + wg_idx * block_q + q_seq_base = qidx * (2 * block_q) + wg_idx * block_q + bidx * q_seq_len q_head = lax.axis_index("heads") plgpu.copy_gmem_to_smem( @@ -165,14 +175,16 @@ def _memory_wg(): plgpu.set_max_registers(40, action="decrease") kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) for i in range(max_concurrent_steps): - s = (pl.ds(i * block_kv, block_kv), kv_head) + start = i * block_kv + bidx * kv_seq_len + s = (pl.ds(start, block_kv), kv_head) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) - s = (pl.ds(tma_step * block_kv, block_kv), kv_head) + start = tma_step * block_kv + bidx * kv_seq_len + s = (pl.ds(start, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barrier) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barrier) @@ -187,13 +199,10 @@ def kv_epilogue(i, _): def run(refs): q_ref, k_ref, v_ref, out_ref = refs - num_q_tiles, rem = divmod(q_seq_len, block_q * 2) - if rem: - raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") mesh = plgpu.GPUMesh( - grid=(num_q_tiles, num_q_heads), + grid=(batch_size * num_q_tiles, num_q_heads), num_threads=3, - axis_names=("q", "heads", "wg"), + axis_names=("bq", "heads", "wg"), approx_math=True, ) @pl.core_map(mesh) @@ -227,7 +236,7 @@ def _kernel_entry(): ) _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) - return out[None] + return jnp.reshape(out, [batch_size, q_seq_len, num_q_heads, head_dim]) @jax.jit @@ -247,13 +256,13 @@ def attention_reference(q, k, v): def main(unused_argv): - batch_size = 1 num_q_heads = 1 num_kv_heads = 1 - problem_it = itertools.product((4096, 32768,), (64, 128, 256,)) - for seq_len, head_dim in problem_it: + problem_it = itertools.product((1, 2), (4096, 32768,), (64, 128, 256,)) + for batch_size, seq_len, head_dim in problem_it: q_seq_len = kv_seq_len = seq_len - print(f"==== {kv_seq_len=:<6} {q_seq_len=:<6} {num_q_heads=:<4} {head_dim=:<6} ====") + print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}" + f"{num_q_heads=:<4} {head_dim=:<6} ====") param_it = itertools.product((64,), (64, 128, 256)) best = None k1, k2, k3 = jax.random.split(jax.random.key(42), 3) From 95029abc18081de001a1dcb4c09920caa5a6fbb7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 25 Nov 2024 11:41:03 -0800 Subject: [PATCH 488/698] drop compute capability check PiperOrigin-RevId: 700052796 --- jax/_src/cudnn/fused_attention_stablehlo.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 8ccf08ec643c..ef4e33ad0665 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -1022,11 +1022,9 @@ def dot_product_attention(query: Array, Returns: Output of the same shape as the query. """ - # check if cuDNN is installed + # TODO(b/380898464): Check the compute capability, e.g., require GPU device, + # in the kernel implementation (c++) code. cudnn_version = check_cudnn_version() - # only support at least Ampere - if not check_compute_capability("8.0"): - raise RuntimeError("Require at least Ampere arch to run") layout = _normalize_layout(qkv_layout) if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") @@ -1047,7 +1045,7 @@ def dot_product_attention(query: Array, # combine bias and mask if bias is None: - bias = mask + bias = mask else: if mask is not None: # should be broadcast to same shape From f22bafac31a6d16546edfe49e0e5d545dcb9a55d Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Mon, 25 Nov 2024 11:43:09 -0800 Subject: [PATCH 489/698] [SDY] remove TODO for enabling Layouts for Shardy post cl/697715276. PiperOrigin-RevId: 700053383 --- tests/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/BUILD b/tests/BUILD index cfebd88f99f2..f0668b42b309 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -260,7 +260,6 @@ jax_multiplatform_test( ], ) -# TODO(b/355263220): enable on TPU once layouts is supported with Shardy. jax_multiplatform_test( name = "layout_test", srcs = ["layout_test.py"], From 676151265859f8b0dd8baf6f6ae50c3367ed0509 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 25 Nov 2024 13:02:22 -0800 Subject: [PATCH 490/698] Re-factor build CLI to a subcommand based approach This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script. Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions. There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time. Usage: * Building `jaxlib`: ``` python build/build.py build --wheels=jaxlib --python_version=3.10 ``` * Building `jax-cuda-plugin`: ``` python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building multiple packages: ``` python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10 ``` * Building `jax-rocm-pjrt`: ``` python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm ``` * Using a local XLA path: ``` python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla ``` * Updating requirements_lock.txt files: ``` python build/build.py requirements_update --python_version=3.10 ``` For more details on each argument and to see available options, run: ``` python build/build.py build --help ``` or ``` python build/build.py requirements_update --help ``` PiperOrigin-RevId: 700075411 --- .bazelrc | 5 + .github/workflows/asan.yaml | 2 +- .github/workflows/wheel_win_x64.yml | 2 +- .github/workflows/windows_ci.yml | 5 +- CHANGELOG.md | 5 + build/build.py | 957 +++++++++++++++------------- build/tools/command.py | 111 ++++ build/tools/utils.py | 89 ++- docs/developer.md | 60 +- third_party/xla/workspace.bzl | 2 +- 10 files changed, 713 insertions(+), 525 deletions(-) create mode 100644 build/tools/command.py diff --git a/.bazelrc b/.bazelrc index 98bca5901d47..6ef7d4493937 100644 --- a/.bazelrc +++ b/.bazelrc @@ -183,6 +183,7 @@ build:macos_cache_push --config=macos_cache --remote_upload_local_results=true - build:ci_linux_x86_64 --config=avx_linux --config=avx_posix build:ci_linux_x86_64 --config=mkl_open_source_only build:ci_linux_x86_64 --config=clang --verbose_failures=true +build:ci_linux_x86_64 --color=yes # TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA # toolchain for both CPU and GPU builds. @@ -203,6 +204,7 @@ build:ci_linux_x86_64_cuda --config=ci_linux_x86_64 # Linux Aarch64 CI configs build:ci_linux_aarch64_base --config=clang --verbose_failures=true build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" +build:ci_linux_aarch64_base --color=yes build:ci_linux_aarch64 --config=ci_linux_aarch64_base build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" @@ -221,11 +223,13 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm build:ci_darwin_x86_64 --macos_minimum_os=10.14 build:ci_darwin_x86_64 --config=macos_cache_push build:ci_darwin_x86_64 --verbose_failures=true +build:ci_darwin_x86_64 --color=yes # Mac Arm64 CI configs build:ci_darwin_arm64 --macos_minimum_os=11.0 build:ci_darwin_arm64 --config=macos_cache_push build:ci_darwin_arm64 --verbose_failures=true +build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows @@ -233,6 +237,7 @@ build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=tru build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE +build:ci_windows_amd64 --color=yes # ############################################################################# # RBE config options below. These inherit the CI configs above and set the diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index ea87d4e29e40..d261ba3a09c2 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -65,7 +65,7 @@ jobs: run: | source ${GITHUB_WORKSPACE}/venv/bin/activate cd jax - python build/build.py \ + python build/build.py build --wheels=jaxlib --verbose \ --bazel_options=--color=yes \ --bazel_options=--copt=-fsanitize=address \ --clang_path=/usr/bin/clang-18 diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 2b4a616e224a..3904bf1b8f10 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -40,7 +40,7 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` + python.exe build\build.py build --wheels=jaxlib ` --bazel_options=--color=yes ` --bazel_options=--config=win_clang ` --verbose diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 3173b81e6819..4c404ef4cb75 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -49,9 +49,10 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` + python.exe build\build.py build --wheels=jaxlib ` --bazel_options=--color=yes ` - --bazel_options=--config=win_clang + --bazel_options=--config=win_clang ` + --verbose - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index be9aaebcd615..ce8b040439c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `platforms` instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. + * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and + replaces previous build.py usage. Run `python build/build.py --help` for + more details. Brief overview of the new subcommand options: + * `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt` + * `requirements_update`: Updates requirements_lock.txt files. * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` on the function inputs. diff --git a/build/build.py b/build/build.py index 62e4217c10a2..12ad0fa3b011 100755 --- a/build/build.py +++ b/build/build.py @@ -14,94 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# Helper script for building JAX's libjax easily. +# CLI for building JAX wheel packages from source and for updating the +# requirements_lock.txt files import argparse +import asyncio import logging import os import platform -import textwrap +import sys +import copy -from tools import utils +from tools import command, utils +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) -def write_bazelrc(*, remote_build, - cuda_version, cudnn_version, rocm_toolkit_path, - cpu, cuda_compute_capabilities, - rocm_amdgpu_targets, target_cpu_features, - wheel_cpu, enable_mkl_dnn, use_clang, clang_path, - clang_major_version, python_version, - enable_cuda, enable_nccl, enable_rocm, - use_cuda_nvcc): - - with open("../.jax_configure.bazelrc", "w") as f: - if not remote_build: - f.write(textwrap.dedent("""\ - build --strategy=Genrule=standalone - """)) - - if use_clang: - f.write(f'build --action_env CLANG_COMPILER_PATH="{clang_path}"\n') - f.write(f'build --repo_env CC="{clang_path}"\n') - f.write(f'build --repo_env BAZEL_COMPILER="{clang_path}"\n') - f.write('build --copt=-Wno-error=unused-command-line-argument\n') - if clang_major_version in (16, 17, 18): - # Necessary due to XLA's old version of upb. See: - # https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505 - f.write("build --copt=-Wno-gnu-offsetof-extensions\n") - - if rocm_toolkit_path: - f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" - .format(rocm_toolkit_path=rocm_toolkit_path)) - if rocm_amdgpu_targets: - f.write( - f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"\n') - if cpu is not None: - f.write(f"build --cpu={cpu}\n") - - if target_cpu_features == "release": - if wheel_cpu == "x86_64": - f.write("build --config=avx_windows\n" if utils.is_windows() - else "build --config=avx_posix\n") - elif target_cpu_features == "native": - if utils.is_windows(): - print("--target_cpu_features=native is not supported on Windows; ignoring.") - else: - f.write("build --config=native_arch_posix\n") - - if enable_mkl_dnn: - f.write("build --config=mkl_open_source_only\n") - if enable_cuda: - f.write("build --config=cuda\n") - if use_cuda_nvcc: - f.write("build --config=build_cuda_with_nvcc\n") - else: - f.write("build --config=build_cuda_with_clang\n") - f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if cuda_version: - f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" - .format(cuda_version=cuda_version)) - if cudnn_version: - f.write("build --repo_env HERMETIC_CUDNN_VERSION=\"{cudnn_version}\"\n" - .format(cudnn_version=cudnn_version)) - if cuda_compute_capabilities: - f.write( - f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') - if enable_rocm: - f.write("build --config=rocm_base\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if use_clang: - f.write("build --config=rocm\n") - f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n") - if python_version: - f.write( - "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( - python_version=python_version)) BANNER = r""" _ _ __ __ | | / \ \ \/ / @@ -112,421 +43,559 @@ def write_bazelrc(*, remote_build, """ EPILOG = """ +From the root directory of the JAX repository, run + `python build/build.py build --wheels=` to build JAX + artifacts. -From the 'build' directory in the JAX repository, run - python build.py -or - python3 build.py -to download and build JAX's XLA (jaxlib) dependency. -""" + Multiple wheels can be built with a single invocation of the CLI. + E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin + To update the requirements_lock.txt files, run + `python build/build.py requirements_update` +""" -def _parse_string_as_bool(s): - """Parses a string as a boolean argument.""" - lower = s.lower() - if lower == "true": - return True - elif lower == "false": - return False - else: - raise ValueError(f"Expected either 'true' or 'false'; got {s}") +# Define the build target for each wheel. +WHEEL_BUILD_TARGET_DICT = { + "jaxlib": "//jaxlib/tools:build_wheel", + "jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", + "jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", +} -def add_boolean_argument(parser, name, default=False, help_str=None): - """Creates a boolean flag.""" - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--" + name, - nargs="?", - default=default, - const=True, - type=_parse_string_as_bool, - help=help_str) - group.add_argument("--no" + name, dest=name, action="store_false") +def add_global_arguments(parser: argparse.ArgumentParser): + """Adds all the global arguments that applies to all the CLI subcommands.""" + parser.add_argument( + "--python_version", + type=str, + choices=["3.10", "3.11", "3.12", "3.13"], + default=f"{sys.version_info.major}.{sys.version_info.minor}", + help= + """ + Hermetic Python version to use. Default is to use the version of the + Python binary that executed the CLI. + """, + ) + bazel_group = parser.add_argument_group('Bazel Options') + bazel_group.add_argument( + "--bazel_path", + type=str, + default="", + help=""" + Path to the Bazel binary to use. The default is to find bazel via the + PATH; if none is found, downloads a fresh copy of Bazel from GitHub. + """, + ) -def _get_editable_output_paths(output_path): - """Returns the paths to the editable wheels.""" - return ( - os.path.join(output_path, "jaxlib"), - os.path.join(output_path, "jax_gpu_pjrt"), - os.path.join(output_path, "jax_gpu_plugin"), + bazel_group.add_argument( + "--bazel_startup_options", + action="append", + default=[], + help=""" + Additional startup options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_startup_options='--nobatch' + """, ) + bazel_group.add_argument( + "--bazel_options", + action="append", + default=[], + help=""" + Additional build options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_options='--local_resources=HOST_CPUS' + """, + ) -def main(): - cwd = os.getcwd() - parser = argparse.ArgumentParser( - description="Builds jaxlib from source.", epilog=EPILOG) - add_boolean_argument( - parser, - "verbose", - default=False, - help_str="Should we produce verbose debugging output?") - parser.add_argument( - "--bazel_path", - help="Path to the Bazel binary to use. The default is to find bazel via " - "the PATH; if none is found, downloads a fresh copy of bazel from " - "GitHub.") - parser.add_argument( - "--python_bin_path", - help="Path to Python binary whose version to match while building with " - "hermetic python. The default is the Python interpreter used to run the " - "build script. DEPRECATED: use --python_version instead.") parser.add_argument( - "--target_cpu_features", - choices=["release", "native", "default"], - default="release", - help="What CPU features should we target? 'release' enables CPU " - "features that should be enabled for a release build, which on " - "x86-64 architectures enables AVX. 'native' enables " - "-march=native, which generates code targeted to use all " - "features of the current machine. 'default' means don't opt-in " - "to any architectural features and use whatever the C compiler " - "generates by default.") - add_boolean_argument( - parser, - "use_clang", - default = "true", - help_str=( - "DEPRECATED: This flag is redundant because clang is " - "always used as default compiler." - ), + "--dry_run", + action="store_true", + help="Prints the Bazel command that is going to be executed.", ) + parser.add_argument( - "--clang_path", - help=( - "Path to clang binary to use. The default is " - "to find clang via the PATH." - ), - ) - add_boolean_argument( - parser, - "enable_mkl_dnn", - default=True, - help_str="Should we build with MKL-DNN enabled?", + "--verbose", + action="store_true", + help="Produce verbose output for debugging.", ) - add_boolean_argument( - parser, - "enable_cuda", - help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN." + + +def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): + """Adds all the arguments that applies to the artifact subcommands.""" + parser.add_argument( + "--wheels", + type=str, + default="jaxlib", + help= + """ + A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib", + --wheels="jaxlib,jax-cuda-plugin", etc. + Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt, + jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt + """, ) - add_boolean_argument( - parser, - "use_cuda_nvcc", - default=True, - help_str=( - "Should we build CUDA code using NVCC compiler driver? The default value " - "is true. If --nouse_cuda_nvcc flag is used then CUDA code is built " - "by clang compiler." - ), + + parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' build instead of a wheel.", ) - add_boolean_argument( - parser, - "build_gpu_plugin", - default=False, - help_str=( - "Are we building the gpu plugin in addition to jaxlib? The GPU " - "plugin is still experimental and is not ready for use yet." - ), + + parser.add_argument( + "--output_path", + type=str, + default=os.path.join(os.getcwd(), "dist"), + help="Directory to which the JAX wheel packages should be written.", ) + parser.add_argument( - "--build_gpu_kernel_plugin", - choices=["cuda", "rocm"], - default="", - help=( - "Specify 'cuda' or 'rocm' to build the respective kernel plugin." - " When this flag is set, jaxlib will not be built." - ), + "--configure_only", + action="store_true", + help=""" + If true, writes the Bazel options to the .jax_configure.bazelrc file but + does not build the artifacts. + """, ) - add_boolean_argument( - parser, - "build_gpu_pjrt_plugin", - default=False, - help_str=( - "Are we building the cuda/rocm pjrt plugin? jaxlib will not be built " - "when this flag is True." - ), + + # CUDA Options + cuda_group = parser.add_argument_group('CUDA Options') + cuda_group.add_argument( + "--cuda_version", + type=str, + help= + """ + Hermetic CUDA version to use. Default is to use the version specified + in the .bazelrc. + """, ) - parser.add_argument( - "--gpu_plugin_cuda_version", - choices=["12"], + + cuda_group.add_argument( + "--cuda_major_version", + type=str, default="12", - help="Which CUDA major version the gpu plugin is for.") - parser.add_argument( - "--gpu_plugin_rocm_version", - choices=["60"], - default="60", - help="Which ROCM major version the gpu plugin is for.") - add_boolean_argument( - parser, - "enable_rocm", - help_str="Should we build with ROCm enabled?") - add_boolean_argument( - parser, - "enable_nccl", - default=True, - help_str="Should we build with NCCL enabled? Has no effect for non-CUDA " - "builds.") - add_boolean_argument( - parser, - "remote_build", - default=False, - help_str="Should we build with RBE (Remote Build Environment)?") - parser.add_argument( - "--cuda_version", - default=None, - help="CUDA toolkit version, e.g., 12.3.2") - parser.add_argument( + help= + """ + Which CUDA major version should the wheel be tagged as? Auto-detected if + --cuda_version is set. When --cuda_version is not set, the default is to + set the major version to 12 to match the default in .bazelrc. + """, + ) + + cuda_group.add_argument( "--cudnn_version", - default=None, - help="CUDNN version, e.g., 8.9.7.29") - # Caution: if changing the default list of CUDA capabilities, you should also - # update the list in .bazelrc, which is used for wheel builds. - parser.add_argument( + type=str, + help= + """ + Hermetic cuDNN version to use. Default is to use the version specified + in the .bazelrc. + """, + ) + + cuda_group.add_argument( + "--disable_nccl", + action="store_true", + help="Should NCCL be disabled?", + ) + + cuda_group.add_argument( "--cuda_compute_capabilities", + type=str, default=None, - help="A comma-separated list of CUDA compute capabilities to support.") - parser.add_argument( + help= + """ + A comma-separated list of CUDA compute capabilities to support. Default + is to use the values specified in the .bazelrc. + """, + ) + + cuda_group.add_argument( + "--build_cuda_with_clang", + action="store_true", + help=""" + Should CUDA code be compiled using Clang? The default behavior is to + compile CUDA with NVCC. + """, + ) + + # ROCm Options + rocm_group = parser.add_argument_group('ROCm Options') + rocm_group.add_argument( + "--rocm_version", + type=str, + default="60", + help="ROCm version to use", + ) + + rocm_group.add_argument( "--rocm_amdgpu_targets", + type=str, default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100", - help="A comma-separated list of ROCm amdgpu targets to support.") - parser.add_argument( + help="A comma-separated list of ROCm amdgpu targets to support.", + ) + + rocm_group.add_argument( "--rocm_path", - default=None, - help="Path to the ROCm toolkit.") - parser.add_argument( - "--bazel_startup_options", - action="append", default=[], - help="Additional startup options to pass to bazel.") - parser.add_argument( - "--bazel_options", - action="append", default=[], - help="Additional options to pass to the main Bazel command to be " - "executed, e.g. `run`.") - parser.add_argument( - "--output_path", - default=os.path.join(cwd, "dist"), - help="Directory to which the jaxlib wheel should be written") - parser.add_argument( - "--target_cpu", - default=None, - help="CPU platform to target. Default is the same as the host machine. " - "Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.") - parser.add_argument( - "--editable", + type=str, + default="", + help="Path to the ROCm toolkit.", + ) + + # Compile Options + compile_group = parser.add_argument_group('Compile Options') + + compile_group.add_argument( + "--use_clang", + type=utils._parse_string_as_bool, + default="true", + const=True, + nargs="?", + help=""" + Whether to use Clang as the compiler. Not recommended to set this to + False as JAX uses Clang as the default compiler. + """, + ) + + compile_group.add_argument( + "--clang_path", + type=str, + default="", + help=""" + Path to the Clang binary to use. + """, + ) + + compile_group.add_argument( + "--disable_mkl_dnn", action="store_true", - help="Create an 'editable' jaxlib build instead of a wheel.") - parser.add_argument( - "--python_version", + help=""" + Disables MKL-DNN. + """, + ) + + compile_group.add_argument( + "--target_cpu_features", + choices=["release", "native", "default"], + default="release", + help=""" + What CPU features should we target? Release enables CPU features that + should be enabled for a release build, which on x86-64 architectures + enables AVX. Native enables -march=native, which generates code targeted + to use all features of the current machine. Default means don't opt-in + to any architectural features and use whatever the C compiler generates + by default. + """, + ) + + compile_group.add_argument( + "--target_cpu", default=None, - help="hermetic python version, e.g., 3.10") - add_boolean_argument( - parser, - "configure_only", - default=False, - help_str="If true, writes a .bazelrc file but does not build jaxlib.") - add_boolean_argument( - parser, - "requirements_update", - default=False, - help_str="If true, writes a .bazelrc and updates requirements_lock.txt " - "for a corresponding version of Python but does not build " - "jaxlib.") - add_boolean_argument( - parser, - "requirements_nightly_update", - default=False, - help_str="Same as update_requirements, but will consider dev, nightly " - "and pre-release versions of packages.") + help="CPU platform to target. Default is the same as the host machine.", + ) + + compile_group.add_argument( + "--local_xla_path", + type=str, + default=os.environ.get("JAXCI_XLA_GIT_DIR", ""), + help=""" + Path to local XLA repository to use. If not set, Bazel uses the XLA at + the pinned version in workspace.bzl. + """, + ) + +async def main(): + parser = argparse.ArgumentParser( + description=r""" + CLI for building JAX wheel packages from source and for updating the + requirements_lock.txt files + """, + epilog=EPILOG, + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + # Create subparsers for build and requirements_update + subparsers = parser.add_subparsers(dest="command", required=True) + + # requirements_update subcommand + requirements_update_parser = subparsers.add_parser( + "requirements_update", help="Updates the requirements_lock.txt files" + ) + requirements_update_parser.add_argument( + "--nightly_update", + action="store_true", + help=""" + If true, updates requirements_lock.txt for a corresponding version of + Python and will consider dev, nightly and pre-release versions of + packages. + """, + ) + add_global_arguments(requirements_update_parser) + + # Artifact build subcommand + build_artifact_parser = subparsers.add_parser( + "build", help="Builds the jaxlib, plugin, and pjrt artifact" + ) + add_artifact_subcommand_arguments(build_artifact_parser) + add_global_arguments(build_artifact_parser) + + arch = platform.machine() + os_name = platform.system().lower() args = parser.parse_args() - logging.basicConfig() + logger.info("%s", BANNER) + if args.verbose: - logger.setLevel(logging.DEBUG) + logging.getLogger().setLevel(logging.DEBUG) + logger.info("Verbose logging enabled") + + bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) + + logging.debug("Bazel path: %s", bazel_path) + logging.debug("Bazel version: %s", bazel_version) + + executor = command.SubprocessExecutor() - if args.enable_cuda and args.enable_rocm: - parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.") + # Start constructing the Bazel command + bazel_command_base = command.CommandBuilder(bazel_path) + + if args.bazel_startup_options: + logging.debug( + "Additional Bazel startup options: %s", args.bazel_startup_options + ) + for option in args.bazel_startup_options: + bazel_command_base.append(option) + + bazel_command_base.append("run") + + if args.python_version: + logging.debug("Hermetic Python version: %s", args.python_version) + bazel_command_base.append( + f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}" + ) - print(BANNER) + # Enable verbose failures. + bazel_command_base.append("--verbose_failures=true") + + # Requirements update subcommand execution + if args.command == "requirements_update": + requirements_command = copy.deepcopy(bazel_command_base) + if args.bazel_options: + logging.debug( + "Using additional build options: %s", args.bazel_options + ) + for option in args.bazel_options: + requirements_command.append(option) + + if args.nightly_update: + logging.info( + "--nightly_update is set. Bazel will run" + " //build:requirements_nightly.update" + ) + requirements_command.append("//build:requirements_nightly.update") + else: + requirements_command.append("//build:requirements.update") - output_path = os.path.abspath(args.output_path) - os.chdir(os.path.dirname(__file__ or args.prog) or '.') + await executor.run(requirements_command.get_command_as_string(), args.dry_run) + sys.exit(0) - host_cpu = platform.machine() wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", "ppc": "ppc64le", "aarch64": "aarch64", } - # TODO(phawkins): support other bazel cpu overrides. - wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None - else host_cpu) - - # Find a working Bazel. - bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) - print(f"Bazel binary path: {bazel_path}") - print(f"Bazel version: {bazel_version}") - - if args.python_version: - python_version = args.python_version - else: - python_bin_path = utils.get_python_bin_path(args.python_bin_path) - print(f"Python binary path: {python_bin_path}") - python_version = utils.get_python_version(python_bin_path) - print("Python version: {}".format(".".join(map(str, python_version)))) - utils.check_python_version(python_version) - python_version = ".".join(map(str, python_version)) - - print("Use clang: {}".format("yes" if args.use_clang else "no")) - clang_path = args.clang_path - clang_major_version = None - if args.use_clang: - if not clang_path: - clang_path = utils.get_clang_path_or_exit() - print(f"clang path: {clang_path}") - clang_major_version = utils.get_clang_major_version(clang_path) - - print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no")) - print(f"Target CPU: {wheel_cpu}") - print(f"Target CPU features: {args.target_cpu_features}") - - rocm_toolkit_path = args.rocm_path - print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no")) - if args.enable_cuda: - if args.cuda_compute_capabilities is not None: - print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}") - if args.cuda_version: - print(f"CUDA version: {args.cuda_version}") - if args.cudnn_version: - print(f"CUDNN version: {args.cudnn_version}") - print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no")) - - print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no")) - if args.enable_rocm: - if rocm_toolkit_path: - print(f"ROCm toolkit path: {rocm_toolkit_path}") - print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}") - - write_bazelrc( - remote_build=args.remote_build, - cuda_version=args.cuda_version, - cudnn_version=args.cudnn_version, - rocm_toolkit_path=rocm_toolkit_path, - cpu=args.target_cpu, - cuda_compute_capabilities=args.cuda_compute_capabilities, - rocm_amdgpu_targets=args.rocm_amdgpu_targets, - target_cpu_features=args.target_cpu_features, - wheel_cpu=wheel_cpu, - enable_mkl_dnn=args.enable_mkl_dnn, - use_clang=args.use_clang, - clang_path=clang_path, - clang_major_version=clang_major_version, - python_version=python_version, - enable_cuda=args.enable_cuda, - enable_nccl=args.enable_nccl, - enable_rocm=args.enable_rocm, - use_cuda_nvcc=args.use_cuda_nvcc, + target_cpu = ( + wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch ) - if args.requirements_update or args.requirements_nightly_update: - if args.requirements_update: - task = "//build:requirements.update" - else: # args.requirements_nightly_update - task = "//build:requirements_nightly.update" - update_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", task, *args.bazel_options]) - print(" ".join(update_command)) - utils.shell(update_command) - return - - if args.configure_only: - return - - print("\nBuilding XLA and installing it in the jaxlib source tree...") - - command_base = ( - bazel_path, - *args.bazel_startup_options, - "run", - "--verbose_failures=true", - *args.bazel_options, - ) - - if args.build_gpu_plugin and args.editable: - output_path_jaxlib, output_path_jax_pjrt, output_path_jax_kernel = ( - _get_editable_output_paths(output_path) + if args.local_xla_path: + logging.debug("Local XLA path: %s", args.local_xla_path) + bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") + + if args.target_cpu: + logging.debug("Target CPU: %s", args.target_cpu) + bazel_command_base.append(f"--cpu={args.target_cpu}") + + if args.disable_nccl: + logging.debug("Disabling NCCL") + bazel_command_base.append("--config=nonccl") + + git_hash = utils.get_githash() + + # Wheel build command execution + for wheel in args.wheels.split(","): + # Allow CUDA/ROCm wheels without the "jax-" prefix. + if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: + wheel = "jax-" + wheel + + if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + logging.error( + "Incorrect wheel name provided, valid choices are jaxlib," + " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," + " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" + ) + sys.exit(1) + + wheel_build_command = copy.deepcopy(bazel_command_base) + print("\n") + logger.info( + "Building %s for %s %s...", + wheel, + os_name, + arch, ) - else: - output_path_jaxlib = output_path - output_path_jax_pjrt = output_path - output_path_jax_kernel = output_path - - if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: - build_cpu_wheel_command = [ - *command_base, - "//jaxlib/tools:build_wheel", - "--", - f"--output_path={output_path_jaxlib}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.build_gpu_plugin: - build_cpu_wheel_command.append("--skip_gpu_kernels") - if args.editable: - build_cpu_wheel_command.append("--editable") - print(" ".join(build_cpu_wheel_command)) - utils.shell(build_cpu_wheel_command) - - if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ - (args.build_gpu_kernel_plugin == "rocm"): - build_gpu_kernels_command = [ - *command_base, - "//jaxlib/tools:build_gpu_kernels_wheel", - "--", - f"--output_path={output_path_jax_kernel}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + + clang_path = "" + if args.use_clang: + clang_path = args.clang_path or utils.get_clang_path_or_exit() + clang_major_version = utils.get_clang_major_version(clang_path) + logging.debug( + "Using Clang as the compiler, clang path: %s, clang version: %s", + clang_path, + clang_major_version, + ) + + # Use double quotes around clang path to avoid path issues on Windows. + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + else: + logging.debug("Use Clang: False") + + # Do not apply --config=clang on Mac as these settings do not apply to + # Apple Clang. + if os_name != "darwin": + wheel_build_command.append("--config=clang") + + if not args.disable_mkl_dnn: + logging.debug("Enabling MKL DNN") + wheel_build_command.append("--config=mkl_open_source_only") + + if args.target_cpu_features == "release": + if arch in ["x86_64", "AMD64"]: + logging.debug( + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) + wheel_build_command.append( + "--config=avx_windows" + if os_name == "windows" + else "--config=avx_posix" + ) + elif wheel_build_command == "native": + if os_name == "windows": + logger.warning( + "--target_cpu_features=native is not supported on Windows;" + " ignoring." + ) + else: + logging.debug("Using native cpu features: --config=native_arch_posix") + wheel_build_command.append("--config=native_arch_posix") else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_gpu_kernels_command.append("--editable") - print(" ".join(build_gpu_kernels_command)) - utils.shell(build_gpu_kernels_command) - - if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: - build_pjrt_plugin_command = [ - *command_base, - "//jaxlib/tools:build_gpu_plugin_wheel", - "--", - f"--output_path={output_path_jax_pjrt}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + logging.debug("Using default cpu features") + + if "cuda" in wheel: + wheel_build_command.append("--config=cuda") + wheel_build_command.append( + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" + ) + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + wheel_build_command.append("--config=build_cuda_with_clang") + else: + logging.debug("Building CUDA with NVCC") + wheel_build_command.append("--config=build_cuda_with_nvcc") + + if args.cuda_version: + logging.debug("Hermetic CUDA version: %s", args.cuda_version) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" + ) + if args.cudnn_version: + logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" + ) + if args.cuda_compute_capabilities: + logging.debug( + "Hermetic CUDA compute capabilities: %s", + args.cuda_compute_capabilities, + ) + wheel_build_command.append( + f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" + ) + + if "rocm" in wheel: + wheel_build_command.append("--config=rocm_base") + if args.use_clang: + wheel_build_command.append("--config=rocm") + wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + if args.rocm_path: + logging.debug("ROCm tookit path: %s", args.rocm_path) + wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") + if args.rocm_amdgpu_targets: + logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) + wheel_build_command.append( + f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" + ) + + # Append additional build options at the end to override any options set in + # .bazelrc or above. + if args.bazel_options: + logging.debug( + "Additional Bazel build options: %s", args.bazel_options + ) + for option in args.bazel_options: + wheel_build_command.append(option) + + with open(".jax_configure.bazelrc", "w") as f: + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) + if not jax_configure_options: + logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + sys.exit(1) + f.write(jax_configure_options) + logging.info("Bazel options written to .jax_configure.bazelrc") + + if args.configure_only: + logging.info("--configure_only is set so not running any Bazel commands.") else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_pjrt_plugin_command.append("--editable") - print(" ".join(build_pjrt_plugin_command)) - utils.shell(build_pjrt_plugin_command) + # Append the build target to the Bazel command. + build_target = WHEEL_BUILD_TARGET_DICT[wheel] + wheel_build_command.append(build_target) + + wheel_build_command.append("--") + + output_path = args.output_path + logger.debug("Artifacts output directory: %s", output_path) + + if args.editable: + logger.info("Building an editable build") + output_path = os.path.join(output_path, wheel) + wheel_build_command.append("--editable") + + wheel_build_command.append(f'--output_path="{output_path}"') + wheel_build_command.append(f"--cpu={target_cpu}") + + if "cuda" in wheel: + wheel_build_command.append("--enable-cuda=True") + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] + else: + cuda_major_version = args.cuda_major_version + wheel_build_command.append(f"--platform_version={cuda_major_version}") + + if "rocm" in wheel: + wheel_build_command.append("--enable-rocm=True") + wheel_build_command.append(f"--platform_version={args.rocm_version}") + + wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") - utils.shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) + await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/build/tools/command.py b/build/tools/command.py new file mode 100644 index 000000000000..48a9bfc1c0d6 --- /dev/null +++ b/build/tools/command.py @@ -0,0 +1,111 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for the JAX build CLI for running subprocess commands. +import asyncio +import dataclasses +import datetime +import os +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + +class CommandBuilder: + def __init__(self, base_command: str): + self.command = [base_command] + + def append(self, parameter: str): + self.command.append(parameter) + return self + + def get_command_as_string(self) -> str: + return " ".join(self.command) + + def get_command_as_list(self) -> list[str]: + return self.command + +@dataclasses.dataclass +class CommandResult: + """ + Represents the result of executing a subprocess command. + """ + + command: str + return_code: int = 2 # Defaults to not successful + logs: str = "" + start_time: datetime.datetime = dataclasses.field( + default_factory=datetime.datetime.now + ) + end_time: Optional[datetime.datetime] = None + + +async def _process_log_stream(stream, result: CommandResult): + """Logs the output of a subprocess stream.""" + while True: + line_bytes = await stream.readline() + if not line_bytes: + break + line = line_bytes.decode().rstrip() + result.logs += line + logger.info("%s", line) + + +class SubprocessExecutor: + """ + Manages execution of subprocess commands with reusable environment and logging. + """ + + def __init__(self, environment: Dict[str, str] = None): + """ + + Args: + environment: + """ + self.environment = environment or dict(os.environ) + + async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: + """ + Executes a subprocess command. + + Args: + cmd: The command to execute. + dry_run: If True, prints the command instead of executing it. + + Returns: + A CommandResult instance. + """ + result = CommandResult(command=cmd) + if dry_run: + logger.info("[DRY RUN] %s", cmd) + result.return_code = 0 # Dry run is a success + return result + + logger.info("[EXECUTING] %s", cmd) + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=self.environment, + ) + + await asyncio.gather( + _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) + ) + + result.return_code = await process.wait() + result.end_time = datetime.datetime.now() + logger.debug("Command finished with return code %s", result.return_code) + return result diff --git a/build/tools/utils.py b/build/tools/utils.py index 4c8765371316..5d7c8e0f20b2 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -28,25 +28,6 @@ logger = logging.getLogger(__name__) -def is_windows(): - return sys.platform.startswith("win32") - -def shell(cmd): - try: - logger.info("shell(): %s", cmd) - output = subprocess.check_output(cmd) - except subprocess.CalledProcessError as e: - logger.info("subprocess raised: %s", e) - if e.output: - print(e.output) - raise - except Exception as e: - logger.info("subprocess raised: %s", e) - raise - return output.decode("UTF-8").strip() - - -# Bazel BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" BazelPackage = collections.namedtuple( "BazelPackage", ["base_uri", "file", "sha256"] @@ -89,7 +70,6 @@ def shell(cmd): ), } - def download_and_verify_bazel(): """Downloads a bazel binary from GitHub, verifying its SHA256 hash.""" package = bazel_packages.get((platform.system(), platform.machine())) @@ -144,7 +124,6 @@ def progress(block_count, block_size, total_size): return os.path.join(".", package.file) - def get_bazel_paths(bazel_path_flag): """Yields a sequence of guesses about bazel path. @@ -155,7 +134,6 @@ def get_bazel_paths(bazel_path_flag): yield shutil.which("bazel") yield download_and_verify_bazel() - def get_bazel_path(bazel_path_flag): """Returns the path to a Bazel binary, downloading Bazel if not found. @@ -177,10 +155,14 @@ def get_bazel_path(bazel_path_flag): ) sys.exit(-1) - def get_bazel_version(bazel_path): try: - version_output = shell([bazel_path, "--version"]) + version_output = subprocess.run( + [bazel_path, "--version"], + encoding="utf-8", + capture_output=True, + check=True, + ).stdout.strip() except (subprocess.CalledProcessError, OSError): return None match = re.search(r"bazel *([0-9\\.]+)", version_output) @@ -188,7 +170,6 @@ def get_bazel_version(bazel_path): return None return tuple(int(x) for x in match.group(1).split(".")) - def get_clang_path_or_exit(): which_clang_output = shutil.which("clang") if which_clang_output: @@ -202,7 +183,6 @@ def get_clang_path_or_exit(): ) sys.exit(-1) - def get_clang_major_version(clang_path): clang_version_proc = subprocess.run( [clang_path, "-E", "-P", "-"], @@ -215,35 +195,42 @@ def get_clang_major_version(clang_path): return major_version - -# Python -def get_python_bin_path(python_bin_path_flag): - """Returns the path to the Python interpreter to use.""" - path = python_bin_path_flag or sys.executable - return path.replace(os.sep, "/") - - -def get_python_version(python_bin_path): - version_output = shell([ - python_bin_path, - "-c", - ( - 'import sys; print("{}.{}".format(sys.version_info[0], ' - "sys.version_info[1]))" - ), - ]) - major, minor = map(int, version_output.split(".")) - return major, minor - -def check_python_version(python_version): - if python_version < (3, 10): - print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) - sys.exit(-1) +def get_jax_configure_bazel_options(bazel_command: list[str]): + """Returns the bazel options to be written to .jax_configure.bazelrc.""" + # Get the index of the "run" parameter. Build options will come after "run" so + # we find the index of "run" and filter everything after it. + start = bazel_command.index("run") + jax_configure_bazel_options = "" + try: + for i in range(start + 1, len(bazel_command)): + bazel_flag = bazel_command[i] + # On Windows, replace all backslashes with double backslashes to avoid + # unintended escape sequences. + if platform.system() == "Windows": + bazel_flag = bazel_flag.replace("\\", "\\\\") + jax_configure_bazel_options += f"build {bazel_flag}\n" + return jax_configure_bazel_options + except ValueError: + logging.error("Unable to find index for 'run' in the Bazel command") + return "" def get_githash(): try: return subprocess.run( - ["git", "rev-parse", "HEAD"], encoding="utf-8", capture_output=True + ["git", "rev-parse", "HEAD"], + encoding="utf-8", + capture_output=True, + check=True, ).stdout.strip() except OSError: return "" + +def _parse_string_as_bool(s): + """Parses a string as a boolean value.""" + lower = s.lower() + if lower == "true": + return True + elif lower == "false": + return False + else: + raise ValueError(f"Expected either 'true' or 'false'; got {s}") diff --git a/docs/developer.md b/docs/developer.md index cbb60382b7f1..29a3cb6068ac 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -63,7 +63,7 @@ To build `jaxlib` from source, you must also install some prerequisites: To build `jaxlib` for CPU or TPU, you can run: ``` -python build/build.py +python build/build.py build --wheels=jaxlib --verbose pip install dist/*.whl # installs jaxlib (includes XLA) ``` @@ -71,7 +71,7 @@ To build a wheel for a version of Python different from your current system installation pass `--python_version` flag to the build command: ``` -python build/build.py --python_version=3.12 +python build/build.py build --wheels=jaxlib --python_version=3.12 --verbose ``` The rest of this document assumes that you are building for Python version @@ -81,13 +81,13 @@ version, simply append `--python_version=` flag every time you call installation regardless of whether the `--python_version` parameter is passed or not. -There are two ways to build `jaxlib` with CUDA support: (1) use -`python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda -support, or (2) use -`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` +If you would like to build `jaxlib` and the CUDA plugins: Run +``` +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt +``` to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and -jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and -clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag. +jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and +clang, but it can be restricted to clang via the `--build_cuda_with_clang` flag. See `python build/build.py --help` for configuration options. Here `python` should be the name of your Python 3 interpreter; on some systems, you @@ -102,11 +102,16 @@ current directory. target dependencies. To download the specific versions of CUDA/CUDNN redistributions, you can use - the following command: + the `--cuda_version` and `--cudnn_version` flags: ```bash - python build/build.py --enable_cuda \ - --cuda_version=12.3.2 --cudnn_version=9.1.1 + python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 + ``` + or + ```bash + python build/build.py build --wheels=jax-cuda-pjrt --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 ``` Please note that these parameters are optional: by default Bazel will @@ -118,7 +123,7 @@ current directory. the following command: ```bash - python build/build.py --enable_cuda \ + python build/build.py build --wheels=jax-cuda-plugin \ --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" @@ -141,7 +146,7 @@ ways to do this: line flag to `build.py` as follows: ``` - python build/build.py --bazel_options=--override_repository=xla=/path/to/xla + python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla ``` - modify the `WORKSPACE` file in the root of the JAX source tree to point to @@ -183,7 +188,7 @@ path of the current session. Ensure `bazel`, `patch` and `realpath` are accessible. Activate the conda environment. ``` -python .\build\build.py +python .\build\build.py build --wheels=jaxlib ``` To build with debug information, add the flag `--bazel_options='--copt=/Z7'`. @@ -203,12 +208,14 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \ The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`, and selecting the appropriate options. -To build jaxlib with ROCM support, you can run the following build command, +To build jaxlib with ROCM support, you can run the following build commands, suitably adjusted for your paths and ROCM version. ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 ``` +to generate three wheels (jaxlib without rocm, jax-rocm-plugin, and +jax-rocm-pjrt) AMD's fork of the XLA repository may include fixes not present in the upstream XLA repository. If you experience problems with the upstream repository, you can @@ -221,7 +228,7 @@ git clone https://github.com/ROCm/xla.git and override the XLA repository with which JAX is built: ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --bazel_options=--override_repository=xla=/rel/xla/ --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py build --wheels=jax-rocm-plugin --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 --local_xla_path=/rel/xla/ ``` For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`. @@ -246,7 +253,7 @@ run `build/build.py` script. To choose a specific version explicitly you may pass `--python_version` argument to the tool: ``` -python build/build.py --python_version=3.12 +python build/build.py build --python_version=3.12 ``` Under the hood, the hermetic Python version is controlled @@ -284,7 +291,7 @@ direct dependencies list and then execute the following command (which will call [pip-compile](https://pypi.org/project/pip-tools/) under the hood): ``` -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Alternatively, if you need more control, you may run the bazel command @@ -328,7 +335,7 @@ For example: ``` echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` ### Specifying dependencies on nightly wheels @@ -338,7 +345,7 @@ dependencies we provide a special version of the dependency updater command as follows: ``` -python build/build.py --requirements_nightly_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 --nightly_update ``` Or, if you run `bazel` directly (the two commands are equivalent): @@ -469,10 +476,13 @@ or using pytest. ### Using Bazel -First, configure the JAX build by running: +First, configure the JAX build by using the `--configure_only` flag. Pass +`--wheel_list=jaxlib` for CPU tests and CUDA/ROCM for GPU for GPU tests: ``` -python build/build.py --configure_only +python build/build.py build --wheels=jaxlib --configure_only +python build/build.py build --wheels=jax-cuda-plugin --configure_only +python build/build.py build --wheels=jax-rocm-plugin --configure_only ``` You may pass additional options to `build.py` to configure the build; see the @@ -494,14 +504,14 @@ make it available in the hermetic Python. To install a specific version of ``` echo -e "\njaxlib >= 0.4.26" >> build/requirements.in -python build/build.py --requirements_update +python build/build.py requirements_update ``` Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): ``` echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Once you have `jaxlib` installed hermetically, run: diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 32508f8310bb..135d02aecd8a 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -37,7 +37,7 @@ def repo(): # local checkout by either: # a) overriding the TF repository on the build.py command line by passing a flag # like: - # python build/build.py --bazel_options=--override_repository=xla=/path/to/xla + # python build/build.py build --local_xla_path=/path/to/xla # or # b) by commenting out the http_archive above and uncommenting the following: # local_repository( From f7e9f6253723971ffb682ea9bb494f004a6a2b5e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 25 Nov 2024 14:36:25 -0800 Subject: [PATCH 491/698] Add new CI scripts for building JAX artifacts This commit introduces new CI scripts and environment files for building JAX artifacts. It makes use of the artifact envs inside the "ci/envs/build_artifacts" folder to control the build behavior. For e.g: for building jaxlib, we will need to run `./ci/build_artifacts.sh ./ci/envs/build_artifacts/jaxlib.env` from the JAX GitHub root. PiperOrigin-RevId: 700104283 --- ci/build_artifacts.sh | 85 ++++++++++++++++++++++++++++++++++ ci/envs/default.env | 10 +++- ci/utilities/run_auditwheel.sh | 46 ++++++++++++++++++ 3 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 ci/build_artifacts.sh create mode 100644 ci/utilities/run_auditwheel.sh diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh new file mode 100644 index 000000000000..9f8d54401691 --- /dev/null +++ b/ci/build_artifacts.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +## +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Build JAX artifacts. +# Usage: ./ci/build_artifacts.sh "" +# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt +# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib" +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +artifact="$1" + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt") + +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# Adjust the values when running on Windows x86 to match the config in +# .bazelrc +if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then + os="windows" + arch="amd64" +fi + +if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then + + # Build the jax artifact + if [[ "$artifact" == "jax" ]]; then + python -m build --outdir $JAXCI_OUTPUT_DIR + else + + # Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_" + # flags in the .bazelrc depending upon the platform we are building for. + bazelrc_config="${os}_${arch}" + + # TODO(b/379903748): Add remote cache options for Linux and Windows. + if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then + bazelrc_config="rbe_${bazelrc_config}" + else + bazelrc_config="ci_${bazelrc_config}" + fi + + # Use the "_cuda" configs when building the CUDA artifacts. + if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then + bazelrc_config="${bazelrc_config}_cuda" + fi + + # Build the artifact. + python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose + + # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we + # run `auditwheel show` to verify manylinux compliance. + if [[ "$os" == "linux" ]]; then + ./ci/utilities/run_auditwheel.sh + fi + + fi + +else + echo "Error: Invalid artifact: $artifact. Allowed values are: ${allowed_artifacts[@]}" + exit 1 +fi \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env index 528c02701acc..f50b7549b823 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -34,4 +34,12 @@ export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} # Allows overriding the XLA commit that is used. -export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} \ No newline at end of file +export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} + +# Controls the location where the artifacts are written to. +export JAXCI_OUTPUT_DIR="$(pwd)/dist" + +# When enabled, artifacts will be built with RBE. Requires gcloud authentication +# and only certain platforms support RBE. Therefore, this flag is enabled only +# for CI builds where RBE is supported. +export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} \ No newline at end of file diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh new file mode 100644 index 000000000000..30b6a3b51865 --- /dev/null +++ b/ci/utilities/run_auditwheel.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Runs auditwheel to verify manylinux compatibility. + +# Get a list of all the wheels in the output directory. Only look for wheels +# that need to be verified for manylinux compliance. +WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" \)) + +if [[ -z "$WHEELS" ]]; then + echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" + exit 1 +fi + +for wheel in $WHEELS; do + printf "\nRunning auditwheel on the following wheel:" + ls $wheel + OUTPUT_FULL=$(python -m auditwheel show $wheel) + # Remove the wheel name from the output to avoid false positives. + wheel_name=$(basename $wheel) + OUTPUT=${OUTPUT_FULL//${wheel_name}/} + + # If a wheel is manylinux2014 compliant, `auditwheel show` will return the + # platform tag as manylinux_2_17. manylinux2014 is an alias for + # manylinux_2_17. + if echo "$OUTPUT" | grep -q "manylinux_2_17"; then + printf "\n$wheel_name is manylinux2014 compliant.\n" + else + echo "$OUTPUT_FULL" + printf "\n$wheel_name is NOT manylinux2014 compliant.\n" + exit 1 + fi +done \ No newline at end of file From ebea4353f88f168b62e60db0491f50aef6aac012 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 25 Nov 2024 14:56:40 -0800 Subject: [PATCH 492/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/7059553f7e215709642e5a5b19274b0e78d4349a. PiperOrigin-RevId: 700110142 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 135d02aecd8a..1e3df5fafb4d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "40d457a268baf95e42cd95709dedef70c0ea2994" -XLA_SHA256 = "a7bd8cc608964ba91d9edfd6070929cffe09b448e9e36fd8224bc8fc99202db3" +XLA_COMMIT = "7059553f7e215709642e5a5b19274b0e78d4349a" +XLA_SHA256 = "16bf9a4e3e62a5180fddec2526657cd0ba9c2a1a3510458054730e60c9526294" def repo(): tf_http_archive( From ef7df1ae7c258e402e7d379816b4d019a1643d44 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 25 Nov 2024 16:36:05 -0800 Subject: [PATCH 493/698] [pallas_mgpu] Allow trees (eg tuples) to be returned from cond_p expressions. PiperOrigin-RevId: 700136799 --- jax/_src/pallas/mosaic_gpu/lowering.py | 39 ++++++++++++++++++++++---- tests/pallas/mosaic_gpu_test.py | 18 ++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 66437839cce2..198a9d8095d0 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1514,8 +1514,24 @@ def _while_lowering_rule( @register_lowering_rule(lax.cond_p) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in + + # We need the branch return mlir types in order to construct the + # switch operation. To avoid leaking information about what kind of + # mlir types are internal to FragmentedArrays and other mgpu types, + # we run one of the branches in a dummy module that we throw away to + # extract the return types + with ir.InsertionPoint(ir.Module.create().body): + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args + ) + yielded = [ + _ensure_ir_value(out, aval.dtype) or out + for out, aval in zip(outs, ctx.avals_out) + ] + yielded_leaves, _ = jax.tree.flatten(yielded) + switch_op = scf_dialect.IndexSwitchOp( - map(mgpu_utils.dtype_to_ir_type, ctx.avals_out), + [v.type for v in yielded_leaves], _as_index(_ensure_ir_value(index, index_aval.dtype)), ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), num_caseRegions=len(branches) - 1, @@ -1527,16 +1543,27 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): regions = list(switch_op.regions) # Move the default region to the back. regions = regions[1:] + regions[:1] + treedef = None for branch, region in zip(branches, regions): with ir.InsertionPoint(region.blocks.append()): outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args + ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts ) - scf_dialect.yield_([ - _ensure_ir_value(out, aval.dtype) + + yielded = [ + _ensure_ir_value(out, aval.dtype) or out for out, aval in zip(outs, ctx.avals_out) - ]) - return list(switch_op.results) + ] + yielded_leaves, yielded_treedef = jax.tree.flatten(yielded) + if treedef is None: + treedef = yielded_treedef + else: + assert treedef == yielded_treedef + + scf_dialect.yield_(yielded_leaves) + + assert treedef is not None + return treedef.unflatten(list(switch_op.results)) @register_lowering_rule(lax.bitcast_convert_type_p) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 110d83bd992b..cb50b9475cb7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -749,6 +749,24 @@ def kernel(x_ref, o_ref): self.assertIn("acc * 2:", output()) + def test_cond_returning_array(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + acc = x_ref[...].sum() + acc2, acc = jax.lax.cond( + acc % 2 == 0, + lambda: (acc * 2, acc), + lambda: (acc, acc * 2), + ) + o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape) + + x = jnp.arange(256) + np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + + @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): self.skip_unless_sm90a() From c5dc980db80a55fd2a0583208921f84a854dfc97 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 25 Nov 2024 17:55:27 -0800 Subject: [PATCH 494/698] [mgpu/pallas_mgpu] Pointwise tanh support PiperOrigin-RevId: 700158250 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 ++++++ jax/experimental/mosaic/gpu/fragmented_array.py | 9 +++++++++ tests/mosaic/gpu_test.py | 3 +++ tests/pallas/mosaic_gpu_test.py | 5 +++-- 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 198a9d8095d0..6e7adfc60a53 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1179,6 +1179,12 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) +@register_lowering_rule(lax.tanh_p) +def _tanh_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) + + @register_lowering_rule(lax.logistic_p) def _logistic_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 26157f407844..1a108ec2ee72 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -918,6 +918,15 @@ def cos(self, *, approx: bool = False): self._lift_fast_unary("cos.approx.f32") if approx else mlir_math.cos ) + def tanh(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + if approx and self.mlir_dtype != ir.F32Type.get(): + raise NotImplementedError + return self._pointwise( + self._lift_fast_unary("tanh.approx.f32") if approx else mlir_math.tanh + ) + def rsqrt(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index aeddbc7e033d..2a4efddd88b1 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1354,11 +1354,14 @@ def kernel(ctx, dst, _): ops=( (lambda x: -x, jax.lax.neg), (lambda x: x + 42, lambda x: x + 42), + (lambda x: x.tanh(), jax.lax.tanh), ), dtype=[jnp.float32, jnp.int32, jnp.uint32], ) def test_unary(self, ops, dtype, m=64, n=32): op, np_op = ops + if np_op is jax.lax.tanh and jnp.issubdtype(dtype, jnp.integer): + raise self.skipTest("Tanh not supported for integer types") def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index cb50b9475cb7..8c4bcd117eda 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -70,8 +70,9 @@ class PallasCallTest(PallasTest): ("exp", jax.lax.exp), ("square", lambda x: x ** 2), ("rsqrt", jax.lax.rsqrt), + ("tanh", jax.lax.tanh, 1e-6), ) - def test_unary_ops(self, unary): + def test_unary_ops(self, unary, rtol=1e-7): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), @@ -80,7 +81,7 @@ def kernel(x_ref, o_ref): o_ref[...] = unary(x_ref[...]) x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(kernel(x), unary(x)) + np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol) def test_add_first(self): @functools.partial( From 59e13f81146a65a40b6d25dd73bf19c07d39614e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 25 Nov 2024 18:14:30 -0800 Subject: [PATCH 495/698] Add sharding argument to reshape since it also takes a `shape` argument for the output shape PiperOrigin-RevId: 700163883 --- jax/_src/lax/lax.py | 49 +++++++++++++++++++++--------- jax/_src/pallas/mosaic/lowering.py | 3 +- jax/_src/pallas/triton/lowering.py | 2 +- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/sparse/bcoo.py | 4 ++- tests/pjit_test.py | 34 ++++++++++++--------- 6 files changed, 62 insertions(+), 32 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 75de83bd5542..be7d13195554 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1231,7 +1231,8 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: return broadcast(x, (1,) * (rank - ndim)) def reshape(operand: ArrayLike, new_sizes: Shape, - dimensions: Sequence[int] | None = None) -> Array: + dimensions: Sequence[int] | None = None, + sharding: NamedSharding | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -1285,7 +1286,8 @@ def reshape(operand: ArrayLike, new_sizes: Shape, return reshape_p.bind( operand, *dyn_shape, new_sizes=tuple(static_new_sizes), - dimensions=None if dims is None or same_dims else dims) + dimensions=None if dims is None or same_dims else dims, + sharding=sharding) def pad(operand: ArrayLike, padding_value: ArrayLike, padding_config: Sequence[tuple[int, int, int]]) -> Array: @@ -4654,7 +4656,7 @@ def shape_as_value(shape: core.Shape): ] return concatenate(dims, dimension=0) -def _reshape_shape_rule(operand, *, new_sizes, dimensions): +def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding): if not all(d >= 0 for d in new_sizes): msg = 'reshape new_sizes must all be positive, got {}.' raise TypeError(msg.format(new_sizes)) @@ -4674,7 +4676,9 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions): raise TypeError(msg.format(dimensions, np.shape(operand))) return tuple(new_sizes) -def _reshape_sharding_rule(operand, *, new_sizes, dimensions): +def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): + if sharding is not None: + return sharding filtered_spec = [ (sh, sp) for sh, sp in zip(operand.shape, operand.sharding.spec) if sh != 1 @@ -4687,14 +4691,18 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions): else: sh, sp = next(fs) if n != sh: - raise NotImplementedError + raise ValueError( + 'This reshape is not supported. Please specify the sharding of the' + ' output via the `sharding` argument of reshape.') new_spec.append(sp) return operand.sharding.with_spec(new_spec) -def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions): +def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions, + sharding): if not dyn_shape: out_aval, effects = reshape_p.abstract_eval( - operand.aval, new_sizes=new_sizes, dimensions=dimensions) + operand.aval, new_sizes=new_sizes, dimensions=dimensions, + sharding=sharding) return [out_aval], effects else: # TODO(mattjj, necula): perform more checks like _reshape_shape_rule @@ -4705,18 +4713,29 @@ def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions): return [out_aval], core.no_effects -def _reshape_dtype_rule(operand, *, new_sizes, dimensions): +def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding): return operand.dtype -def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions): +def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): assert ad.is_undefined_primal(operand) if dimensions is None: + if config.sharding_in_types.value: + return [reshape(t, operand.aval.shape, sharding=operand.aval.sharding)] return [reshape(t, operand.aval.shape)] else: - return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)), + if config.sharding_in_types.value: + t_s = operand.sharding.with_spec( + tuple(map(str, np.take(operand.aval.sharding.spec, dimensions)))) + else: + t_s = None + return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), + sharding=t_s), np.argsort(dimensions))] -def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): +def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions, + sharding): + if sharding is not None: + raise NotImplementedError operand, = batched_args bdim, = batch_dims operand = batching.moveaxis(operand, bdim, 0) @@ -4725,7 +4744,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0 -def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): +def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): aval_out, = ctx.avals_out if dimensions is not None: x = hlo.transpose(x, mlir.dense_int_array(dimensions)) @@ -4733,12 +4752,14 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) out = mlir.reshape(ctx, x, aval_out) if config.sharding_in_types.value: + if sharding is not None: + assert sharding == aval_out.sharding return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] def _reshape_staging_rule( - trace, x, *dyn, new_sizes, dimensions): - params = dict(new_sizes=new_sizes, dimensions=dimensions) + trace, x, *dyn, new_sizes, dimensions, sharding): + params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding) if not dyn: return trace.default_process_primitive(reshape_p, (x,), params) av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 1f0062cad0f9..590d1db92198 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1849,7 +1849,8 @@ def _convert_element_type_lowering_rule( lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule -def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions): +def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, + sharding): if dimensions is not None: raise NotImplementedError if any(d is None for d in new_sizes): diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index fa49f3b7cbbf..94848236d299 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1612,7 +1612,7 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions): @register_lowering(lax.reshape_p) def _reshape_lowering_rule( - ctx: LoweringRuleContext, a, *, new_sizes, dimensions + ctx: LoweringRuleContext, a, *, new_sizes, dimensions, sharding, ): del new_sizes # Unused. if dimensions is not None: diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c41eda693d7f..188ffeb6d670 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2291,7 +2291,7 @@ def _empty(*, dtype): tf_impl[lax_internal.empty_p] = _empty -def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval): +def _reshape(operand, *, new_sizes, dimensions, sharding, _in_avals, _out_aval): if dimensions is None: dimensions = tf.range(tf.rank(operand)) new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 9dcb0cadc1b2..d8bf1ee4a7bd 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1826,7 +1826,9 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO: return BCOO((new_data, new_indices), shape=out_aval.shape) -def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[int] | None = None) -> BCOO: +def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], + dimensions: Sequence[int] | None = None, + sharding=None) -> BCOO: """Sparse implementation of {func}`jax.lax.reshape`. Args: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e5aa1a604eae..8c362af19d0b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5147,31 +5147,37 @@ def h2(x, y): @parameterized.named_parameters( ('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False), ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), - ('3', (8, 1), (1, 4, 2), P('x', None), P(None, 'x', None), True) + ('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True) ) - def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, will_error): + def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, + use_sharding_arg): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(math.prod(src_shape), dtype=np.float32).reshape(src_shape) arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) - @jax.jit - def f(x): - y = jnp.reshape(x, dst_shape) + @partial(jax.jit, static_argnums=1) + def f(x, new_sharding): + y = lax.reshape(x, dst_shape, sharding=new_sharding) y = y * 2 self.assertEqual(y.sharding.spec, dst_spec) return y - if will_error: - with self.assertRaises(NotImplementedError): - f(arr) - else: - out = f(arr) - self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec)) - self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2) + new_s = (NamedSharding(mesh.abstract_mesh, dst_spec) + if use_sharding_arg else None) + out = f(arr, new_s) + self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec)) + self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2) + + lowered_text = f.lower(arr, new_s).as_text() + self.assertIn('@Sharding', lowered_text) - lowered_text = f.lower(arr).as_text() - self.assertIn('@Sharding', lowered_text) + def g(x): + out = f(x, new_s) + return jnp.square(jnp.sum(out)) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) def test_select(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) From 627debc78b54983f7c17d55d5c09ed1c251561a2 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 25 Nov 2024 18:29:52 -0800 Subject: [PATCH 496/698] Create a `null_mesh_context` internal context manager to handle null contexts properly. PiperOrigin-RevId: 700167406 --- jax/_src/mesh.py | 34 +++++++++++++++++++++++++--------- jax/_src/pjit.py | 5 ++--- jax/_src/stages.py | 4 ++-- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 3d0e1b0cccf5..214fb190d498 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -455,17 +455,10 @@ def local_mesh(self): _raise_value_error("local_mesh") def __enter__(self): - mesh_context.stack.append(self) - mesh_context.mesh = self - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) - return self + return push_mesh_context(self) def __exit__(self, exc_type, exc_value, traceback): - mesh_context.stack.pop() - mesh_context.mesh = mesh_context.stack[-1] - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) + pop_mesh_context() return False @staticmethod @@ -486,3 +479,26 @@ def __init__(self): self.mesh = self.stack[-1] mesh_context = MeshContext() + +def push_mesh_context(val): + mesh_context.stack.append(val) + mesh_context.mesh = val + jax_config.abstract_mesh_context_manager.set_local( + tuple(m for m in mesh_context.stack if m is not None)) + return val + +def pop_mesh_context(): + mesh_context.stack.pop() + mesh_context.mesh = mesh_context.stack[-1] + jax_config.abstract_mesh_context_manager.set_local( + tuple(m for m in mesh_context.stack if m is not None)) + + +class null_mesh_context: + + def __enter__(self): + return push_mesh_context(None) + + def __exit__(self, *excinfo): + pop_mesh_context() + return False diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fb60f9d52727..b77af1a8f14e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -16,7 +16,6 @@ from collections import defaultdict from collections.abc import Callable, Sequence, Iterable -import contextlib import dataclasses from functools import partial import inspect @@ -696,7 +695,7 @@ def _infer_params_impl( def get_abstract_mesh(in_avals): if not config.sharding_in_types.value: - return contextlib.nullcontext() + return mesh_lib.null_mesh_context() m = None for a in in_avals: # TODO(yashkatariya): Remove this when mesh context can be set by the user. @@ -709,7 +708,7 @@ def get_abstract_mesh(in_avals): m = a.sharding.mesh # type: ignore # TODO(yashkatariya): Remove this when mesh context can be set by the user. if m is None: - return contextlib.nullcontext() + return mesh_lib.null_mesh_context() assert m is not None return m diff --git a/jax/_src/stages.py b/jax/_src/stages.py index b6f3b63d3de4..cc89a3338313 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,7 +30,6 @@ """ from __future__ import annotations -import contextlib import functools from collections.abc import Sequence from dataclasses import dataclass @@ -44,6 +43,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util +from jax._src import mesh as mesh_lib from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.layout import Layout from jax._src.interpreters import mlir @@ -717,7 +717,7 @@ class Traced(Stage): "_args_flat", "_arg_names", "_num_consts"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, abstract_mesh=contextlib.nullcontext(), + lower_callable, abstract_mesh=mesh_lib.null_mesh_context(), args_flat=None, arg_names=None, num_consts: int = 0): self.jaxpr = jaxpr self.args_info = args_info From f828f2d7d0a17baa1eb9add69b80f2d31eaa9c08 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 25 Nov 2024 19:12:56 -0800 Subject: [PATCH 497/698] [mgpu] Pointwise min PiperOrigin-RevId: 700175724 --- .../mosaic/gpu/fragmented_array.py | 18 ++++++++--- tests/mosaic/gpu_test.py | 27 ++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 31 ++++++++++++------- 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 1a108ec2ee72..64d351aba026 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -622,10 +622,10 @@ def to_layout(self, new_layout: FragmentedLayout): reg, self.shape, new_layout, is_signed=self.is_signed ) - def _pointwise(self, op, *other, output_is_signed: bool | None = None): + def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_dispatch=False): # If our layout is a splat, then we should either dispatch to a non-splat # layout, or broadcast ourselves to the output shape first. - if isinstance(self.layout, WGSplatFragLayout): + if not force_no_dispatch and isinstance(self.layout, WGSplatFragLayout): output_shape = self.shape for i, o in enumerate(other): if not isinstance(o, FragmentedArray): @@ -642,7 +642,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): output_shape = np.broadcast_shapes(output_shape, o.shape) # If we get here then we haven't found any non-splat layout. return self.broadcast(output_shape)._pointwise( - op, *other, output_is_signed=output_is_signed + op, *other, output_is_signed=output_is_signed, force_no_dispatch=True, ) other_arrs = [] @@ -884,7 +884,17 @@ def max(self, other): arith.maxsi if self.is_signed else arith.maxui, other ) else: - return NotImplemented + return NotImplementedError + + def min(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.minimumf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise( + arith.minsi if self.is_signed else arith.minui, other + ) + else: + return NotImplementedError def exp(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 2a4efddd88b1..1c7f22f885ec 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1256,6 +1256,7 @@ class FragmentedArrayTest(TestCase): operator.add, operator.mul, operator.sub, + (lambda x, y: mgpu.FragmentedArray.min(x, y), np.minimum), (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), ), dtype=[jnp.float32, jnp.int32, jnp.uint32], @@ -1285,6 +1286,32 @@ def kernel(ctx, dst, _): ref_rhs = scalar_rhs or ref_x np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + def test_minimum_np_compatibility(self): + one = np.ones((128, 128)).astype(np.float32) + negz = one * -0. + posz = one * 0. + nan = one * np.nan + expectation = (np.minimum(negz, posz) == negz) & (np.minimum(nan, one) != one) + assert np.all(expectation), expectation + + def kernel(ctx, dst, _): + f32 = ir.F32Type.get() + splat = lambda i: mgpu.FragmentedArray.splat(c(i, f32), (128, 128)) + negz = splat(-0.) + posz = splat(0.) + nan = splat(np.nan) + one = splat(1.) + res = (negz.min(posz) == negz) & (one.min(nan) != one) & (nan.min(one) != one) + i8 = ir.IntegerType.get_signless(8) + res.astype(i8, is_signed=False).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((128, 128), np.int8) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + # astype() uses extsi so i1=True becomes -1 + np.testing.assert_array_equal(result == -1, expectation) + @parameterized.product( op=[operator.truediv, operator.floordiv, operator.mod], dtype=[jnp.float32, jnp.int32, jnp.uint32], diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 8c4bcd117eda..993de287f74c 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -83,6 +83,25 @@ def kernel(x_ref, o_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol) + @parameterized.named_parameters( + ("add", lambda x, y: x + y), + ("mul", lambda x, y: x * y), + ("div", lambda x, y: x / y), + ("min", lambda x, y: jnp.minimum(x, y)), + ("max", lambda x, y: jnp.maximum(x, y)), + ) + def test_binary_op(self, bop): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = bop(x_ref[...], y_ref[...]) + + x = jnp.arange(256).astype(jnp.float32) + y = x + 1 + np.testing.assert_array_equal(kernel(x, y), bop(x, y)) + def test_add_first(self): @functools.partial( pl.pallas_call, @@ -111,18 +130,6 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - def test_add_xy(self): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - ) - def kernel(x_ref, y_ref, o_ref): - o_ref[...] = x_ref[...] + y_ref[...] - - x = jnp.arange(256).astype(jnp.float32) - y = x + 1 - np.testing.assert_array_equal(kernel(x, y), x + y) - def test_add_xy_indexed(self): @functools.partial( pl.pallas_call, From 16a5607c9100e54bd77738a4689db600a89458f5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 26 Nov 2024 02:12:16 -0800 Subject: [PATCH 498/698] Use xla_extension_version instead of jaxlib_version PiperOrigin-RevId: 700265297 --- jax/_src/cache_key.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 324fa85f81ed..c7957ae11a33 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -21,9 +21,9 @@ from typing import cast as type_cast from jax._src import config -from jax._src.lib import version as jaxlib_version from jax._src.lib import version_str as jaxlib_version_str from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir import passmanager as pm import numpy as np @@ -226,8 +226,11 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_dump_hlo_as_long_text = False debug_options.xla_dump_disable_metadata = False debug_options.xla_dump_hlo_pipeline_re = "" - if jaxlib_version > (0, 4, 35): + + # "Requires jaxlib 0.4.36+" + if xla_extension_version > 296: debug_options.xla_gpu_experimental_autotune_cache_mode = 0 + # Optional way to specify the cuda install path to be used by the compiler. # This could possibly affect the cuda version compiled with, but this should # already be included in the platform information (and might not be reflected From b6566c80b099510a59d5777000dee8176cd28831 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 26 Nov 2024 02:13:36 -0800 Subject: [PATCH 499/698] [mosaic_gpu] Fixed unbounded recursion in `FragmentedArray._pointwise` PiperOrigin-RevId: 700265616 --- jax/experimental/mosaic/gpu/fragmented_array.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 64d351aba026..b16cd26da271 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -622,10 +622,10 @@ def to_layout(self, new_layout: FragmentedLayout): reg, self.shape, new_layout, is_signed=self.is_signed ) - def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_dispatch=False): + def _pointwise(self, op, *other, output_is_signed: bool | None = None): # If our layout is a splat, then we should either dispatch to a non-splat # layout, or broadcast ourselves to the output shape first. - if not force_no_dispatch and isinstance(self.layout, WGSplatFragLayout): + if isinstance(self.layout, WGSplatFragLayout): output_shape = self.shape for i, o in enumerate(other): if not isinstance(o, FragmentedArray): @@ -641,9 +641,10 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None, force_no_ else: output_shape = np.broadcast_shapes(output_shape, o.shape) # If we get here then we haven't found any non-splat layout. - return self.broadcast(output_shape)._pointwise( - op, *other, output_is_signed=output_is_signed, force_no_dispatch=True, - ) + if self.shape != output_shape: + return self.broadcast(output_shape)._pointwise( + op, *other, output_is_signed=output_is_signed + ) other_arrs = [] for o in other: From 231967fdb5d9fee2f2fa4e189f90ce7a64a4e009 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 26 Nov 2024 04:05:35 -0800 Subject: [PATCH 500/698] [AutoPGLE] Explicitly ignore host callback pointers Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE. PiperOrigin-RevId: 700289965 --- jax/BUILD | 1 + jax/_src/cache_key.py | 115 +++++++++++++++++++++++----------- jax/_src/compilation_cache.py | 21 +++++-- jax/_src/compiler.py | 28 +++++++-- tests/cache_key_test.py | 63 ++++++++++++++++--- 5 files changed, 173 insertions(+), 55 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 64bfa627f42e..d35ff0e399a6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -426,6 +426,7 @@ pytype_strict_library( name = "compiler", srcs = ["_src/compiler.py"], deps = [ + ":cache_key", ":compilation_cache_internal", ":config", ":mlir", diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index c7957ae11a33..6e7a421482ce 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import enum import hashlib import io import logging @@ -62,11 +63,23 @@ def custom_hook() -> str: return "" -def get(module: ir.Module, - devices: np.ndarray, - compile_options: xla_client.CompileOptions, - backend: xla_client.Client, - compression_algorithm: str = "zstandard") -> str: +class IgnoreCallbacks(enum.IntEnum): + # Do not remove any callback pointers from precompiled IR. + NO = enum.auto() + # Remove all callback pointers from precompiled IR. + ALL = enum.auto() + # Remove only custom_partitioning callback pointer from precompiled IR. + CUSTOM_PARTITIONING = enum.auto() + + +def get( + module: ir.Module, + devices: np.ndarray, + compile_options: xla_client.CompileOptions, + backend: xla_client.Client, + compression_algorithm: str = "zstandard", + ignore_callbacks: IgnoreCallbacks = IgnoreCallbacks.NO, +) -> str: """Creates a hashed string to use as a key to the compilation cache. Creates a cache key that is a hex-encoded string of a unique hash based on @@ -79,28 +92,47 @@ def get(module: ir.Module, backend: description of the platform (e.g., TPU version) compression_algorithm: a string representing the compression algorithm used for the executable before persisting in the cache + ignore_callbacks: whether to remove the all callback pointer from the + computation. Typical return value example: 'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf' """ entries = [ - ("computation", - lambda hash_obj: _hash_computation(hash_obj, module)), - ("jax_lib version", - lambda hash_obj: hash_obj.update( - bytes(jaxlib_version_str.encode("utf-8")))), - ("XLA flags", - lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())), - ("compile_options", - lambda hash_obj: _hash_serialized_compile_options( - hash_obj, compile_options, - # In case of GPU multi-process tasks we need to strip device - # assignment to use cache key as invariant between processes. - strip_device_assignment=(backend.platform == "gpu"))), - ("accelerator_config", - lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)), - ("compression", - lambda hash_obj: _hash_string(hash_obj, compression_algorithm)), + ( + "computation", + lambda hash_obj: _hash_computation( + hash_obj, module, ignore_callbacks + ), + ), + ( + "jax_lib version", + lambda hash_obj: hash_obj.update( + bytes(jaxlib_version_str.encode("utf-8")) + ), + ), + ( + "XLA flags", + lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()), + ), + ( + "compile_options", + lambda hash_obj: _hash_serialized_compile_options( + hash_obj, + compile_options, + # In case of GPU multi-process tasks we need to strip device + # assignment to use cache key as invariant between processes. + strip_device_assignment=(backend.platform == "gpu"), + ), + ), + ( + "accelerator_config", + lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend), + ), + ( + "compression", + lambda hash_obj: _hash_string(hash_obj, compression_algorithm), + ), ("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())), ] @@ -131,45 +163,56 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): ) -def _remove_custom_partitioning_ptr(m: ir.Module): - """ - Removes custom_partitioning callback pointer from precompiled IR. +def _remove_callbacks(m: ir.Module, ignore_callbacks: IgnoreCallbacks): + """Removes callback pointers from precompiled IR. + Python function pointers are not deterministic across executions. """ def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult: - if (op.name == "stablehlo.custom_call" and - op.attributes["call_target_name"].value == "CustomSPMDPartitioning"): + if op.name == "stablehlo.custom_call" and ( + ( + ignore_callbacks == IgnoreCallbacks.ALL + and op.attributes["call_target_name"].value.endswith("callback") + ) + or op.attributes["call_target_name"].value == "CustomSPMDPartitioning" + ): op.attributes["backend_config"] = ir.StringAttr.get("REMOVED") return ir.WalkResult.ADVANCE + if ignore_callbacks == IgnoreCallbacks.NO: + return m + m.operation.walk(_update_bc_attribute) return m -def _serialize_ir(m: ir.Module) -> bytes: +def _serialize_ir(m: ir.Module, ignore_callbacks: IgnoreCallbacks) -> bytes: output = io.BytesIO() - if config.remove_custom_partitioning_ptr_from_cache_key.value: - m = _remove_custom_partitioning_ptr(type_cast(ir.Module, - m.operation.clone())) + if ignore_callbacks != IgnoreCallbacks.NO: + m = _remove_callbacks( + type_cast(ir.Module, m.operation.clone()), ignore_callbacks + ) m.operation.write_bytecode(file=output) return output.getvalue() -def _canonicalize_ir(m_original: ir.Module) -> bytes: +def _canonicalize_ir( + m_original: ir.Module, ignore_callbacks: IgnoreCallbacks +) -> bytes: with m_original.context: m = type_cast(ir.Module, m_original.operation.clone()) passes = pm.PassManager.parse( "builtin.module(strip-debuginfo)" ) passes.run(m.operation) - return _serialize_ir(m) + return _serialize_ir(m, ignore_callbacks) -def _hash_computation(hash_obj, module): +def _hash_computation(hash_obj, module, ignore_callbacks: IgnoreCallbacks): if config.compilation_cache_include_metadata_in_key.value: - canonical_ir = _serialize_ir(module) + canonical_ir = _serialize_ir(module, ignore_callbacks) else: - canonical_ir = _canonicalize_ir(module) + canonical_ir = _canonicalize_ir(module, ignore_callbacks) hash_obj.update(canonical_ir) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 89dd97175f00..d8724e42975e 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -267,12 +267,21 @@ def put_executable_and_time( cache.put(cache_key, executable_and_time) -def get_cache_key(module: ir.Module, - devices: np.ndarray, - compile_options, - backend) -> str: - return cache_key.get(module, devices, compile_options, backend, - "zstandard" if zstandard is not None else "zlib") +def get_cache_key( + module: ir.Module, + devices: np.ndarray, + compile_options, + backend, + ignore_callbacks: cache_key.IgnoreCallbacks = cache_key.IgnoreCallbacks.NO, +) -> str: + return cache_key.get( + module, + devices, + compile_options, + backend, + "zstandard" if zstandard is not None else "zlib", + ignore_callbacks, + ) def is_initialized() -> bool: diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index ebb1a2b54855..d40ded556a5b 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -24,6 +24,7 @@ from typing import Any, Callable import warnings +from jax._src import cache_key as cache_key_type from jax._src import compilation_cache from jax._src import config as config from jax._src import distributed @@ -33,8 +34,8 @@ from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir -from jax._src.lib import xla_client as xc from jax._src.lib import version as jaxlib_version +from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir import numpy as np @@ -351,8 +352,18 @@ def compile_or_get_cached( monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') try: + if config.remove_custom_partitioning_ptr_from_cache_key.value: + ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING + else: + ignore_callbacks = cache_key_type.IgnoreCallbacks.NO + cache_key = compilation_cache.get_cache_key( - computation, devices, compile_options, backend) + computation, + devices, + compile_options, + backend, + ignore_callbacks=ignore_callbacks, + ) except xc._xla.XlaRuntimeError as ex: logger.error("compile_or_get_cached: unable to generate cache key, " "skipping the cache: %s", ex) @@ -385,7 +396,12 @@ def compile_or_get_cached( compile_options.executable_build_options.fdo_profile = b"pgle profiled" pgle_profiled_module_key = compilation_cache.get_cache_key( - computation, devices, compile_options, backend) + computation, + devices, + compile_options, + backend, + cache_key_type.IgnoreCallbacks.ALL, + ) compile_options.executable_build_options.fdo_profile = fdo_profile if _is_executable_in_cache(backend, pgle_profiled_module_key): @@ -493,7 +509,11 @@ def _share_fdo_profiles( compile_options.executable_build_options.fdo_profile = b"" profile_key = ( compilation_cache.get_cache_key( - computation, devices, compile_options, backend + computation, + devices, + compile_options, + backend, + cache_key_type.IgnoreCallbacks.ALL, ) + "_fdo_sync" ) diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 8f9c5d0e8b82..f84a9d5fb39f 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -176,7 +176,7 @@ def _infer_sharding_from_operands(mesh, arg_shapes, result_shape): @custom_partitioning def _cp_add(x, y): - return jax.numpy.add(x, y) + return jax.numpy.add(x, y) _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, @@ -199,14 +199,59 @@ def _cp_add(x, y): r'(.*?backend_config\s*=\s*"([^"]*)".*?)' r'\}' ) - with config.remove_custom_partitioning_ptr_from_cache_key(True): - with computation.context: - updated_module = cache_key._remove_custom_partitioning_ptr( - type_cast(ir.Module, computation.operation.clone())) - bcs = [match[2] for - match in re.findall(pattern, str(updated_module), re.DOTALL)] - for bc in bcs: - self.assertEqual(bc, "REMOVED") + with computation.context: + updated_module = cache_key._remove_callbacks( + type_cast(ir.Module, computation.operation.clone()), + ignore_callbacks=cache_key.IgnoreCallbacks.ALL, + ) + bcs = [ + match[2] + for match in re.findall(pattern, str(updated_module), re.DOTALL) + ] + for bc in bcs: + self.assertEqual(bc, "REMOVED") + + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + backend = xla_bridge.get_backend() + hash_without_callback_ptrs = cache_key.get( + computation, + devices, + compile_options, + backend, + ignore_callbacks=cache_key.IgnoreCallbacks.CUSTOM_PARTITIONING, + ) + expected_hash = cache_key.get( + updated_module, devices, compile_options, backend + ) + self.assertEqual(expected_hash, hash_without_callback_ptrs) + + @jtu.skip_on_devices("cpu") + def test_host_callbacks_ptrs_removed(self): + def _host_callback(x, y): + jax.debug.print("x={x[0]} y={y[0]}", x=x, y=y) + + computation = ( + jax.jit(_host_callback) + .lower( + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + ) + .compiler_ir() + ) + pattern = r'(.*?backend_config\s*=\s*"([^"]*)".*?)' + with computation.context: + updated_module = cache_key._remove_callbacks( + type_cast(ir.Module, computation.operation.clone()), + ignore_callbacks=cache_key.IgnoreCallbacks.ALL, + ) + bcs = [ + match[1] + for match in re.findall(pattern, str(updated_module), re.DOTALL) + ] + for bc in bcs: + self.assertEqual(bc, "REMOVED") def test_different_device_assignment(self): computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() From dc11d402f533541887ebfccb81e461fdda69db4a Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 26 Nov 2024 04:08:59 -0800 Subject: [PATCH 501/698] [Pallas TPU] Better error message for lowering `sp.broadcast_to_p` `sp.broadcast_to_p` is a GPU-specific primitive, but it mistakenly appears in TPU lowerings. This PR improves the error message to reflect this. As an example, currently, users will hit this error when doing: ``` def kernel(x_ref, o_ref): m, n = 32, 8 x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], jnp.arange(n, dtype=jnp.int32)[None])) o_ref[...] = x ``` PiperOrigin-RevId: 700290975 --- jax/_src/pallas/mosaic/lowering.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 590d1db92198..af5aa66a3851 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1545,6 +1545,18 @@ def _proxy_reduce(arg, *, axes): lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule +def _broadcast_to_lowering_rule( + ctx: LoweringRuleContext, x, shape: Sequence[int] +): + raise RuntimeError( + "`broadcast_to` is a Triton-specific primitive. Please consider using" + " `jnp.broadcast_to` instead." + ) + + +lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule + + def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding ): From 762301fc5d6dcc53311d8e24deb7a1376d0d4332 Mon Sep 17 00:00:00 2001 From: "labs-code-app[bot]" <161369871+labs-code-app[bot]@users.noreply.github.com> Date: Tue, 26 Nov 2024 13:57:47 +0000 Subject: [PATCH 502/698] Add exec_time_optimization_effort and memory_fitting_effort flags. These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0. --- CHANGELOG.md | 1 + jax/_src/compiler.py | 3 +++ jax/_src/config.py | 12 ++++++++++++ 3 files changed, 16 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce8b040439c0..b5758d107077 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions ({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now supported on GPU. See {jax-issue}`#24663` for more details. + * Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0. * Bug fixes * Fixed a bug where the GPU implementations of LU and QR decomposition would diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index ebb1a2b54855..75eb723a1c75 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -200,6 +200,9 @@ def get_compile_options( setattr(build_options, name, env_options_overrides.pop(name)) compile_options.env_option_overrides = list(env_options_overrides.items()) + build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value + build_options.memory_fitting_effort = config.memory_fitting_effort.value + debug_options = compile_options.executable_build_options.debug_options if lib.cuda_path is not None: debug_options.xla_gpu_cuda_data_dir = lib.cuda_path diff --git a/jax/_src/config.py b/jax/_src/config.py index 43c29c996cfb..42cfde8788e3 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1993,3 +1993,15 @@ def _update_garbage_collection_guard(state, key, val): 'to use this feature.' ), ) + +exec_time_optimization_effort = float_state( + name='jax_exec_time_optimization_effort', + default=0.0, + help='Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].' +) + +memory_fitting_effort = float_state( + name='jax_memory_fitting_effort', + default=0.0, + help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].' +) From 504c738781910983182aa5df032ba027b8a0d0dc Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Tue, 26 Nov 2024 16:50:49 +0000 Subject: [PATCH 503/698] Use next to tiny as smallest floating point value on Mac ARM --- jax/_src/test_util.py | 9 +++++++-- tests/lax_test.py | 24 ++---------------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c5a713743fb8..c639ebd03586 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -25,6 +25,7 @@ import logging import math import os +import platform import re import sys import tempfile @@ -1704,6 +1705,10 @@ def complex_plane_sample(dtype, size_re=10, size_im=None): size_im = size_re finfo = np.finfo(dtype) + machine = platform.machine() + is_arm_cpu = machine.startswith('aarch') or machine.startswith('arm') + smallest = np.nextafter(finfo.tiny, finfo.max) if is_arm_cpu and platform.system() == 'Darwin' else finfo.tiny + def make_axis_points(size): prec_dps_ratio = 3.3219280948873626 logmin = logmax = finfo.maxexp / prec_dps_ratio @@ -1722,8 +1727,8 @@ def make_axis_points(size): axis_points[1] = finfo.min axis_points[-2] = finfo.max if size > 0: - axis_points[size] = -finfo.tiny - axis_points[-size - 1] = finfo.tiny + axis_points[size] = -smallest + axis_points[-size - 1] = smallest axis_points[0] = -np.inf axis_points[-1] = np.inf return axis_points diff --git a/tests/lax_test.py b/tests/lax_test.py index 10fa8c006184..9ef13efc2ed3 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4400,34 +4400,14 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'tanh': regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj') - elif name == 'arcsin': - if is_arm_cpu and platform.system() == 'Darwin': - regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real') - else: - regions_with_inaccuracies.clear() - - elif name == 'arcsinh': - if is_arm_cpu and platform.system() == 'Darwin': - regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', - 'negj.imag', 'posj.imag') - else: - regions_with_inaccuracies.clear() - elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') - elif name == 'log1p': - if is_arm_cpu and platform.system() == 'Darwin': - regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', - 'posj.imag') - else: - regions_with_inaccuracies.clear() - - elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'tan', - 'arcsinh', 'arccosh', 'arctan', 'arctanh', 'square'}: + elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'tan', 'log1p', + 'arcsin', 'arcsinh', 'arccosh', 'arctan', 'arctanh', 'square'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable From 92e18e6d5ce84350a9bab25f49edbde45cbfc45a Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 26 Nov 2024 08:57:35 -0800 Subject: [PATCH 504/698] [AutoPGLE] Fix pgle test after removing pjit cache. PiperOrigin-RevId: 700359385 --- tests/pgle_test.py | 189 ++++++++++++++++++++++++--------------------- 1 file changed, 102 insertions(+), 87 deletions(-) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 46146abfc7c6..ef91c399db16 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -112,47 +112,73 @@ def f(x): fdo_profile = pgle_profiler.consume_fdo_profile() self.assertEqual(fdo_profile.count(b'custom'), its) + def get_fdo_profiles(self, dump_dir): + jit_f_fdo_profiles = [ + x + for x in os.listdir(dump_dir) + if 'jit_f' in x and x.endswith('.fdo_profile') + ] + return jit_f_fdo_profiles + def testAutoPgle(self): mesh = jtu.create_mesh((2,), ('x',)) - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={ - 'xla_gpu_enable_latency_hiding_scheduler': 'True', - # TODO(patrios): Remove this flag once b/376647494 is fixed. - 'xla_gpu_graph_min_graph_size': '100000', - }, - ) - def f(x): - return x * 2 - - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - expected = x * 2 + with tempfile.TemporaryDirectory() as dump_dir: + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # TODO(patrios): Remove this flag once b/376647494 is fixed. + 'xla_gpu_graph_min_graph_size': '100000', + 'xla_dump_to': dump_dir, + 'xla_gpu_experimental_dump_fdo_profiles': 'True' + }, + ) + def f(x): + return x * 2 - with config.pgle_profiling_runs(2), config.enable_pgle(True): - # Run 1: Module should be compiled without FDO. Two modules are expected - # One is the funtion f, the other one is multi slice module - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 2) + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + expected = x * 2 - # Run 2: Second PGLE run should not recompile the module - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertLess(cache_miss_count[0], 2) + with config.pgle_profiling_runs(2), config.enable_pgle(True): + # Run 1: Module should be compiled without FDO. Two modules are expected + # One is the funtion f, the other one is multi slice module + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) - # Run 3: The module should be recompiled with FDO profiles - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertEqual(cache_miss_count[0], 2) + # Run 2: Second PGLE run. Profile should be empty. + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) + # One for before and one for after optimization. + self.assertLen(fdo_profiles_before_pgle, 2) + # The FDO profile file should be empty. + self.assertEqual( + os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) + + # Run 3: The module should be recompiled with FDO profiles + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertEqual(cache_miss_count[0], 2) + fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) + # One for before and one for after optimization. + self.assertLen(fdo_profiles_after_pgle, 4) + + for fdo_profile in fdo_profiles_after_pgle: + if fdo_profile not in fdo_profiles_before_pgle: + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0 + ) - # Run 4: Fast-path should be used after PGLE is done - with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - self.assertArraysEqual(f(x), expected) - self.assertLess(cache_miss_count[0], 2) + # Run 4: Fast-path should be used after PGLE is done + with jtu.count_cached_compilation_cache_miss() as cache_miss_count: + self.assertArraysEqual(f(x), expected) + self.assertLess(cache_miss_count[0], 2) def testAutoPgleWithAot(self): @jax.jit @@ -225,38 +251,27 @@ def f(x): # Run 2: Compilation should not be called with jtu.count_cached_compilation_cache_miss() as cache_miss_count: f(x) - self.assertLess(cache_miss_count[0], 2) + self.assertGreater(cache_miss_count[0], 0) - module_before_pgle = os.listdir(dump_dir) - self.assertNotEmpty(module_before_pgle) + fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) # Run 3: Module should be compiled with FDO and stored to persistent cache with jtu.count_cached_compilation_cache_miss() as cache_miss_count: f(x) self.assertGreater(cache_miss_count[0], 0) # Check if FDO profile file of the biggest module is not empty - module_after_pgle = [ + fdo_profiles_after_pgle = [ x - for x in os.listdir(dump_dir) - if x not in module_before_pgle + for x in self.get_fdo_profiles(dump_dir) + if x not in fdo_profiles_before_pgle ] - self.assertNotEmpty(module_after_pgle) - biggest_module_after_pgle = max( - module_after_pgle, - key=lambda x: os.path.getsize( - os.path.join(dump_dir, x) - ), - ) - base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) + self.assertNotEmpty(fdo_profiles_after_pgle) # Check if FDO profile file in dump directory is not empty - for module in module_after_pgle: - if module.startswith(base_module_name) and module.endswith( - '.fdo_profile' - ): - self.assertGreater( - os.path.getsize(os.path.join(dump_dir, module)), 0 - ) + for fdo_profile in fdo_profiles_after_pgle: + self.assertGreater( + os.path.getsize(os.path.join(dump_dir, fdo_profile)), 0 + ) for pgle_profiler in pjit._pgle_profiler_dict.values(): self.assertTrue(pgle_profiler.is_enabled()) @@ -293,42 +308,42 @@ def check_if_cache_hit(event): self.assertGreater(cache_hit, 0) - def testPassingFDOProfile(self): - mesh = jtu.create_mesh((2,), ('x',)) + def testPassingFDOProfile(self): + mesh = jtu.create_mesh((2,), ('x',)) - @partial( - jax.jit, - in_shardings=NamedSharding(mesh, PartitionSpec('x')), - out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, - ) - def f(x, y): - return x @ y + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, PartitionSpec('x')), + out_shardings=NamedSharding(mesh, PartitionSpec('x')), + compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + ) + def f(x, y): + return x @ y - shape = (16, 16) - x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - y = x + 1 + shape = (16, 16) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 - with config.pgle_profiling_runs(0): - f_lowered = f.lower(x, y) - compiled = f_lowered.compile() + with config.pgle_profiling_runs(0): + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() - with tempfile.TemporaryDirectory() as cache_dir: - jax.profiler.start_trace(cache_dir) - compiled(x, y) - jax.profiler.stop_trace() - directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) - directories = [d for d in directories if os.path.isdir(d)] - rundir = directories[-1] - logging.info('rundir: %s', rundir) - fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) - - if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda(): - self.assertIn(b'custom', fdo_profile) - - logging.info('fdo_profile: %s', fdo_profile) - # Test pass fdo_profile as compiler_options API works. - f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) + with tempfile.TemporaryDirectory() as cache_dir: + jax.profiler.start_trace(cache_dir) + compiled(x, y) + jax.profiler.stop_trace() + directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/')) + directories = [d for d in directories if os.path.isdir(d)] + rundir = directories[-1] + logging.info('rundir: %s', rundir) + fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) + + if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda(): + self.assertIn(b'custom', fdo_profile) + + logging.info('fdo_profile: %s', fdo_profile) + # Test pass fdo_profile as compiler_options API works. + f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) if __name__ == '__main__': From e453fa179efa7cb0c55b707c67785ce4289065b3 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Tue, 26 Nov 2024 09:47:22 -0800 Subject: [PATCH 505/698] Update XLA dependency to use revision PiperOrigin-RevId: 700373062 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 1e3df5fafb4d..a76cdaa3af96 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "7059553f7e215709642e5a5b19274b0e78d4349a" -XLA_SHA256 = "16bf9a4e3e62a5180fddec2526657cd0ba9c2a1a3510458054730e60c9526294" +XLA_COMMIT = "f28020d8e2b523765fbccb084ab03ae37cfbfcf5" +XLA_SHA256 = "725a7f38f52cf60b24ee78e6a5006ef9bf177d399729a0ddbf99b8acc33d93ee" def repo(): tf_http_archive( From 6763fcfb4e15cb8cb3260d713df82c836a08918d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 26 Nov 2024 10:48:12 -0800 Subject: [PATCH 506/698] Fix a weird interaction with `set_local` and empty tuples passed to it. PiperOrigin-RevId: 700392735 --- jax/_src/mesh.py | 12 ++++++++---- jax/_src/pjit.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 214fb190d498..cecc24fb2e1a 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -483,15 +483,19 @@ def __init__(self): def push_mesh_context(val): mesh_context.stack.append(val) mesh_context.mesh = val - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) + # TODO(yashkatariya): Allow setting empty tuples and tuples with None in them. + # Right now that leads to weird numerical issues. + non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) + if non_none_meshes: + jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) return val def pop_mesh_context(): mesh_context.stack.pop() mesh_context.mesh = mesh_context.stack[-1] - jax_config.abstract_mesh_context_manager.set_local( - tuple(m for m in mesh_context.stack if m is not None)) + non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) + if non_none_meshes: + jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) class null_mesh_context: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index b77af1a8f14e..a2c0aff98716 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -709,7 +709,7 @@ def get_abstract_mesh(in_avals): # TODO(yashkatariya): Remove this when mesh context can be set by the user. if m is None: return mesh_lib.null_mesh_context() - assert m is not None + assert isinstance(m, AbstractMesh) return m From dbe34299e4aa945550f892c3c9b819e22b76b7f8 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 26 Nov 2024 13:50:50 -0600 Subject: [PATCH 507/698] Change the workflow for opening upstream PRs to post links that open PRs (#157) * Add GH auth token to env * Make the job post a comment with a link to open the PR instead of actually opening the PR --- .github/workflows/rocm-open-upstream-pr.yml | 28 +++++++++------------ 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml index 96c2d6e8128a..7ae0b7a65109 100644 --- a/.github/workflows/rocm-open-upstream-pr.yml +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -10,11 +10,8 @@ jobs: contents: write pull-requests: write runs-on: ubuntu-latest - outputs: - new-pr-link: ${{ steps.create-pr.outputs.link }} env: NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream" - NEW_PR_TITLE: "[ROCM] ${{ github.event.pull_request.title }}" steps: - name: Checkout code uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -24,19 +21,18 @@ jobs: git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }} git rebase --onto origin/main git push origin HEAD - # TODO: Change the base of the PR to upstream main - - name: Create a PR to upstream - id: create-pr - run: | - echo link=$(gh pr create --repo rocm/jax --base main --head $NEW_BRANCH_NAME --title "$NEW_PR_TITLE" --body "${{ github.event.pull_request.body }}") >> "$GITHUB_OUTPUT" - comment-link: - needs: open-upstream - permissions: - pull-requests: write - runs-on: ubuntu-latest - steps: - - name: Leave comment on old PR + - name: Leave link to create PR env: GH_TOKEN: ${{ github.token }} - run: gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body ${{ needs.open-upstream.outputs.new-pr-link.link }} + run: | + # Bash is not friendly with newline characters, so make our own + NL=$'\n' + # Encode the PR title and body for passing as URL get parameters + TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri') + BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: ${{ github.event.pull_request.url }}" '$x|@uri') + # Create a link to the that will open up a new PR form to upstream and autofill the fields + CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC" + # Add a comment with the link to the PR + COMMENT_BODY="Feature branch from main is ready. [Create a new PR]($CREATE_PR_LINK) destined for upstream?" + gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY" From cc51fda35f50dbc6a009a934de31775144dcb1c9 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 26 Nov 2024 14:50:28 -0600 Subject: [PATCH 508/698] Fix rebase command to exclude rocm-main (#158) --- .github/workflows/rocm-open-upstream-pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml index 7ae0b7a65109..674696f3a286 100644 --- a/.github/workflows/rocm-open-upstream-pr.yml +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -19,7 +19,7 @@ jobs: run: | git fetch git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }} - git rebase --onto origin/main + git rebase --onto origin/main origin/rocm-main git push origin HEAD - name: Leave link to create PR env: From 5f3c134167276979e3fca118cb20991db60f3d83 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 26 Nov 2024 14:57:59 -0600 Subject: [PATCH 509/698] Fix user identity for rebase (#159) --- .github/workflows/rocm-open-upstream-pr.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml index 674696f3a286..c9f11883ba7f 100644 --- a/.github/workflows/rocm-open-upstream-pr.yml +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -17,6 +17,8 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Rebase code to main run: | + git config --global user.email "github-actions@github.com" + git config --global user.name "Github Actions" git fetch git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }} git rebase --onto origin/main origin/rocm-main From d30ec2b5b30a91d6dc0e65f145725ceeba1bc9da Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Tue, 22 Oct 2024 10:35:03 -0500 Subject: [PATCH 510/698] [ROCm] fix jax and wheelhouse relative paths --- build/rocm/ci_build | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 1ec5c6e7f36f..c64c62f2c3a0 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -89,9 +89,9 @@ def dist_wheels( mounts = [ "-v", - "./:/jax", + os.path.abspath("./") + ":/jax", "-v", - "./wheelhouse:/wheelhouse", + os.path.abspath("./wheelhouse") + ":/wheelhouse", ] if xla_path: @@ -210,7 +210,7 @@ def test(image_name): # JAX and jaxlib are already installed from wheels mounts = [ "-v", - "./:/jax", + os.path.abspath("./") + ":/jax", ] cmd.extend(mounts) From 694de6b64cdba4f348f346da43b0f9210f112e9c Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Tue, 26 Nov 2024 11:27:32 -0600 Subject: [PATCH 511/698] [ROCm] Change run_multi_gpu set opts --- build/rocm/run_multi_gpu.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/rocm/run_multi_gpu.sh b/build/rocm/run_multi_gpu.sh index b5d5798e7920..aa1d4d0f38ed 100755 --- a/build/rocm/run_multi_gpu.sh +++ b/build/rocm/run_multi_gpu.sh @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -eu +set -xu # Function to run tests with specified GPUs run_tests() { From bbaec6ea59ba312db7f42a843190ff8896ff8739 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Tue, 26 Nov 2024 13:30:31 -0800 Subject: [PATCH 512/698] [JAX] Add Python binding for building a colocated Python program This change adds a Python binding that makes `ifrt::CustomCallProgram` for a colocated Python program. This Python binding will be used internally in the colocated Python API implementation. The API does not yet compile the program into an executable, which will be added separately. PiperOrigin-RevId: 700443656 --- jax/BUILD | 1 + jax/experimental/colocated_python/func.py | 12 +++++++++--- tests/BUILD | 1 + tests/colocated_python_test.py | 20 ++++++++++++++++++-- 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index d35ff0e399a6..4260a8a1acb2 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1198,5 +1198,6 @@ pytype_library( ":util", ":xla_bridge", "//jax/_src/lib", + "//jax/extend:ifrt_programs", ] + py_deps("numpy") + py_deps("cloudpickle"), ) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 3e95ddf03c7e..6639e7eefdd6 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -28,7 +28,8 @@ from jax._src.traceback_util import api_boundary from jax._src.util import wraps from jax.experimental.colocated_python import func_backend -from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs +from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs +from jax.extend.ifrt_programs import ifrt_programs ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] @@ -141,8 +142,13 @@ def _compile_to_executable( devices: xc.DeviceList, ) -> Callable[..., Any]: """Compiles a Python function into a runtime executable.""" - # TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an - # executable. + pickled_function = _serialize(fun) + program = ifrt_programs.make_colocated_python_program( + name, pickled_function, devices, in_specs_leaves, out_specs_leaves + ) + # TODO(hyeontaek): Compile the program and use the executable. + del program + del name del in_specs_leaves del out_specs_leaves diff --git a/tests/BUILD b/tests/BUILD index f0668b42b309..92a6ed99ceca 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1387,6 +1387,7 @@ jax_multiplatform_test( srcs = ["colocated_python_test.py"], deps = [ "//jax:experimental_colocated_python", + "//jax/extend:ifrt_programs", ], ) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 9f65e3aeced4..f86a68a998f3 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -22,6 +22,8 @@ from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member from jax.experimental import colocated_python from jax.experimental.colocated_python import func as colocated_python_func +from jax.experimental.colocated_python import serialization +from jax.extend.ifrt_programs import ifrt_programs import jax.numpy as jnp import numpy as np @@ -77,8 +79,22 @@ class ColocatedPythonTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_extension_version < 290: - self.skipTest("Requires xla_extension_version >= 290") + if xla_extension_version < 298: + self.skipTest("Requires xla_extension_version >= 298") + + def testMakeColocatedPythonProgram(self): + def add_one(x): + return x + 1 + + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) + aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) + + pickled_function = serialization._serialize(add_one) + program = ifrt_programs.make_colocated_python_program( + "add_one", pickled_function, [cpu_devices[0]], [aval], [aval] + ) + del program def testSimpleFunction(self): @colocated_python.colocated_python From f3cfe477c8da62a5be74cc87f4c4787408136967 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 26 Nov 2024 15:37:17 -0600 Subject: [PATCH 513/698] Fix the link to the downstream PR (#160) --- .github/workflows/rocm-open-upstream-pr.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml index c9f11883ba7f..a8748d2d84f6 100644 --- a/.github/workflows/rocm-open-upstream-pr.yml +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -22,7 +22,8 @@ jobs: git fetch git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }} git rebase --onto origin/main origin/rocm-main - git push origin HEAD + # Force push here so that we don't run into conflicts with the origin branch + git push origin HEAD --force - name: Leave link to create PR env: GH_TOKEN: ${{ github.token }} @@ -31,7 +32,7 @@ jobs: NL=$'\n' # Encode the PR title and body for passing as URL get parameters TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri') - BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: ${{ github.event.pull_request.url }}" '$x|@uri') + BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri') # Create a link to the that will open up a new PR form to upstream and autofill the fields CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC" # Add a comment with the link to the PR From c835a78d1dc2011385ce88ddbdb3d87c2b610f77 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 26 Nov 2024 16:27:17 -0600 Subject: [PATCH 514/698] Use the reference format for links instead of inline (#162) --- .github/workflows/rocm-open-upstream-pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-open-upstream-pr.yml b/.github/workflows/rocm-open-upstream-pr.yml index a8748d2d84f6..bd14fa050577 100644 --- a/.github/workflows/rocm-open-upstream-pr.yml +++ b/.github/workflows/rocm-open-upstream-pr.yml @@ -36,6 +36,6 @@ jobs: # Create a link to the that will open up a new PR form to upstream and autofill the fields CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC" # Add a comment with the link to the PR - COMMENT_BODY="Feature branch from main is ready. [Create a new PR]($CREATE_PR_LINK) destined for upstream?" + COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK" gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY" From 8df2766466add14025d13b5603898ea0943788b3 Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Tue, 26 Nov 2024 16:22:24 -0600 Subject: [PATCH 515/698] Add argument to override base docker in dockerfile --- build/rocm/Dockerfile.ms | 3 ++- build/rocm/ci_build | 9 +++++++++ build/rocm/ci_build.sh | 7 ++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index e20291cefd63..575dce87664e 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -1,5 +1,6 @@ ################################################################################ -FROM ubuntu:20.04 AS rocm_base +ARG BASE_DOCKER=ubuntu:22.04 +FROM $BASE_DOCKER AS rocm_base ################################################################################ RUN --mount=type=cache,target=/var/cache/apt \ diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 1ec5c6e7f36f..f3b8ae401649 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -143,6 +143,7 @@ def _fetch_jax_metadata(xla_path): def dist_docker( rocm_version, + base_docker, python_versions, xla_path, rocm_build_job="", @@ -168,6 +169,7 @@ def dist_docker( "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, + "--build-arg=BASE_DOCKER=%s" % base_docker, "--build-arg=PYTHON_VERSION=%s" % python_version, "--build-arg=JAX_VERSION=%(jax_version)s" % md, "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, @@ -231,6 +233,12 @@ def test(image_name): def parse_args(): p = argparse.ArgumentParser() + p.add_argument( + "--base-docker", + default="", + help="Argument to override base docker in dockerfile", + ) + p.add_argument( "--python-versions", type=lambda x: x.split(","), @@ -308,6 +316,7 @@ def main(): ) dist_docker( args.rocm_version, + args.base_docker, args.python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 0a50b5845d69..386f70ee1a96 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -48,7 +48,7 @@ PYTHON_VERSION="3.10" ROCM_VERSION="6.1.3" ROCM_BUILD_JOB="" ROCM_BUILD_NUM="" -BASE_DOCKER="ubuntu:20.04" +BASE_DOCKER="ubuntu:22.04" CUSTOM_INSTALL="" JAX_USE_CLANG="" POSITIONAL_ARGS=() @@ -90,6 +90,10 @@ while [[ $# -gt 0 ]]; do ROCM_BUILD_NUM="$2" shift 2 ;; + --base_docker) + BASE_DOCKER="$2" + shift 2 + ;; --use_clang) JAX_USE_CLANG="$2" shift 2 @@ -154,6 +158,7 @@ fi # which is the ROCm image that is shipped for users to use (i.e. distributable). ./build/rocm/ci_build \ --rocm-version $ROCM_VERSION \ + --base-docker $BASE_DOCKER \ --python-versions $PYTHON_VERSION \ --xla-source-dir=$XLA_CLONE_DIR \ --rocm-build-job=$ROCM_BUILD_JOB \ From 3d8063209ebfac675eac2142f5097ba71a5f4fff Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Tue, 26 Nov 2024 16:35:53 -0600 Subject: [PATCH 516/698] Update http to https in amd artifactory url. --- build/rocm/tools/get_rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py index 993c2f94b558..d9cb9ea3f25a 100644 --- a/build/rocm/tools/get_rocm.py +++ b/build/rocm/tools/get_rocm.py @@ -229,7 +229,7 @@ def _build_installer_url(rocm_version, metadata): rv = parse_version(rocm_version) - base_url = "http://artifactory-cdn.amd.com/artifactory/list" + base_url = "https://artifactory-cdn.amd.com/artifactory/list" if md["ID"] == "ubuntu": fmt = "amdgpu-install-internal_%(rocm_major)s.%(rocm_minor)s-%(os_version)s-1_all.deb" From 9c423796334a6cd5dce893c26d638bc27149d580 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 26 Nov 2024 14:43:23 -0800 Subject: [PATCH 517/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/9d5bac13a3a9284959cf85e9bcb959eb147151cf. PiperOrigin-RevId: 700465720 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index a76cdaa3af96..01f5f862db4d 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "f28020d8e2b523765fbccb084ab03ae37cfbfcf5" -XLA_SHA256 = "725a7f38f52cf60b24ee78e6a5006ef9bf177d399729a0ddbf99b8acc33d93ee" +XLA_COMMIT = "9d5bac13a3a9284959cf85e9bcb959eb147151cf" +XLA_SHA256 = "f3681cee585ac605d43a7981254ab480a5f4202cc0ab30151d7f2151bddb0d15" def repo(): tf_http_archive( From 10fdee34d6585f4a2974f15f0770a07a217903d5 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 26 Nov 2024 15:08:25 -0800 Subject: [PATCH 518/698] Move `tsl/platform/{build_config,build_config_root,rules_cc}.bzl` to `xla/tsl/platform` PiperOrigin-RevId: 700472724 --- jaxlib/jax.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 2bae7ab2a203..976e5f26cb4b 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,8 +20,8 @@ load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_roc load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") load("@rules_python//python:defs.bzl", "py_test") -load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") +load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") # Explicitly re-exports names to avoid "unused variable" warnings from .bzl # lint tools. From afcef6779168a814392022acaf011112d5c6a503 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 18:15:32 -0800 Subject: [PATCH 519/698] Install git before actions/checkout This fixes the workflow failing at "Build and install JAX" step as it wasn't able to run git command to fetch the `jaxlib` git hash Without git present on the PATH, it seems that `actions/checkout` (from its logs) will download the code with the GitHub REST API. This results in the code not being a git repository and therefore any subsequent git commands fail. PiperOrigin-RevId: 700518101 --- .github/workflows/asan.yaml | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index d261ba3a09c2..d0d889729448 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -25,14 +25,8 @@ jobs: run: shell: bash -l {0} steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - path: jax - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: python/cpython - path: cpython - ref: v3.13.0 + # Install git before actions/checkout as otherwise it will download the code with the GitHub + # REST API and therefore any subsequent git commands will fail. - name: Install clang 18 env: DEBIAN_FRONTEND: noninteractive @@ -42,6 +36,14 @@ jobs: zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: jax + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: python/cpython + path: cpython + ref: v3.13.0 - name: Build CPython with ASAN enabled env: ASAN_OPTIONS: detect_leaks=0 From c6866d05dba096f4be67c0368926eff7ffcbe121 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 18:16:47 -0800 Subject: [PATCH 520/698] Add a check for return codes of `executor.run` so that we propagate error codes correctly PiperOrigin-RevId: 700518396 --- build/build.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/build/build.py b/build/build.py index 12ad0fa3b011..0eb6fd9a83ee 100755 --- a/build/build.py +++ b/build/build.py @@ -399,8 +399,11 @@ async def main(): else: requirements_command.append("//build:requirements.update") - await executor.run(requirements_command.get_command_as_string(), args.dry_run) - sys.exit(0) + result = await executor.run(requirements_command.get_command_as_string(), args.dry_run) + if result.return_code != 0: + raise RuntimeError(f"Command failed with return code {result.return_code}") + else: + sys.exit(0) wheel_cpus = { "darwin_arm64": "arm64", @@ -594,7 +597,11 @@ async def main(): wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") - await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) + result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) + if result.return_code != 0: + raise RuntimeError(f"Command failed with return code {result.return_code}") + else: + sys.exit(0) if __name__ == "__main__": From 13726690ddc72acb78b2c61411ffaab89e2a6b0a Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 18:41:22 -0800 Subject: [PATCH 521/698] Add new CI script to run Bazel GPU (non-RBE) jobs This commit adds the CI script needed for running Bazel GPU (non-RBE) tests. These run two Bazel commands: Single accelerator tests with one GPU a piece and multi-accelerator tests with all GPUs PiperOrigin-RevId: 700523594 --- ci/envs/default.env | 16 +++++- ci/run_bazel_test_cpu_rbe.sh | 4 +- ci/run_bazel_test_gpu_non_rbe.sh | 87 ++++++++++++++++++++++++++++++++ ci/run_bazel_test_gpu_rbe.sh | 2 +- 4 files changed, 105 insertions(+), 4 deletions(-) create mode 100755 ci/run_bazel_test_gpu_non_rbe.sh diff --git a/ci/envs/default.env b/ci/envs/default.env index f50b7549b823..e3bf1a5ab47a 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -42,4 +42,18 @@ export JAXCI_OUTPUT_DIR="$(pwd)/dist" # When enabled, artifacts will be built with RBE. Requires gcloud authentication # and only certain platforms support RBE. Therefore, this flag is enabled only # for CI builds where RBE is supported. -export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} \ No newline at end of file +export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} + +# ############################################################################# +# Test script specific environment variables. +# ############################################################################# +# The maximum number of tests to run per GPU when running single accelerator +# tests with parallel execution with Bazel. The GPU limit is set because we +# need to allow about 2GB of GPU RAM per test. Default is set to 12 because we +# use L4 machines which have 24GB of RAM but can be overriden if we use a +# different GPU type. +export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12} + +# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override +# this value in the Github action workflow files. +export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 6ba9f6dce239..248111e0247a 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -50,7 +50,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ) --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ --test_env=JAX_NUM_GENERATED_CASES=25 \ --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64=0 \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ //tests:cpu_tests //tests:backend_independent_tests @@ -61,7 +61,7 @@ else --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ --test_env=JAX_NUM_GENERATED_CASES=25 \ --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64=0 \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ //tests:cpu_tests //tests:backend_independent_tests diff --git a/ci/run_bazel_test_gpu_non_rbe.sh b/ci/run_bazel_test_gpu_non_rbe.sh new file mode 100755 index 000000000000..7828cf41c60e --- /dev/null +++ b/ci/run_bazel_test_gpu_non_rbe.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Run Bazel GPU tests without RBE. This runs two commands: single accelerator +# tests with one GPU a piece, multiaccelerator tests with all GPUS. +# Requires that jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels are stored +# inside the ../dist folder +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel GPU tests (single accelerator and multiaccelerator tests) directly +# on the VM without RBE. +nvidia-smi +echo "Running single accelerator tests (without RBE)..." + +# Set up test environment variables. +export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +export num_test_jobs=$((gpu_count * JAXCI_MAX_TESTS_PER_GPU)) +export num_cpu_cores=$(nproc) + +# tests_jobs = max(gpu_count * max_tests_per_gpu, num_cpu_cores) +if [[ $num_test_jobs -gt $num_cpu_cores ]]; then + num_test_jobs=$num_cpu_cores +fi +# End of test environment variables setup. + +# Runs single accelerator tests with one GPU apiece. +# It appears --run_under needs an absolute path. +# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR` +# should match the VM's CPU core count (set in `--local_test_jobs`). +bazel test --config=ci_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --//jax:build_jaxlib=false \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_env=JAX_ACCELERATOR_COUNT=$gpu_count \ + --test_env=JAX_TESTS_PER_ACCELERATOR=$JAXCI_MAX_TESTS_PER_GPU \ + --local_test_jobs=$num_test_jobs \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests + +echo "Running multi-accelerator tests (without RBE)..." +# Runs multiaccelerator tests with all GPUs directly on the VM without RBE.. +bazel test --config=ci_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --//jax:build_jaxlib=false \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --jobs=8 \ + --test_tag_filters=multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests/pallas:gpu_tests \ No newline at end of file diff --git a/ci/run_bazel_test_gpu_rbe.sh b/ci/run_bazel_test_gpu_rbe.sh index 0c004c584300..17bd8d9db4f8 100755 --- a/ci/run_bazel_test_gpu_rbe.sh +++ b/ci/run_bazel_test_gpu_rbe.sh @@ -46,6 +46,6 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ --test_tag_filters=-multiaccelerator \ --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64=0 \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file From 0d2dfea4b1c0ac1bb4a4493c158d83d2fae4707c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 26 Nov 2024 20:00:19 -0800 Subject: [PATCH 522/698] Add a private `set_mesh` API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet). Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on. PiperOrigin-RevId: 700537898 --- jax/_src/config.py | 9 ++++- jax/_src/core.py | 2 +- jax/_src/interpreters/pxla.py | 11 +++++-- jax/_src/mesh.py | 62 +++++++++++++++++++++++++++-------- jax/_src/pjit.py | 4 +-- tests/pjit_test.py | 26 +++++++++++++-- 6 files changed, 92 insertions(+), 22 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 43c29c996cfb..eea686fa0f3c 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -212,6 +212,7 @@ def trace_context(): return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value, abstract_mesh_context_manager.value, + device_context.value, compute_on_context_manager.value, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, @@ -245,6 +246,7 @@ def trace_context(): axis_env_state = () mesh_context_manager = () abstract_mesh_context_manager = () + device_context = () xla_metadata_context_manager = () compute_on_context_manager = () @@ -255,12 +257,14 @@ def trace_context(): mesh_context_manager = context.mesh_context_manager if context and context.abstract_mesh_context_manager: abstract_mesh_context_manager = context.abstract_mesh_context_manager + if context and context.device_context: + device_context = context.device_context if context and context.xla_metadata_context_manager: xla_metadata_context_manager = context.xla_metadata_context_manager if context and context.compute_on_context_manager: compute_on_context_manager = context.compute_on_context_manager return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager, - xla_metadata_context_manager, + device_context, xla_metadata_context_manager, compute_on_context_manager, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, @@ -976,6 +980,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: axis_env_state = config_ext.Config((), include_in_jit_key=True) mesh_context_manager = config_ext.Config((), include_in_jit_key=True) abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True) + device_context = config_ext.Config((), include_in_jit_key=True) compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) else: @@ -1019,6 +1024,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): axis_env_state: Hashable = () mesh_context_manager: Hashable = () abstract_mesh_context_manager: Hashable = () + device_context: Hashable = () compute_on_context_manager: Hashable = () xla_metadata_context_manager: Hashable = () @@ -1086,6 +1092,7 @@ def set_local(self, value): axis_env_state = JitConfig('axis_env_state') mesh_context_manager = JitConfig('mesh_context_manager') abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager') + device_context = JitConfig('device_context') compute_on_context_manager = JitConfig('compute_on_context_manager') xla_metadata_context_manager = JitConfig('xla_metadata_context_manager') diff --git a/jax/_src/core.py b/jax/_src/core.py index 86646faa980b..122b7bcf5eb2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1605,7 +1605,7 @@ def get_sharding(sharding, ndim): assert len(sharding.spec) == ndim return sharding - context_mesh = mesh_lib.mesh_context.mesh + context_mesh = mesh_lib.abstract_mesh_context.mesh # TODO(yashkatariya): Error out and ask users to set the context mesh in their # code. if context_mesh is None: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2164c1a914c9..11df2d38f21d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2193,8 +2193,15 @@ def lower_sharding_computation( assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) - devices_from_context = (None if context_mesh is None or context_mesh.empty - else context_mesh._flat_devices_tuple) + if config.sharding_in_types.value: + # TODO(yashkatariya): Thread it via jit path and remove the None check by + # making tests go via set_mesh API always. + devices_from_context = ( + None if mesh_lib.device_context.concrete_mesh is None + else mesh_lib.device_context.concrete_mesh._flat_devices_tuple) + else: + devices_from_context = (None if context_mesh is None or context_mesh.empty + else context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. unique_intermediate_shardings = util.stable_unique( diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index cecc24fb2e1a..a83287d5ecad 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -455,10 +455,10 @@ def local_mesh(self): _raise_value_error("local_mesh") def __enter__(self): - return push_mesh_context(self) + return push_abstract_mesh_context(self) def __exit__(self, exc_type, exc_value, traceback): - pop_mesh_context() + pop_abstract_mesh_context() return False @staticmethod @@ -473,27 +473,29 @@ def _raise_value_error(name): raise ValueError(f"AbstractMesh does not implement {name}") -class MeshContext(threading.local): +class AbstractMeshContext(threading.local): def __init__(self): self.stack = [None] self.mesh = self.stack[-1] -mesh_context = MeshContext() +abstract_mesh_context = AbstractMeshContext() -def push_mesh_context(val): - mesh_context.stack.append(val) - mesh_context.mesh = val +def push_abstract_mesh_context(val): + abstract_mesh_context.stack.append(val) + abstract_mesh_context.mesh = val # TODO(yashkatariya): Allow setting empty tuples and tuples with None in them. # Right now that leads to weird numerical issues. - non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) + non_none_meshes = tuple(m for m in abstract_mesh_context.stack + if m is not None) if non_none_meshes: jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) return val -def pop_mesh_context(): - mesh_context.stack.pop() - mesh_context.mesh = mesh_context.stack[-1] - non_none_meshes = tuple(m for m in mesh_context.stack if m is not None) +def pop_abstract_mesh_context(): + abstract_mesh_context.stack.pop() + abstract_mesh_context.mesh = abstract_mesh_context.stack[-1] + non_none_meshes = tuple(m for m in abstract_mesh_context.stack + if m is not None) if non_none_meshes: jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) @@ -501,8 +503,40 @@ def pop_mesh_context(): class null_mesh_context: def __enter__(self): - return push_mesh_context(None) + return push_abstract_mesh_context(None) def __exit__(self, *excinfo): - pop_mesh_context() + pop_abstract_mesh_context() return False + + +@contextlib.contextmanager +def set_mesh(mesh: Mesh): + with (mesh.abstract_mesh, jax_config.sharding_in_types(True), + enter_device_context(mesh)): + yield + + +class DeviceContext(threading.local): + def __init__(self): + self.stack = [None] + self.concrete_mesh = self.stack[-1] + +device_context = DeviceContext() + + +@contextlib.contextmanager +def enter_device_context(mesh: Mesh): + device_context.stack.append(mesh) + device_context.concrete_mesh = mesh + non_none_meshes = tuple(m for m in device_context.stack if m is not None) + if non_none_meshes: + jax_config.device_context.set_local(non_none_meshes) + try: + yield + finally: + device_context.stack.pop() + device_context.concrete_mesh = device_context.stack[-1] + non_none_meshes = tuple(m for m in device_context.stack if m is not None) + if non_none_meshes: + jax_config.device_context.set_local(non_none_meshes) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a2c0aff98716..5d3fb5a3c67c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -644,8 +644,8 @@ def _infer_params_impl( attr_token = _attr_token(flat_fun, in_type) abstract_mesh = ( - get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None - else mesh_lib.mesh_context.mesh) + get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None + else mesh_lib.abstract_mesh_context.mesh) with abstract_mesh: jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, dbg, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8c362af19d0b..e541c6346666 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4622,6 +4622,28 @@ def f(x): ins, _ = f.lower(np.arange(8)).compile().input_shardings self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) + def test_sharding_in_types_with_set_mesh(self): + if config.use_shardy_partitioner.value: + self.skipTest("ShiT doesn't work with shardy") + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + with mesh_lib.set_mesh(mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + self.assertEqual(x.sharding.spec, s.spec) + x = x * 2 + self.assertEqual(x.sharding.spec, s.spec) + x = x * x + self.assertEqual(x.sharding.spec, s.spec) + return x + + out = f(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") @@ -5229,7 +5251,7 @@ def test_shard_map_full_manual(self): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) - self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective) return x * y @jax.jit @@ -5254,7 +5276,7 @@ def test_shard_map_dot(self): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) - self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y return jax.lax.psum(z, axis_name='y') From 47d1960926c834695f876abe26c8428d23785dc2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 26 Nov 2024 21:01:06 -0800 Subject: [PATCH 523/698] Update the render documentation job to use the new self-hosted runners PiperOrigin-RevId: 700550934 --- .github/workflows/ci-build.yaml | 11 ++++++++--- jax/_src/scipy/linalg.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 2b555d492644..9828a160f2e1 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -144,13 +144,19 @@ jobs: documentation_render: name: Documentation - render documentation - runs-on: ubuntu-latest + runs-on: linux-x86-n2-16 + container: + image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 10 strategy: matrix: python-version: ['3.10'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Image Setup + run: | + apt update + apt install -y libssl-dev libsqlite3-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: @@ -170,8 +176,7 @@ jobs: pip install -r docs/requirements.txt - name: Render documentation run: | - sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html - + sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html jax2tf_test: name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 1c5eba988e6a..2e3632700759 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -1679,7 +1679,7 @@ def polar(a: ArrayLike, side: str = 'right', *, method: str = 'qdwh', eps: float U is a Unitary Matrix: - >>> jnp.round(U.T @ U) + >>> jnp.round(U.T @ U) # doctest: +SKIP Array([[ 1., -0., -0.], [-0., 1., 0.], [-0., 0., 1.]], dtype=float32) From 7a2070e7da00aa1fb2ef79b4b9c0d0f445eae248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 27 Nov 2024 00:27:14 -0800 Subject: [PATCH 524/698] [Mosaic:TPU] Enable broadcast from 1-D vectors PiperOrigin-RevId: 700592669 --- jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index c0b2c6c96e7e..1be81e733161 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1110,12 +1110,10 @@ class VectorLayoutInferer { return success(); } if (auto src_ty = dyn_cast(some_src_ty)) { - TPU_CHECK_OP(src_ty.getRank() >= 2, "source rank below 2D unsupported"); - TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported"); auto some_layout = getLayout(op.getSource()); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); auto &layout = *some_layout; - if (layout.implicit_dim() != ImplicitDim::kNone) { + if (layout.implicit_dim() != ImplicitDim::kNone && src_ty.getRank() > 1) { VectorLayout layout_2d(layout.bitwidth(), layout.offsets(), layout.tiling(), ImplicitDim::kNone); if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) { From 7f14de0469f1bad8d1ad65fad5f1d85fe8ded469 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 27 Nov 2024 04:24:37 -0800 Subject: [PATCH 525/698] [mosaic_gpu] Warmup before measuring the running time in `profiler.measure` PiperOrigin-RevId: 700650380 --- jax/experimental/mosaic/gpu/profiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 0594e9239be7..e51a7b842931 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -104,6 +104,7 @@ def run(*args, **kwargs): raise ValueError("Can only measure functions with at least one output") return outs, _event_elapsed(start_event, end_event) + jax.block_until_ready(run(*args, **kwargs)) # Warmup. outs, elapsed = run(*args, **kwargs) return outs, float(elapsed) From 03b6945ee7ddeef4d9e2f24aad926e9b57904814 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 27 Nov 2024 07:07:23 -0800 Subject: [PATCH 526/698] Integrate LLVM at llvm/llvm-project@b214ca82daee Updates LLVM usage to match [b214ca82daee](https://github.com/llvm/llvm-project/commit/b214ca82daee) PiperOrigin-RevId: 700689999 --- jaxlib/mosaic/dialect/tpu/tpu.td | 80 ++++++++++++++++---------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 8a4f573bce24..5dad1309ae91 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -44,7 +44,7 @@ class TPU_Attr traits = []> } // TODO(b/369418606): Find out the way to verify vreg size. -def TPU_Vreg : Type; +def TPU_Vreg : Type; class TPU_Type traits = []> : TypeDef { @@ -179,8 +179,8 @@ def TPU_ReductionKindAttr } def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> { - let arguments = (ins AnyVector:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); - let results = (outs AnyVector:$output); + let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) }]; @@ -217,11 +217,11 @@ def TPU_LoadOp : TPU_Op<"load"> { // TODO(jevinjiang): migrate tpu.strided_store to general vector store op. def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { let arguments = (ins - AnyVector:$valueToStore, + AnyVectorOfNonZeroRank:$valueToStore, AnyMemRef:$base, Variadic:$indices, DenseI32ArrayAttr:$strides, - Optional:$mask // Elementwise mask. + Optional:$mask // Elementwise mask. ); let results = (outs); let assemblyFormat = [{ @@ -236,7 +236,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> { Variadic:$indices, DenseI32ArrayAttr:$strides ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) }]; @@ -245,7 +245,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> { def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let arguments = (ins - AnyVector:$valueToStore, + AnyVectorOfNonZeroRank:$valueToStore, AnyMemRef:$base, Variadic:$indices, DenseI32ArrayAttr:$strides @@ -291,7 +291,7 @@ def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { // TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let arguments = (ins - AnyVector:$value, + AnyVectorOfNonZeroRank:$value, SI32Attr:$amount, SI32Attr:$dimension, // When the stride is specified, the rotation amount for each index on the @@ -299,7 +299,7 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { OptionalAttr:$stride, OptionalAttr:$stride_dimension ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value) }]; @@ -308,7 +308,7 @@ def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { let arguments = (ins - AnyVector:$value, + AnyVectorOfNonZeroRank:$value, I32:$amount, SI32Attr:$dimension, // When the stride is specified, the rotation amount for each index on the @@ -316,7 +316,7 @@ def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { OptionalAttr:$stride, OptionalAttr:$stride_dimension ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) }]; @@ -325,7 +325,7 @@ def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { def TPU_IotaOp : TPU_Op<"iota", [Pure]> { let arguments = (ins OptionalAttr:$dimension); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ attr-dict `:` type($output) }]; } @@ -333,22 +333,22 @@ def TPU_IotaOp : TPU_Op<"iota", [Pure]> { // b/376295711 def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { let arguments = (ins - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, I32Attr:$dimension, I32Attr:$times ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }]; } def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { let arguments = (ins - AnyVector:$source, // All sublanes should be equal. + AnyVectorOfNonZeroRank:$source, // All sublanes should be equal. I32Attr:$lane // Coordinates of the first element to take. ); // Output shape should be the same, except for position dim which contains // the newly inserted dimension. - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $lane attr-dict `:` type($source) `->` type($output) }]; @@ -357,30 +357,30 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { // Integer unpacks are always signed at the moment. def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { let arguments = (ins - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, I32Attr:$index ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; } // Integer packs are always signed at the moment. def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> { let arguments = (ins - Variadic:$sources, + Variadic:$sources, TPU_PackFormatEnum:$pack_format ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; } def TPU_GatherOp : TPU_Op<"gather", [Pure]> { let arguments = (ins - AnyVector:$source, + AnyVectorOfNonZeroRank:$source, DenseI32ArrayAttr:$indices, I32Attr:$dimension ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `[` $indices `]` `in` $dimension attr-dict `:` type($source) `->` type($output) @@ -389,11 +389,11 @@ def TPU_GatherOp : TPU_Op<"gather", [Pure]> { def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> { let arguments = (ins - AnyVector:$source, - AnyVector:$indices, // If this is 2D, only the first row matters. + AnyVectorOfNonZeroRank:$source, + AnyVectorOfNonZeroRank:$indices, // If this is 2D, only the first row matters. I32Attr:$dimension ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `[` $indices `]` `in` $dimension attr-dict `:` type($source) `,` type($indices) `->` type($output) @@ -424,9 +424,9 @@ def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension // TODO(apaszke): Think hard about precision def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { let arguments = (ins - AnyVector:$lhs, - AnyVector:$rhs, - AnyVector:$acc, + AnyVectorOfNonZeroRank:$lhs, + AnyVectorOfNonZeroRank:$rhs, + AnyVectorOfNonZeroRank:$acc, // These flags are deprecated - if dimension_numbers are defined, // these flags are ignored. They will always be false after canonicalize. DefaultValuedAttr:$transpose_lhs, @@ -435,7 +435,7 @@ def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { // NOTE: User-level optional, once canonicalized, always present. OptionalAttr:$dimension_numbers ); - let results = (outs AnyVector:$result); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result) }]; @@ -445,10 +445,10 @@ def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> { let arguments = (ins - Variadic:$sources, + Variadic:$sources, I32Attr:$dimension ); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $sources `in` $dimension attr-dict `:` type($sources) `->` type($output) }]; @@ -456,8 +456,8 @@ def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure]> { } def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> { - let arguments = (ins AnyVector:$input); - let results = (outs AnyVector:$output); + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; let hasVerifier = 1; } @@ -469,16 +469,16 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> { } def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { - let arguments = (ins Variadic:$input); - let results = (outs AnyVector:$output); + let arguments = (ins Variadic:$input); + let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; } def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> { - let arguments = (ins AnyVector:$input); - let results = (outs Variadic:$output); + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs Variadic:$output); let hasCanonicalizeMethod = 1; let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) @@ -722,8 +722,8 @@ def TPU_DelayOp : TPU_Op<"delay"> { // Expands the granularity of mask to subelements. def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { - let arguments = (ins AnyVector:$input); - let results = (outs AnyVector:$result); + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs AnyVectorOfNonZeroRank:$result); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($result) }]; @@ -749,7 +749,7 @@ def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { let arguments = (ins); - let results = (outs AnyVector:$output); + let results = (outs AnyVectorOfNonZeroRank:$output); } def TPU_LogOp : TPU_Op<"log"> { From f3acfa93bb66e91d975ab0c4eb823b033f55cd36 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 27 Nov 2024 08:20:11 -0800 Subject: [PATCH 527/698] [mgpu] FragentedArray.foreach() can now optionally return a new array PiperOrigin-RevId: 700708119 --- .../mosaic/gpu/fragmented_array.py | 22 ++++++++++--- tests/mosaic/gpu_test.py | 33 +++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index b16cd26da271..6b288906e967 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1243,15 +1243,29 @@ def select(self, on_true, on_false): lambda t, p, f: arith.select(p, t, f), self, on_false, ) - def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): + def foreach( + self, + fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None], + *, + create_array=False, + is_signed=None, + ): """Call a function for each value and index.""" index = ir.IndexType.get() - for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True): - assert len(idx) == len(self.shape), (idx, self.shape) + new_regs = None + if create_array: + new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type)) + for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): + reg = self.registers[reg_idx] + assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) [elems] = ir.VectorType(reg.type).shape for i in range(elems): i = c(i, index) - fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i))) + val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if create_array: + new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + + return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) def store_untiled(self, ref: ir.Value): if not ir.MemRefType.isinstance(ref.type): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1c7f22f885ec..7dadc71fdcba 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1361,6 +1361,39 @@ def kernel(ctx, dst, _): rhs = rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs)) + def test_foreach(self): + dtype = jnp.int32 + swizzle = 128 + tile = 64, swizzle // jnp.dtype(dtype).itemsize + shape = 128, 192 + tiled_shape = mgpu.tile_shape(shape, tile) + mlir_dtype = utils.dtype_to_ir_type(dtype) + cst = 9999 + def causal(val, idx): + row, col = idx + mask = arith.cmpi(arith.CmpIPredicate.uge, row, col) + return arith.select(mask, val, c(cst, mlir_dtype)) + + tiling = mgpu.TileTransform(tile) + def kernel(ctx, dst, smem): + x = iota_tensor(shape[0], shape[1], dtype) + x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem, dst_ref=dst) + ctx.await_async_copy(0) + + iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape) + result = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + (), + jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + )() + expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst + np.testing.assert_array_equal(result, expected) + @parameterized.product( op=[operator.and_, operator.or_, operator.xor], dtype=[jnp.uint32], From 8477580d95130567ab3d65242af72da1d85ab66c Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 27 Nov 2024 08:33:20 -0800 Subject: [PATCH 528/698] [mgpu pallas] Layout iota operation. PiperOrigin-RevId: 700711177 --- jax/_src/pallas/mosaic_gpu/BUILD | 2 +- jax/_src/pallas/mosaic_gpu/primitives.py | 30 ++++++++++++++++++++++++ jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 11 +++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index ad418e2b936d..3d6e82d443b4 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -91,7 +91,7 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", - "//jax:effects", + "//jax:mlir", "//jax:mosaic_gpu", "//jax:tree_util", "//jax:util", diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 0f25f9808ac1..c1ea3e9870b6 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -25,8 +25,10 @@ from jax._src import state from jax._src import tree_util from jax._src import util +from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import llvm as llvm_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -692,3 +694,31 @@ def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): def commit_smem(): """Commits all writes to SMEM, making them visible to loads, TMA and WGMMA.""" commit_smem_p.bind() + + +broadcasted_iota_p = jax_core.Primitive("broadcasted_iota") + +@broadcasted_iota_p.def_abstract_eval +def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): + del layout, dimension + return jax_core.ShapedArray(shape, dtype) + +@lowering.register_lowering_rule(broadcasted_iota_p) +def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout): + del ctx + undef = llvm_dialect.mlir_undef(mlir.dtype_to_ir_type(dtype)) + is_signed = ( + jnp.issubdtype(dtype, jnp.signedinteger) + if jnp.issubdtype(dtype, jnp.integer) + else None + ) + mlir_dtype = mlir.dtype_to_ir_type(dtype) + return mgpu.FragmentedArray.splat( + undef, shape, layout.value, is_signed=is_signed + ).foreach( + lambda _, idx: arith_dialect.index_cast(mlir_dtype, idx[dimension]), create_array=True, is_signed=is_signed + ) + + +def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None): + return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 3903f3a9c0ae..2a6a6fa83663 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -36,6 +36,7 @@ from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast +from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 993de287f74c..dba1e67acf02 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -241,6 +241,17 @@ def kernel(x_ref, o_ref): # are never written to. np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16]) + def test_iota(self): + dtype, dimension = jnp.int8, 1 + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + ) + def kernel(o_ref): + o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA) + + np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)) + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_smem_to_gmem(self, indexer): @functools.partial( From d449f12a2eed9111dbb044fe7312e93ae0932ba4 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Wed, 27 Nov 2024 08:34:12 -0800 Subject: [PATCH 529/698] Fix early exiting when building multiple wheels PiperOrigin-RevId: 700711389 --- build/build.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/build/build.py b/build/build.py index 0eb6fd9a83ee..25a873d89e24 100755 --- a/build/build.py +++ b/build/build.py @@ -598,10 +598,12 @@ async def main(): wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) + # Exit with error if any wheel build fails. if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") - else: - sys.exit(0) + + # Exit with success if all wheels in the list were built successfully. + sys.exit(0) if __name__ == "__main__": From df8ecb971a1d0e30d1d74dd7515a202923211621 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 27 Nov 2024 08:44:13 -0800 Subject: [PATCH 530/698] [mgpu] Debug print for mlir vectors. PiperOrigin-RevId: 700714031 --- jax/experimental/mosaic/gpu/utils.py | 46 +++++++++++++++++++--------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 0ce1140cfa07..fcba0518620b 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -107,28 +107,46 @@ def c(val: int | float, ty): raise NotImplementedError(ty) return arith.constant(ty, attr) +def _debug_scalar_ty_format(arg): + ty_format = None + if ir.IndexType.isinstance(arg.type): + return "%llu" + if ir.IntegerType.isinstance(arg.type): + width = ir.IntegerType(arg.type).width + ty_format = "%llu" + if width < 64: + arg = arith.extui(ir.IntegerType.get_signless(64), arg) + if ir.F32Type.isinstance(arg.type): + ty_format = "%f" + if ir.F16Type.isinstance(arg.type): + ty_format = "%f" + arg = arith.extf(ir.F32Type.get(), arg) + + return ty_format, arg def debug_print(fmt, *args, uniform=True): type_formats = [] new_args = [] for arg in args: - ty_format = None - if ir.IndexType.isinstance(arg.type): - ty_format = "%llu" - if ir.IntegerType.isinstance(arg.type): - width = ir.IntegerType(arg.type).width - ty_format = "%llu" - if width < 64: - arg = arith.extui(ir.IntegerType.get_signless(64), arg) - if ir.F32Type.isinstance(arg.type): - ty_format = "%f" - if ir.F16Type.isinstance(arg.type): - ty_format = "%f" - arg = arith.extf(ir.F32Type.get(), arg) + if ir.VectorType.isinstance(arg.type): + index = ir.IndexType.get() + vec_ty = ir.VectorType(arg.type) + if len(vec_ty.shape) > 1: + raise NotImplementedError(vec_ty) + vec_args = [ + vector.extractelement(arg, position=c(i, index)) + for i in range(vec_ty.shape[0]) + ] + ty_formats, args = zip(*map(_debug_scalar_ty_format,vec_args)) + ty_format = f"[{','.join(ty_formats)}]" + new_args += args + else: + ty_format, arg = _debug_scalar_ty_format(arg) + new_args.append(arg) + if ty_format is None: raise NotImplementedError(arg.type) type_formats.append(ty_format) - new_args.append(arg) ctx = ( functools.partial(single_thread, per_block=False) if uniform From df6758f021167b1c0b85f5d6e4986f6f0d2a1169 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Wed, 27 Nov 2024 09:38:36 -0800 Subject: [PATCH 531/698] Update XLA dependency to use revision PiperOrigin-RevId: 700728296 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 01f5f862db4d..c14005d2ebdf 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "9d5bac13a3a9284959cf85e9bcb959eb147151cf" -XLA_SHA256 = "f3681cee585ac605d43a7981254ab480a5f4202cc0ab30151d7f2151bddb0d15" +XLA_COMMIT = "c7fdcbc588fa9ea021cf8766530604e8d0fef332" +XLA_SHA256 = "c0e82c28e5e74065c8446199af657af71ae2f786ba33ddb23d6e1bbcd4463d50" def repo(): tf_http_archive( From 83b54d97e7077990f21b2c2653e5df6b14150f27 Mon Sep 17 00:00:00 2001 From: Jed Borovik Date: Wed, 27 Nov 2024 13:54:33 -0500 Subject: [PATCH 532/698] Add version check for effort flags --- jax/_src/compiler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 75eb723a1c75..87c3f62fe552 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -34,6 +34,7 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir import numpy as np @@ -190,6 +191,10 @@ def get_compile_options( assert device_assignment.computation_count() == num_partitions compile_options.device_assignment = device_assignment + if xla_extension_version >= 294: + build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value + build_options.memory_fitting_effort = config.memory_fitting_effort.value + if env_options_overrides is not None: # Some overrides are passed directly on build_options. overrides_on_build_options = [ @@ -200,9 +205,6 @@ def get_compile_options( setattr(build_options, name, env_options_overrides.pop(name)) compile_options.env_option_overrides = list(env_options_overrides.items()) - build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value - build_options.memory_fitting_effort = config.memory_fitting_effort.value - debug_options = compile_options.executable_build_options.debug_options if lib.cuda_path is not None: debug_options.xla_gpu_cuda_data_dir = lib.cuda_path From c2c177eee85e2b57192782759aabf3469ec25935 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 27 Nov 2024 11:23:34 -0800 Subject: [PATCH 533/698] [AutoPGLE] Update fdo_profile comment. PiperOrigin-RevId: 700759386 --- jax/_src/pjit.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 5d3fb5a3c67c..7f29d745e48d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1662,10 +1662,8 @@ def _pjit_call_impl_python( pgle_compile_options['fdo_profile'] = fdo_profile compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) - # TODO(patrios): Do not pass mutable profile session through cached lowering - # chain. Instead we need to move profilers dictionary to pxla module and use - # module as key. Right now we can't do that since there is no way to evict - # _pjit_lower_cached cache for in PGLE mode. + # Passing mutable PGLE profile here since it should be extracted by JAXPR to + # initialize the fdo_profile compile option. compiled = _resolve_and_lower( args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, From 6e72592be6e87a15c46eacee5c03d4510a3dff40 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 27 Nov 2024 11:32:18 -0800 Subject: [PATCH 534/698] [Pallas] Fix float -> int casting on Triton backend. PiperOrigin-RevId: 700761545 --- jax/_src/pallas/triton/lowering.py | 27 +++++++++++++++++++++------ tests/pallas/ops_test.py | 4 ---- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 94848236d299..fe641ba29494 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1469,10 +1469,22 @@ def _float_int_cast( dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: return _not_equal(src, _full(src.type, 0), signed=signed) - elif signed: - return arith_dialect.fptosi(dst_type, src) else: - return arith_dialect.fptoui(dst_type, src) + # We clamp the float value to the min/max integer destination value + # in order to match JAX/XLA casting behavior. Note that this differs + # from numpy casting behavior. + if signed: + maxint = 2**(dst_element_type.width-1) - 1 + minint = -2**(dst_element_type.width-1) + else: + maxint = 2**dst_element_type.width - 1 + minint = 0 + src = arith_dialect.minimumf(src, _full(src.type, maxint)) + src = arith_dialect.maximumf(src, _full(src.type, minint)) + if signed: + return arith_dialect.fptosi(dst_type, src) + else: + return arith_dialect.fptoui(dst_type, src) def _int_float_cast( @@ -1499,10 +1511,12 @@ def _cast( src, _dtype_to_ir_type(dst_type), signed=jnp.issubdtype(src_type, jnp.signedinteger), + dst_signed=jnp.issubdtype(dst_type, jnp.signedinteger), ) -def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: +def _ir_cast(src: ir.Value, dst_type: ir.Type, *, + signed: bool, dst_signed: bool = False) -> ir.Value: if ir.RankedTensorType.isinstance( src.type ) and not ir.RankedTensorType.isinstance(dst_type): @@ -1527,7 +1541,8 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: dst_element_type, ir.F32Type ): return _ir_cast( - _ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False + _ir_cast(src, ir.F32Type.get(), signed=False), + dst_type, signed=False, dst_signed=dst_signed ) if isinstance(src_element_type, ir.FloatType) and isinstance( @@ -1543,7 +1558,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: if isinstance(src_element_type, ir.FloatType) and isinstance( dst_element_type, ir.IntegerType ): - return _float_int_cast(src, dst_type, signed=signed) + return _float_int_cast(src, dst_type, signed=dst_signed) if isinstance(src_element_type, ir.IntegerType) and isinstance( dst_element_type, ir.FloatType ): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d7c1bac5dc61..214e258892ec 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -556,10 +556,6 @@ def test_cast(self, from_dtype, to_dtype, data): self.skipTest("Not supported: bad canonicalization") if from_dtype == "bool" and to_dtype in {"int16", "int8"}: self.skipTest("Not supported: cannot extend to sub-32 bit types") - if jtu.test_device_matches(["gpu"]): - if (from_dtype in {"bfloat16", "float32"} and - to_dtype in {"int8", "int16", "int32"}): - self.skipTest("TODO: wrong result on GPU") if from_dtype == "bfloat16": from_dtype = jnp.bfloat16 From cc5036cc18bc585b0d92a4f606956da084effbad Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 27 Nov 2024 12:42:26 -0800 Subject: [PATCH 535/698] Raise a better error message if anything other than a sequence of ints is passed to make_mesh or create_device_mesh PiperOrigin-RevId: 700779838 --- jax/_src/mesh_utils.py | 9 ++++++++- jax/_src/sharding_impls.py | 1 + tests/mesh_utils_test.py | 6 ++++++ tests/pjit_test.py | 6 ++++++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index 16e34e1afaef..d227b1eeeea9 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -705,6 +705,12 @@ def _transpose_trick( *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] ) +def _validate_axis_shapes(axis_shapes: Sequence[int], arg_name: str, + fun_name: str): + if not all(isinstance(s, int) for s in axis_shapes): + raise ValueError( + f'{arg_name} passed to {fun_name} should be a sequence of ints. Got' + f' {axis_shapes}') def create_device_mesh( mesh_shape: Sequence[int], @@ -740,7 +746,8 @@ def create_device_mesh( """ if devices is None: devices = xb.devices() - if np.prod(mesh_shape) != len(devices): + _validate_axis_shapes(mesh_shape, 'mesh_shape', 'create_device_mesh') + if math.prod(mesh_shape) != len(devices): raise ValueError( f'Number of devices {len(devices)} must equal the product ' f'of mesh_shape {mesh_shape}' diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 8abe58e52a74..39d8aedfe7ad 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1714,6 +1714,7 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], """ if devices is None: devices = xla_bridge.devices() + mesh_utils._validate_axis_shapes(axis_shapes, 'axis_shapes', 'make_mesh') axis_size = math.prod(axis_shapes) if axis_size > len(devices): raise ValueError( diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 66f1fc9f6cfb..d4db8fd3d406 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -353,6 +353,12 @@ def test_create_device_mesh_for_nd_torus( ) self.assertArraysEqual(assignment, expected_assignment_matrix) + def test_create_device_mesh_non_int_error(self): + with self.assertRaisesRegex( + ValueError, + "mesh_shape passed to create_device_mesh should be a sequence of ints"): + mesh_utils.create_device_mesh(((4,), 4)) + @parameterized.named_parameters( ('2x2x1', mock_2x2x1_devices,), ('2x2x4', mock_2x2x4_devices, ), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e541c6346666..6bd05536cebc 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4458,6 +4458,12 @@ def g(x): self.assertEqual(out2.sharding, s) self.assertEqual(out2.dtype, np.float32) + def test_make_mesh_non_int_error(self): + with self.assertRaisesRegex( + ValueError, + "axis_shapes passed to make_mesh should be a sequence of ints"): + jax.make_mesh(((4,), 4), ('x', 'y')) + def test_jnp_array_reshard_error(self): if jax.device_count() < 2: self.skipTest('Requires >=2 devices') From a212a29dc654986ba2d85255a71797490c631c2a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 27 Nov 2024 13:05:57 -0800 Subject: [PATCH 536/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e2fe67323ea46076a61230952a3551df04ec559d. PiperOrigin-RevId: 700786259 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index c14005d2ebdf..5eea114774f4 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "c7fdcbc588fa9ea021cf8766530604e8d0fef332" -XLA_SHA256 = "c0e82c28e5e74065c8446199af657af71ae2f786ba33ddb23d6e1bbcd4463d50" +XLA_COMMIT = "e2fe67323ea46076a61230952a3551df04ec559d" +XLA_SHA256 = "0cdc3108f44f8ab37c90e165bae3bc72e16d049ad18c46d2aa8004f93df2d9f9" def repo(): tf_http_archive( From 8c521547b7b8e9fdbd9720450c36d32d779ab9db Mon Sep 17 00:00:00 2001 From: Enrique Piqueras <19157096+epiqueras@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:29:27 -0800 Subject: [PATCH 537/698] Add experimental JAX roofline API. --- jax/BUILD | 1 + jax/experimental/roofline/__init__.py | 29 ++ jax/experimental/roofline/roofline.py | 342 ++++++++++++++++++++ jax/experimental/roofline/rooflines.py | 270 ++++++++++++++++ tests/BUILD | 6 + tests/roofline_test.py | 426 +++++++++++++++++++++++++ 6 files changed, 1074 insertions(+) create mode 100644 jax/experimental/roofline/__init__.py create mode 100644 jax/experimental/roofline/roofline.py create mode 100644 jax/experimental/roofline/rooflines.py create mode 100644 tests/roofline_test.py diff --git a/jax/BUILD b/jax/BUILD index 4260a8a1acb2..053b05027a2e 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -227,6 +227,7 @@ py_library_providing_imports_info( "_src/state/**/*.py", "_src/third_party/**/*.py", "experimental/key_reuse/**/*.py", + "experimental/roofline/**/*.py", "image/**/*.py", "interpreters/**/*.py", "lax/**/*.py", diff --git a/jax/experimental/roofline/__init__.py b/jax/experimental/roofline/__init__.py new file mode 100644 index 000000000000..8d76c46858c7 --- /dev/null +++ b/jax/experimental/roofline/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from jax.experimental.roofline.roofline import ( + RooflineRuleContext as RooflineRuleContext, +) +from jax.experimental.roofline.roofline import RooflineShape as RooflineShape +from jax.experimental.roofline.roofline import RooflineResult as RooflineResult +from jax.experimental.roofline.roofline import roofline as roofline +from jax.experimental.roofline.roofline import register_roofline as register_roofline +from jax.experimental.roofline.roofline import ( + register_standard_roofline as register_standard_roofline, +) +from jax.experimental.roofline.roofline import roofline_and_grad as roofline_and_grad + + +import jax.experimental.roofline.rooflines as rooflines + +del rooflines diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py new file mode 100644 index 000000000000..42f72f005034 --- /dev/null +++ b/jax/experimental/roofline/roofline.py @@ -0,0 +1,342 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Protocol, Sequence +import numpy as np + +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax._src import api +from jax._src import core +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.api import make_jaxpr +from jax._src.interpreters.partial_eval import dce_jaxpr +from jax._src.interpreters.xla import abstractify +from jax._src.mesh import AbstractMesh, Mesh +from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map +from jax.experimental import shard_map + + +ShapeDtypeStructTree = Any + + +map = util.safe_map + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineRuleContext: + name_stack: source_info_util.NameStack + primitive: core.Primitive + avals_in: Sequence[core.AbstractValue] + avals_out: Sequence[core.AbstractValue] + jaxpr_eqn_ctx: core.JaxprEqnContext + mesh: Mesh | AbstractMesh + pin_lhs_in_vmem: bool + pin_rhs_in_vmem: bool + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineShape: + shape: tuple[int, ...] + dtype: np.dtype + + @classmethod + def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape": + if not isinstance(aval, core.ShapedArray): + raise TypeError(f"Expected ShapedArray, got {type(aval)}.") + if not isinstance(aval.dtype, np.dtype): + raise TypeError(f"Expected numpy dtype, got {type(aval.dtype)}.") + return cls(shape=aval.shape, dtype=aval.dtype) + + @property + def size(self) -> int: + return int(np.prod(self.shape)) + + @property + def bytes(self) -> int: + return int(self.size * self.dtype.itemsize) + + @classmethod + def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int: + return sum(cls.from_aval(aval).bytes for aval in avals) + + +@dataclass(frozen=True, slots=True, kw_only=True) +class RooflineResult: + flops: int = 0 + ici_bytes: dict[str, int] = field(default_factory=dict) + ici_latency: dict[str, int] = field(default_factory=dict) + hbm_bytes: int = 0 + peak_hbm_bytes: int = 0 + + @classmethod + def zeros(cls) -> "RooflineResult": + return cls() + + def __add__(self, other: "RooflineResult") -> "RooflineResult": + def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: + return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)} + + return RooflineResult( + flops=self.flops + other.flops, + ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes), + ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency), + hbm_bytes=self.hbm_bytes + other.hbm_bytes, + peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes), + ) + + def __mul__(self, constant: int | float) -> "RooflineResult": + return RooflineResult( + flops=int(self.flops * constant), + ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()}, + ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()}, + hbm_bytes=int(self.hbm_bytes * constant), + peak_hbm_bytes=int(self.peak_hbm_bytes * constant), + ) + + def __rmul__(self, constant: int | float) -> "RooflineResult": + return self.__mul__(constant) + + +class _RooflineRule(Protocol): + def __call__( + self, ctx: RooflineRuleContext, *args: RooflineShape, **kw + ) -> RooflineResult: ... + + +_rooflines: dict[core.Primitive, _RooflineRule] = {} + + +def _roofline_interpreter( + f_name: str, + jaxpr: core.Jaxpr, + mesh: Mesh | AbstractMesh, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, +) -> RooflineResult: + name_stack = source_info_util.new_name_stack(util.wrap_name(f_name, "roofline")) + + result = RooflineResult.zeros() + + env: dict[core.Var, RooflineShape] = {} + + def write(v: core.Var, node: RooflineShape): + assert node is not None + env[v] = node + + def read(v: core.Atom) -> RooflineShape: + if type(v) is core.Literal: + return RooflineShape.from_aval(abstractify(v.val)) + else: + assert isinstance(v, core.Var) + return env[v] + + def aval(v: core.Atom) -> core.AbstractValue: + if type(v) is core.Literal: + return abstractify(v.val) + else: + return v.aval + + def calculate_peak_hbm_bytes() -> int: + return int( + sum(np.prod(shape.shape) * shape.dtype.itemsize for shape in env.values()) + ) + + make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x)) + map( + write, + jaxpr.constvars, + map(make_roofline_shape, jaxpr.constvars), + ) + map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars)) + last_used = core.last_used(jaxpr) + for eqn in jaxpr.eqns: + source_info = eqn.source_info.replace( + name_stack=name_stack + eqn.source_info.name_stack + ) + with source_info_util.user_context( + eqn.source_info.traceback, name_stack=source_info.name_stack + ): + if "jaxpr" in eqn.params: + result += _roofline_interpreter( + util.wrap_name(f_name, eqn.primitive.name), + eqn.params["jaxpr"], + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) + else: + if eqn.primitive not in _rooflines: + msg = f"No roofline rule for {eqn.primitive}." + for attr in dir(eqn): + if not attr.startswith("_"): + msg += f"\n{attr}: {getattr(eqn, attr)}" + raise NotImplementedError(msg) + rule = _rooflines[eqn.primitive] + result += rule( + RooflineRuleContext( + name_stack=source_info.name_stack, + primitive=eqn.primitive, + avals_in=map(aval, eqn.invars), + avals_out=map(aval, eqn.outvars), + jaxpr_eqn_ctx=eqn.ctx, + mesh=mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ), + *map(read, eqn.invars), + **eqn.params, + ) + + map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars)) + core.clean_up_dead_vars(eqn, env, last_used) + result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes()) + + return result + + +def _f_with_vjp(f: Callable): + @util.wraps(f) + def wrapped(*args): + primals, f_vjp = api.vjp(f, *args) + return f_vjp(tree_map(jnp.bfloat16, primals)) + + return wrapped + + +def roofline( + f: Callable, + mesh: Mesh | AbstractMesh, + in_specs: shard_map.Specs, + out_specs: shard_map.Specs, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, + vjp: bool = False, + print_jaxpr: bool = False, +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]: + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + wrapped_f = shard_map.shard_map(f, mesh, in_specs, out_specs) + if vjp: + wrapped_f = _f_with_vjp(wrapped_f) + + jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args) + + def make_sharded_shape_dtype_struct( + shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs + ) -> api.ShapeDtypeStruct: + return api.ShapeDtypeStruct( + shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) + ) + + out_specs_flat = broadcast_prefix(out_specs, out_shapes) + flat_out_shapes, treedef = tree_flatten(out_shapes) + flat_out_shapes = map( + make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat + ) + out_shapes = tree_unflatten(treedef, flat_out_shapes) + + used_outputs = (True,) * len(jaxpr.jaxpr.outvars) + jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs) + try: + jaxpr = [e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p][ + -1 + ].params["jaxpr"] + except KeyError: + raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.") + + if print_jaxpr: + print(jaxpr) + + return out_shapes, _roofline_interpreter( + util.fun_qual_name(f), + jaxpr, + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) + + return wrapped + + +def register_roofline(prim: core.Primitive): + def register(rule: _RooflineRule): + _rooflines[prim] = rule + return rule + + return register + + +def register_standard_roofline(prim: core.Primitive): + def standard_rule(ctx: RooflineRuleContext, *args, **kwargs): + return RooflineResult.zeros() + + _rooflines[prim] = standard_rule + + +def roofline_and_grad( + f: Callable, + mesh: Mesh | AbstractMesh, + in_specs: shard_map.Specs, + out_specs: shard_map.Specs, + *, + pin_lhs_in_vmem: bool = False, + pin_rhs_in_vmem: bool = False, + print_jaxpr: bool = False, +) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]: + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + primal_shapes, fwd_result = roofline( + f, + mesh, + in_specs, + out_specs, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + print_jaxpr=print_jaxpr, + )(*args) + + return ( + primal_shapes, + fwd_result, + roofline( + f, + mesh, + in_specs, + out_specs, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + vjp=True, + print_jaxpr=print_jaxpr, + )( + *tree_map( + lambda x: api.ShapeDtypeStruct( + x.shape, + jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16, + sharding=x.sharding, + ), + args, + ) + )[1], + ) + + return wrapped diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py new file mode 100644 index 000000000000..cfdb6358bc76 --- /dev/null +++ b/jax/experimental/roofline/rooflines.py @@ -0,0 +1,270 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from dataclasses import replace +import itertools as it +import numpy as np + +from jax._src import ad_util +from jax._src import core, util +from jax._src import ops +from jax._src import prng +from jax._src import random +from jax._src.lax import ( + ann, + convolution, + fft, + lax, + linalg, + parallel as lax_parallel, + slicing, + special, + windowed_reductions, +) +from jax.experimental import roofline +from jax.experimental import shard_map + + +for prim in it.chain( + ad_util.__dict__.values(), + ann.__dict__.values(), + convolution.__dict__.values(), + fft.__dict__.values(), + lax.__dict__.values(), + linalg.__dict__.values(), + ops.__dict__.values(), + prng.__dict__.values(), + random.__dict__.values(), + shard_map.__dict__.values(), + slicing.__dict__.values(), + special.__dict__.values(), + windowed_reductions.__dict__.values(), +): + if isinstance(prim, core.Primitive): + roofline.register_standard_roofline(prim) + + +@roofline.register_roofline(lax.dot_general_p) +def _dot_general_roofline( + ctx: roofline.RooflineRuleContext, + *args, + dimension_numbers: lax.DotDimensionNumbers, + **kw, +) -> roofline.RooflineResult: + lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + (lhs_contract, _), (lhs_batch, _) = dimension_numbers + + flops = ( + 2 + * lhs.size + * rhs.size + / np.prod([lhs.shape[i] for i in lhs_contract]) + / np.prod([lhs.shape[i] for i in lhs_batch]) + ) + + hbm_bytes = 0 + if not ctx.pin_lhs_in_vmem: + hbm_bytes += lhs.bytes + hbm_bytes += out.bytes + if not ctx.pin_rhs_in_vmem: + hbm_bytes += rhs.bytes + + return roofline.RooflineResult(flops=int(flops), hbm_bytes=hbm_bytes) + + +def _return_zeros_if_one_sized_axis( + ctx: roofline.RooflineRuleContext, axes: tuple[str, ...] +) -> roofline.RooflineResult | None: + axes_size = np.prod([ctx.mesh.shape[axis] for axis in axes]) + if axes_size > 1: + return None + return roofline.RooflineResult( + ici_bytes={axis: 0 for axis in axes}, + ici_latency={axis: 0 for axis in axes}, + ) + + +def _ring_collective_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + is_reduce: bool = True, + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axes): + return zeros_result + + mesh = ctx.mesh.shape + current_shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in) + if is_reduce: + current_shard_size /= np.prod([mesh[axis] for axis in axes]) + + # We model the slowest color as the bottleneck. + sorted_axes = sorted(axes, key=lambda x: mesh[x], reverse=True) + num_axes = len(sorted_axes) + + ici_bytes = 0 + # Phase split. + current_shard_size //= num_axes + for axis in sorted_axes: + axis_size = mesh[axis] + # Do phase. + ici_bytes += current_shard_size * (axis_size - 1) + # Increase shard size. + current_shard_size *= axis_size + + # Bottleneck is the longest axis. + ici_latency = mesh[sorted_axes[0]] * num_axes + + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in sorted_axes}, + ici_latency={axis: int(ici_latency) for axis in sorted_axes}, + ) + + +roofline.register_roofline(lax_parallel.reduce_scatter_p)( + lambda *args, axis_name, **kw: _ring_collective_roofline(*args, axes=axis_name, **kw) +) +roofline.register_roofline(lax_parallel.all_gather_p)( + lambda *args, axis_name, **kw: _ring_collective_roofline( + *args, axes=axis_name, is_reduce=False, **kw + ) +) + + +def _scalar_collective_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in] + ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes]) + return _ring_collective_roofline(ctx, *args, axes=axes, is_reduce=False, **kw) + + +roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline) +roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline) + + +@roofline.register_roofline(shard_map.psum2_p) +def _psum2_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + ring_roofline = _ring_collective_roofline(ctx, *args, axes=axes, **kw) + + def double_dict(d: dict[str, int]) -> dict[str, int]: + return {k: v * 2 for k, v in d.items()} + + return roofline.RooflineResult( + ici_bytes=double_dict(ring_roofline.ici_bytes), + ici_latency=double_dict(ring_roofline.ici_latency), + ) + + +@roofline.register_roofline(lax_parallel.all_to_all_p) +def _all_to_all_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis_name: tuple[str, ...], + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name): + return zeros_result + + mesh = ctx.mesh.shape + size = roofline.RooflineShape.total_bytes(ctx.avals_in) * np.prod([ + mesh[axis] for axis in axis_name + ]) + + smallest_axis = sorted(axis_name, key=lambda x: mesh[x])[0] + num_axes = len(axis_name) + bisection_bw = mesh[smallest_axis] ** (num_axes - 1) + if mesh[smallest_axis] > 2: + # Times 2 because of wraparound. + bisection_bw *= 2 + + # Half the data needs to cross the bisection on average. + ici_bytes = size / 2 / bisection_bw + + # The latency is the max number of hops across the mesh. + ici_latency = sum(mesh[axis] / 2 for axis in axis_name) + + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in axis_name}, + ici_latency={axis: int(ici_latency) for axis in axis_name}, + ) + + +@roofline.register_roofline(lax_parallel.ppermute_p) +def _ppermute_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis_name: tuple[str, ...], + perm: tuple[tuple[int, int], ...], + **kw, +) -> roofline.RooflineResult: + if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name): + return zeros_result + + mesh = ctx.mesh.shape + mesh_dims: list[int] = [mesh.get(axis, 1) for axis in axis_name] + shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in) + + ici_contention: dict[tuple[tuple[int, ...], ...], float] = defaultdict(float) + ici_latency = 0 + + for src, dst in perm: + if src == dst: + continue + # Perms are linearized. + src_coords = tuple(int(i) for i in np.unravel_index(src, mesh_dims)) + dst_coords = tuple(int(i) for i in np.unravel_index(dst, mesh_dims)) + + ici_latency_for_perm = 0 + + # For each dimension. + for i in range(len(axis_name)): + dim_size = mesh_dims[i] + src_pos = src_coords[i] + dst_pos = dst_coords[i] + + if src_pos != dst_pos: + # Calculate distance with wraparound. + clockwise_dist = (dst_pos - src_pos) % dim_size + counter_dist = (src_pos - dst_pos) % dim_size + direction = 1 if clockwise_dist <= counter_dist else -1 + + curr_pos = src_pos + while curr_pos != dst_pos: + curr_coords = util.tuple_update(src_coords, i, curr_pos) + next_pos = (curr_pos + direction) % dim_size + next_coords = util.tuple_update(curr_coords, i, next_pos) + ici_contention[tuple(sorted([curr_coords, next_coords]))] += 1 + curr_pos = next_pos + + distance = min(clockwise_dist, counter_dist) + ici_latency_for_perm += distance + + ici_latency = max(ici_latency, ici_latency_for_perm) + + ici_bytes = shard_size * max(ici_contention.values(), default=0) + return roofline.RooflineResult( + ici_bytes={axis: int(ici_bytes) for axis in axis_name}, + ici_latency={axis: int(ici_latency) for axis in axis_name}, + ) diff --git a/tests/BUILD b/tests/BUILD index 92a6ed99ceca..97f8a3634d99 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1200,6 +1200,12 @@ jax_multiplatform_test( srcs = ["key_reuse_test.py"], ) +jax_multiplatform_test( + name = "roofline_test", + srcs = ["roofline_test.py"], + enable_backends = ["cpu"], +) + jax_multiplatform_test( name = "x64_context_test", srcs = ["x64_context_test.py"], diff --git a/tests/roofline_test.py b/tests/roofline_test.py new file mode 100644 index 000000000000..e5003947181b --- /dev/null +++ b/tests/roofline_test.py @@ -0,0 +1,426 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from functools import partial +import contextlib + +from absl.testing import absltest +from jax.sharding import PartitionSpec as P +import jax +import jax.lax as lax +import jax.numpy as jnp + +from jax._src import test_util as jtu + +from jax.experimental import roofline + + +jax.config.parse_flags_with_absl() + + +def create_inputs( + *shardings: P, + dtype: jnp.dtype = jnp.float32, + mesh_shape: tuple[int, ...] = (2, 2, 2), +) -> tuple[jax.sharding.Mesh, tuple[jax.ShapeDtypeStruct, ...]]: + mesh = jtu.create_mesh(mesh_shape, ("x", "y", "z")) + arrays = [] + for sharding in shardings: + array = jax.ShapeDtypeStruct( + (8, 8), dtype, sharding=jax.sharding.NamedSharding(mesh, sharding) + ) + arrays.append(array) + return mesh, tuple(arrays) + + +# Run all tests with 8 CPU devices. +_exit_stack = contextlib.ExitStack() + + +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) + + +def tearDownModule(): + _exit_stack.close() + + +class RooflineTest(jtu.JaxTestCase): + def test_scalar_collectives(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(P("z", None), P(("x", "y"), None)), + ) + def scalar_collectives(a, b): + a = lax.pmin(a, ("x", "y")) + b = lax.pmax(b, "z") + return a, b + + _, results = scalar_collectives(a, b) + + itemsize = 4 + + axis_size = 2 + axis_size_m1 = axis_size - 1 + + xy_num_axes = 2 + xy_ici_bytes = int( + itemsize + # 2 phases. + * ( + (1 / xy_num_axes * axis_size_m1) + (1 * axis_size / xy_num_axes * axis_size_m1) + ) + ) + # 2 phases times 2 hops. + xy_ici_latency = 2 * 2 + + z_ici_bytes = int(itemsize * 1 * axis_size_m1) + # 2 hops. + z_ici_latency = 2 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_collective_matmul(self): + a_spec = P(None, "x") + b_spec = P(None, "x") + c_spec = P("x", None) + mesh, (a, b, c) = create_inputs(a_spec, b_spec, c_spec, dtype=jnp.int8) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec, c_spec), + out_specs=a_spec, + ) + def collective_matmul(a, b, c): + a = lax.all_gather(a, "x", axis=1, tiled=True) + # Test broadcasting and slicing works. + a = a[None, :, :] + b = b[:, None, :] + ab = jnp.einsum("bij,jbk->ikb", a, b).astype(jnp.int8)[..., 0] + abc = jnp.einsum("ik,kc->ic", ab, c).astype(jnp.int8) + abc = lax.psum_scatter(abc, "x", scatter_dimension=1, tiled=True) + return abc + + _, results = collective_matmul(a, b, c) + + itemsize = 1 + m, k, n = 8, 4, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_mk = mk + + # Times 2 for ag + rs. + ici_bytes = 2 * int(itemsize * sharded_mk * axis_size_m1) + ici_latency = 2 * 2 + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=2 * itemsize * (mk + kn + mn), + # Right after all_gather. + peak_hbm_bytes=itemsize * (mk * axis_size + mk + kn), + ) + self.assertDataclassEqual(results, expected) + + def test_matmul_psum(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), None) + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=P("z", None), + ) + def matmul_psum(a, b): + c = a @ b + c = lax.psum(c, ("x", "y")) + return c + + _, results = matmul_psum(a, b) + + itemsize = 4 + m, k, n = 4, 2, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + num_axes = 2 + sharded_mn = mn / axis_size / num_axes + + # Times 2 for ag + rs. + ici_bytes = 2 * int( + itemsize + # 2 phases. + * ( + (sharded_mn / num_axes * axis_size_m1) + + (sharded_mn * axis_size / num_axes * axis_size_m1) + ) + ) + ici_latency = 2 * 2 * 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={axis: ici_bytes for axis in ("x", "y")}, + ici_latency={axis: ici_latency for axis in ("x", "y")}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mn), + ) + self.assertDataclassEqual(results, expected) + + def test_all_to_all(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(P(("z", "x", "y"), None), P(("x", "y", "z"), None)), + ) + def all_to_all(a, b): + a = lax.all_to_all(a, ("x", "y"), split_axis=0, concat_axis=1, tiled=True) + b = lax.all_to_all(b, "z", split_axis=0, concat_axis=1, tiled=True) + return a, b + + _, results = all_to_all(a, b) + + itemsize = 4 + + xy_size = itemsize * 8 * 8 / 2 + # Half the data over 2 links. + xy_ici_bytes = int(xy_size / 2 / 2) + # 2 hops. + xy_ici_latency = 2 + + z_size = itemsize * 8 * 8 / 2 / 2 + # Half the data over 1 link. + z_ici_bytes = int(z_size / 2) + # 1 hop. + z_ici_latency = 1 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_ppermute(self): + a_spec = P("z", ("x", "y")) + b_spec = P(("x", "y"), "z") + mesh, (a, b) = create_inputs(a_spec, b_spec) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=(a_spec, b_spec), + ) + def ppermute(a, b): + a = lax.ppermute(a, ("x", "y"), perm=((0, 3), (3, 0), (1, 2), (2, 1))) + b = lax.ppermute(b, "z", perm=((1, 0), (0, 1))) + return a, b + + _, results = ppermute(a, b) + + itemsize = 4 + shard_size = itemsize * 4 * 2 + + # At most 2 shards contend for 1 link. + xy_ici_bytes = int(shard_size * 2) + # 2 hops. + xy_ici_latency = 2 + + # No contention but there is a single link. + z_ici_bytes = int(shard_size * 2) + # 1 hop. + z_ici_latency = 1 + expected = roofline.RooflineResult( + ici_bytes={"x": xy_ici_bytes, "y": xy_ici_bytes, "z": z_ici_bytes}, + ici_latency={"x": xy_ici_latency, "y": xy_ici_latency, "z": z_ici_latency}, + peak_hbm_bytes=itemsize * 2 * 4 * 2, + ) + self.assertDataclassEqual(results, expected) + + def test_grad_matmuls(self): + a_spec = P(None, "x") + b_spec = P(None, None) + mesh, (a, b) = create_inputs(a_spec, b_spec, dtype=jnp.int8) + + @partial( + roofline.roofline_and_grad, + mesh=mesh, + in_specs=(a_spec, b_spec), + # Numerically incorrect AD, but tests that we handle it properly. + out_specs=P("x", None), + ) + def collective_matmul(a, b): + a = lax.all_gather(a, "x", axis=1, tiled=True) + return a @ b + + c, fwd_results, bwd_results = collective_matmul(a, b) + + itemsize = 1 + m, k, n = 8, 8, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_mk = mk // axis_size + + ici_bytes = int(itemsize * sharded_mk * axis_size_m1) + ici_latency = 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mk + kn), + ) + self.assertDataclassEqual(fwd_results, expected) + + bwd_itemsize = 2 + # 2 for psum + 1 for rs. + bwd_ici_bytes = 3 * int(bwd_itemsize * sharded_mk * axis_size_m1) + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": bwd_ici_bytes}, + ici_latency={"x": 3 * ici_latency}, + hbm_bytes=2 * bwd_itemsize * (mk + kn + mn), + # Residuals + cotangents. + peak_hbm_bytes=bwd_itemsize * (mk + kn + mn), + ) + self.assertDataclassEqual(bwd_results, expected) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=c.sharding.spec, + out_specs=c.sharding.spec, + ) + def mul_2(c): + return c * 2 + + results = mul_2(c) + self.assertLen(results, 2) + + def test_one_sized_axis_collectives(self): + a_spec = P("x") + mesh, (a,) = create_inputs(a_spec, mesh_shape=(1, 2, 4)) + + @partial( + roofline.roofline, + mesh=mesh, + in_specs=a_spec, + out_specs=a_spec, + ) + def one_sized_axis_collectives(a): + a = lax.pmin(a, "x") + a = lax.all_gather(a, "x", axis=1, tiled=True) + a = lax.psum_scatter(a, "x", scatter_dimension=1, tiled=True) + a = lax.psum(a, "x") + a = lax.all_to_all(a, "x", split_axis=0, concat_axis=1, tiled=True) + a = lax.ppermute(a, "x", perm=((1, 0), (0, 1))) + return a + + _, results = one_sized_axis_collectives(a) + expected = roofline.RooflineResult( + ici_bytes={"x": 0}, + ici_latency={"x": 0}, + peak_hbm_bytes=4 * 8 * 8, + ) + self.assertDataclassEqual(results, expected) + + def test_remat(self): + a_spec = P("x", None) + b_spec = P("x", None) + mesh, (a, b) = create_inputs(a_spec, b_spec) + + def fsdp_checkpoint_policy(prim, *args, **kwargs): + if prim == lax.all_gather_p and kwargs["axis_name"] == "x": + return True + return False + + @partial( + roofline.roofline_and_grad, + mesh=mesh, + in_specs=(a_spec, b_spec), + out_specs=P("x", None), + ) + @partial(jax.checkpoint, policy=fsdp_checkpoint_policy) + def collective_matmul(a, b): + b = lax.all_gather(b, "x", axis=0, tiled=True) + return a @ b + + _, fwd_results, bwd_results = collective_matmul(a, b) + + itemsize = 4 + m, k, n = 4, 8, 8 + mk = m * k + kn = k * n + mn = m * n + + axis_size = 2 + axis_size_m1 = axis_size - 1 + sharded_kn = kn // axis_size + + ici_bytes = int(itemsize * sharded_kn * axis_size_m1) + ici_latency = 2 + expected = roofline.RooflineResult( + flops=2 * m * k * n, + ici_bytes={"x": ici_bytes}, + ici_latency={"x": ici_latency}, + hbm_bytes=itemsize * (mk + kn + mn), + peak_hbm_bytes=itemsize * (mk + kn), + ) + self.assertDataclassEqual(fwd_results, expected) + + bwd_itemsize = 2 + # Remat ag + rs. + bwd_ici_bytes = 2 * int(bwd_itemsize * sharded_kn * axis_size_m1) + expected = roofline.RooflineResult( + flops=2 * 2 * m * k * n, + ici_bytes={"x": bwd_ici_bytes}, + ici_latency={"x": 2 * ici_latency}, + hbm_bytes=2 * bwd_itemsize * (mk + kn + mn), + # Residuals + cotangents. + # We gather kn while computing the kn cotangents. + peak_hbm_bytes=bwd_itemsize * (kn + kn + mn), + ) + self.assertDataclassEqual(bwd_results, expected) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 132ad251a127d2d536d5e6fd1c3351e09a1a1a29 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 27 Nov 2024 15:13:56 -0800 Subject: [PATCH 538/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d0769456f09dde192ab6e6421648bddc38908b39. PiperOrigin-RevId: 700817614 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 5eea114774f4..177eaff34b23 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e2fe67323ea46076a61230952a3551df04ec559d" -XLA_SHA256 = "0cdc3108f44f8ab37c90e165bae3bc72e16d049ad18c46d2aa8004f93df2d9f9" +XLA_COMMIT = "d0769456f09dde192ab6e6421648bddc38908b39" +XLA_SHA256 = "e3bef213eef3ea6a8459bfdfe863b19380219af0a6c5a017226b1bdc1bda17e9" def repo(): tf_http_archive( From b62ca8b15b09620074fb20848ec38cf2d3cce8a3 Mon Sep 17 00:00:00 2001 From: Vadym Matsishevskyi Date: Wed, 27 Nov 2024 18:38:35 -0800 Subject: [PATCH 539/698] Rework custom hermetic python instructions. The focus was shifted from how one should build custom python, as it seems like people don't really have issues with that and the process is fairly standard. Instead the focus was made on demystifying of what hermetic (custom or not) Python actually is and explaining how a user can customize the build while still keeping it as close to a regular Python workflow as possible. PiperOrigin-RevId: 700863168 --- docs/developer.md | 249 +++++++++++++++++++++++++++++----------------- 1 file changed, 156 insertions(+), 93 deletions(-) diff --git a/docs/developer.md b/docs/developer.md index 29a3cb6068ac..e6bdf53f1112 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -359,99 +359,162 @@ accept pre-release, dev and nightly packages, it will also search https://pypi.anaconda.org/scientific-python-nightly-wheels/simple as an extra index url and will not put hashes in the resultant requirements lock file. -### Building with pre-release Python version - -We support all of the current versions of Python out of the box, but if you need -to build and test against a different version (for example the latest unstable -version which hasn't been released officially yet) please follow the -instructions below. - -1) Make sure you have installed necessary linux packages needed to build Python - interpreter itself and key packages (like `numpy` or `scipy`) from source. On - a typical Debian system you may need to install the following packages: - -``` -sudo apt-get update -sudo apt-get build-dep python3 -y -sudo apt-get install pkg-config zlib1g-dev libssl-dev -y -# to build scipy -sudo apt-get install libopenblas-dev -y -``` - -2) Check your `WORKSPACE` file and make sure it - has `custom_python_interpreter()` entry there, pointing to the version of - Python you want to build. - -3) Run `bazel build @python_dev//:python_dev -repo_env=HERMETIC_PYTHON_VERSION=3.12` - to build Python interpreter. Note, it is easy to confuse Python version used - to conduct the build (which is needed for technical reasons and is defined by - `HERMETIC_PYTHON_VERSION=3.12`) and the version of Python you are building - (defined by whichever version you specified in `custom_python_interpreter()` - on step 2). For build to succeed, please make sure that hermetic Python you - choose to conduct the build already exists in your configuraiton (the actual - version does not matter, as long as it is a working one). By default, Python - binary will be built with GCC compiler. If you wish to build it with clang, - you need to set corresponding env variables to do so ( - e.g. `--repo_env=CC=/usr/lib/llvm-17/bin/clang --repo_env=CXX=/usr/lib/llvm-17/bin/clang++`). - -4) Check the output of the previous command. At the very end of it you will find - a code snippet for `python_register_toolchains()` entry with your newly built - Python in it. Copy that code snippet in your `WORKSPACE` file either right - after `python_init_toolchains()` entry (to add the new version of Python) or - instead of it (to replace an existing version, like replacing `3.12` with - custom built variant of `3.12`). The code snippet is generated to match your - actual setup, so it should work as is, but you can customize it if you choose - so (for example to change location of Python's `.tgz` file so it could be - downloaded remotely instead of being on local machine). - -5) Make sure there is an entry for your Python's version in `requirements` - parameter for `python_init_repositories()` in your WORKSPACE file. For - example for `Python 3.13` it should have something - like `"3.13": "//build:requirements_lock_3_13.txt"`. Note, the key in the - `requirements` parameter must always be in `"major.minor"` version format, so - even if you are building Python version `3.13.0rc1` the corresponding - `requirements` entry must still be `"3.13": "//build:requirements_lock_3_13.txt"`, - **not** `"3.13.0rc1": "//build:requirements_lock_3_13_0rc1.txt"`. - -6) For unstable versions of Python, optionally (but highly recommended) - run `bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13"`, - where `3.13` is the version of Python interpreter you built on step 3. - This will make `pip` pull and build from sources (for packages which don't - have binaries published yet, for - example `numpy`, `scipy`, `matplotlib`, `zstandard`) all of the JAX's python - dependencies. It is recommended to do this step first (i.e. independently of - actual JAX build) for all unstable versions of Python to avoid conflict - between building JAX itself and building of its Python dependencies. For - example, we normally build JAX with clang but building `matplotlib` from - sources with clang fails out of the box due to differences in LTO behavior ( - Link Time Optimization, triggered by `-flto` flag) between GCC and clang, and - matplotlib assumes GCC by default. - If you build against a stable version of Python, or in general you do not - expect any of your Python dependencies to be built from sources (i.e. binary - distributions for the corresponding Python version already exist in the - repository) this step is not needed. - -7) Congrats, you've built and configured your custom Python for JAX project! You - may now execute your built/test commands as usual, just make - sure `HERMETIC_PYTHON_VERSION` environment variable is set and points to your - new version. - -8) Note, if you were building a pre-release version of Python, updating of - `requirements_lock_.txt` files with your newly built Python - is likely to fail, because package repositories will not have matching - binary packages. When there are no binary packages available `pip-compile` - proceeds with building them from sources, which is likely to fail because it - is more restrictive than doing the same thing during `pip` installation. - The recommended way to update requirements lock file for unstable versions of - Python is to update requirements for the latest stable version (e.g. `3.12`) - without hashes (therefore special `//build:requirements_dev.update` target) - and then copy the results to the unstable Python's lock file (e.g. `3.13`): -``` -bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="3.12" -cp build/requirements_lock_3_12.txt build/requirements_lock_3_13.txt -bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13" -# You may need to edit manually the resultant lock file, depending on how ready -# your dependencies are for the new version of Python. +### Customizing hermetic Python (Advanced Usage) + +We support all of the current versions of Python out of the box, so unless your +workflow has very special requirements (such as ability to use your own custom +Python interpreter) you may safely skip this section entirely. + +In short, if you rely on a non-standard Python workflow you still can achieve +the great level of flexibility in hermetic Python setup. Conceptually there will +be only one difference compared to non-hermetic case: you will need to think in +terms of files, not installations (i.e. think what files your build actually +depends on, not what files need to be installed on your system), the rest is +pretty much the same. + +So, in practice, to gain full control over your Python environment, hermetic or +not you need to be able to do the following three things: + +1) Specify which python interpreter to use (i.e. pick actual `python` or + `python3` binary and libs that come with it in the same folder). +2) Specify a list of Python dependencies (e.g. `numpy`) and their actual + versions. +3) Be able to add/remove/update dependencies in the list easily. Each + dependency itself could be custom too (self-built for example). + +You already know how to do all of the steps above in a non-hermetic Python +environment, here is how you do the same in the hermetic one (by approaching it +in terms of files, not installations): + +1) Instead of installing Python, get Python interpreter in a `tar` or `zip` + file. Depending on your case you may simply pull one of many existing ones + (such as [python-build-standalone](https://github.com/indygreg/python-build-standalone/releases)), + or build your own and pack it in an archive (following official + [build instructions](https://devguide.python.org/getting-started/setup-building/#compile-and-build) + will do just fine). E.g. on Linux it will look something like the following: + ``` + ./configure --prefix python + make -j12 + make altinstall + tar -czpf my_python.tgz python + ``` + Once you have the tarball ready, plug it in the build by pointing + `HERMETIC_PYTHON_URL` env var to the archive (either local one or from the + internet): + ``` + --repo_env=HERMETIC_PYTHON_URL="file:///local/path/to/my_python.tgz" + --repo_env=HERMETIC_PYTHON_SHA256= + + # OR + --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" + --repo_env=HERMETIC_PYTHON_SHA256= + + # We assume that top-level folder in the tarbal is called "python", if it is + # something different just pass additional HERMETIC_PYTHON_PREFIX parameter + --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" + --repo_env=HERMETIC_PYTHON_SHA256= + --repo_env=HERMETIC_PYTHON_PREFIX="my_python/install" + ``` + +2) Instead of doing `pip install` create `requirements_lock.txt` file with + full transitive closure of your dependencies. You may also depend on the + existing ones already checked in this repo (as long as they work with your + custom Python version). There are no special instructions on how you do it, + you may follow steps recommended in [Specifying Python dependencies](#specifying-python-dependencies) + from this doc, just call pip-compile directly (note, the lock file must be + hermetic, but you can always generate it from non-hermetic python if you'd + like) or even create it manually (note, hashes are optional in lock files). + + +3) If you need to update or customize your dependencies list, you may once again + follow the [Specifying Python dependencies](#specifying-python-dependencies) + instructions to update `requirements_lock.txt`, call pip-compile directly or + modify it manually. If you have a custom package you want to use just point + to its `.whl` file directly (remember, work in terms of files, not + installations) from your lock (note, `requirements.txt` and + `requirements_lock.txt` files support local wheel references). If your + `requirements_lock.txt` is already specified as a dependency to + `python_init_repositories()` in `WORKSPACE` file you don't have to do + anything else. Otherwise you can point to your custom file as follows: + ``` + --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/custom_requirements_lock.txt" + ``` + Also note if you use `HERMETIC_REQUIREMENTS_LOCK` then it fully controls list + of your dependencies and the automatic local wheels resolution logic + described in [Specifying dependencies on local wheels](#specifying-dependencies-on-local-wheels) + gets disabled to not interfere with it. + +That is it. To summarize: if you have an archive with Python interpreter in it +and a requirements_lock.txt file with full transitive closure of your +dependencies then you fully control your Python environment. + +#### Custom hermetic Python examples + +Note, for all of the examples below you may also set the environment variables +globally (i.e. `export` in your shell instead of `--repo_env` argument to your +command) so calling bazel via `build/build.py` will work just fine. + +Build with custom `Python 3.13` from the internet, using default +`requirements_lock_3_13.txt` already checked in this repo (i.e. custom +interpreter but default dependencies): +``` +bazel build + --repo_env=HERMETIC_PYTHON_VERSION=3.13 + --repo_env=HERMETIC_PYTHON_URL="https://github.com/indygreg/python-build-standalone/releases/download/20241016/cpython-3.13.0+20241016-x86_64-unknown-linux-gnu-install_only.tar.gz" + --repo_env=HERMETIC_PYTHON_SHA256="2c8cb15c6a2caadaa98af51df6fe78a8155b8471cb3dd7b9836038e0d3657fb4" +``` + +Build with custom Python 3.13 from local file system and custom lock file +(assuming the lock file was put in `jax/build` folder of this repo before +running the command): +``` +bazel test + --repo_env=HERMETIC_PYTHON_VERSION=3.13 + --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz" + --repo_env=HERMETIC_PYTHON_PREFIX="prefix/to/strip/in/cython/tar/gz/archive" + --repo_env=HERMETIC_PYTHON_SHA256= + --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt" +``` + +If default python interpreter is good enough for you and you just need a custom +set of dependencies: +``` +bazel test + --repo_env=HERMETIC_PYTHON_VERSION=3.13 + --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt" +``` + +Note, you can have multiple different `requirement_lock.txt` files corresponding +to the same Python version to support different scenarios. You can control +which one is selected by specifying `HERMETIC_PYTHON_VERSION`. For example in +`WORKSPACE` file: +``` +requirements = { + "3.10": "//build:requirements_lock_3_10.txt", + "3.11": "//build:requirements_lock_3_11.txt", + "3.12": "//build:requirements_lock_3_12.txt", + "3.13": "//build:requirements_lock_3_13.txt", + "3.13-scenario1": "//build:scenario1_requirements_lock_3_13.txt", + "3.13-scenario2": "//build:scenario2_requirements_lock_3_13.txt", +}, +``` +Then you can build and test different combinations of stuff without changing +anything in your environment: +``` +# To build with scenario1 dependendencies: +bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 + +# To build with scenario2 dependendencies: +bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2 + +# To build with default dependendencies: +bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13 + +# To build with scenario1 dependendencies and custom Python 3.13 interpreter: +bazel test + --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 + --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz" + --repo_env=HERMETIC_PYTHON_SHA256= ``` ## Installing `jax` From 34fe66b08b3720bd37ca12d3527cbf31dabbd9b8 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 28 Nov 2024 02:19:59 -0800 Subject: [PATCH 540/698] [mgpu] foreach should not try to create an array if it didn't create the registers due to create_array=False. PiperOrigin-RevId: 700955830 --- jax/experimental/mosaic/gpu/fragmented_array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 6b288906e967..c2cd8c21c132 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1265,7 +1265,8 @@ def foreach( if create_array: new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) - return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) + if create_array: + return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) def store_untiled(self, ref: ir.Value): if not ir.MemRefType.isinstance(ref.type): From a158e02b7d1c1a50e53adfec7f48bec69cc0dc5b Mon Sep 17 00:00:00 2001 From: Fabian Mentzer Date: Thu, 28 Nov 2024 05:34:52 -0800 Subject: [PATCH 541/698] Reverts cc5036cc18bc585b0d92a4f606956da084effbad PiperOrigin-RevId: 700998046 --- jax/_src/mesh_utils.py | 9 +-------- jax/_src/sharding_impls.py | 1 - tests/mesh_utils_test.py | 6 ------ tests/pjit_test.py | 6 ------ 4 files changed, 1 insertion(+), 21 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index d227b1eeeea9..16e34e1afaef 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -705,12 +705,6 @@ def _transpose_trick( *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] ) -def _validate_axis_shapes(axis_shapes: Sequence[int], arg_name: str, - fun_name: str): - if not all(isinstance(s, int) for s in axis_shapes): - raise ValueError( - f'{arg_name} passed to {fun_name} should be a sequence of ints. Got' - f' {axis_shapes}') def create_device_mesh( mesh_shape: Sequence[int], @@ -746,8 +740,7 @@ def create_device_mesh( """ if devices is None: devices = xb.devices() - _validate_axis_shapes(mesh_shape, 'mesh_shape', 'create_device_mesh') - if math.prod(mesh_shape) != len(devices): + if np.prod(mesh_shape) != len(devices): raise ValueError( f'Number of devices {len(devices)} must equal the product ' f'of mesh_shape {mesh_shape}' diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 39d8aedfe7ad..8abe58e52a74 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1714,7 +1714,6 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], """ if devices is None: devices = xla_bridge.devices() - mesh_utils._validate_axis_shapes(axis_shapes, 'axis_shapes', 'make_mesh') axis_size = math.prod(axis_shapes) if axis_size > len(devices): raise ValueError( diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index d4db8fd3d406..66f1fc9f6cfb 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -353,12 +353,6 @@ def test_create_device_mesh_for_nd_torus( ) self.assertArraysEqual(assignment, expected_assignment_matrix) - def test_create_device_mesh_non_int_error(self): - with self.assertRaisesRegex( - ValueError, - "mesh_shape passed to create_device_mesh should be a sequence of ints"): - mesh_utils.create_device_mesh(((4,), 4)) - @parameterized.named_parameters( ('2x2x1', mock_2x2x1_devices,), ('2x2x4', mock_2x2x4_devices, ), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6bd05536cebc..e541c6346666 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4458,12 +4458,6 @@ def g(x): self.assertEqual(out2.sharding, s) self.assertEqual(out2.dtype, np.float32) - def test_make_mesh_non_int_error(self): - with self.assertRaisesRegex( - ValueError, - "axis_shapes passed to make_mesh should be a sequence of ints"): - jax.make_mesh(((4,), 4), ('x', 'y')) - def test_jnp_array_reshard_error(self): if jax.device_count() < 2: self.skipTest('Requires >=2 devices') From b09b0779e044a461d80620868f27150e8f54645e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 28 Nov 2024 05:40:20 -0800 Subject: [PATCH 542/698] [Mosaic GPU] Add support for fast upcasts of s8 to bf16 for vectors of 4 elements To complement the current path that only handles 2 elements. PiperOrigin-RevId: 700998965 --- .../mosaic/gpu/fragmented_array.py | 39 ++++++++++++------- tests/mosaic/gpu_test.py | 10 ++--- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index c2cd8c21c132..dc2a5f0d891a 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1032,12 +1032,11 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): ) reg_type = self.registers.flat[0].type is_vector_reg = ir.VectorType.isinstance(reg_type) - reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else () - if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,): + reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,) + [vector_len] = reg_shape # This is meant to be a 1D assertion. + if cur_dtype == i8 and self.is_signed and new_dtype == bf16 and vector_len in {2, 4}: new_registers = np.empty_like(self.registers) - for idx, reg in np.ndenumerate(self.registers): - reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) - val_16 = llvm.extractelement(reg_16, c(0, i32)) + def upcast_to_bf16(reg, high): # We first embed the s8 into a bf16 with the exponent equal to # bias + mantissa bits. Then, we zero the msb that didn't fit into the # mantissa, zero out all bits other than msb, and subtract the last @@ -1045,24 +1044,36 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): # lsb of the exponent (msb of the second byte) is zero, which allows us # to losslesly pack the msb there. When 1, it doubles the value of s2, # making the result negative. - new_val_32 = llvm.inline_asm( + return llvm.inline_asm( i32, - [val_16], - """ - { + [reg], + f""" + {{ .reg .b32 s<3>; - prmt.b32 s0, $1, 0x43, 0x4140; + prmt.b32 s0, $1, 0x43, {0x4342 if high else 0x4140}; and.b32 s1, s0, 0xff7fff7f; and.b32 s2, s0, 0xff80ff80; sub.bf16x2 $0, s1, s2; - } + }} """, "=r,r", ) - new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32)) - new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32)) + empty_vec_32 = llvm.mlir_undef(ir.VectorType.get((vector_len // 2,), i32)) + for idx, reg in np.ndenumerate(self.registers): + if vector_len == 2: + reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) + new_reg_32 = upcast_to_bf16(reg_16, high=False) + new_vec_32 = llvm.insertelement(empty_vec_32, new_reg_32, c(0, i32)) + elif vector_len == 4: + reg_32 = vector.bitcast(ir.VectorType.get((1,), i32), reg) + low = upcast_to_bf16(reg_32, high=False) + high = upcast_to_bf16(reg_32, high=True) + new_vec_32 = llvm.insertelement(empty_vec_32, low, c(0, i32)) + new_vec_32 = llvm.insertelement(new_vec_32, high, c(1, i32)) + else: + raise NotImplementedError(vector_len) new_registers[idx] = vector.bitcast( - ir.VectorType.get((2,), new_dtype), new_vec + ir.VectorType.get((vector_len,), new_dtype), new_vec_32 ) return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 7dadc71fdcba..71f2d383f809 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1608,19 +1608,19 @@ def kernel(ctx, out, *_): np.testing.assert_array_equal(result, x) - @parameterized.named_parameters( - ("_bf16", jnp.bfloat16) - ) - def test_fast_i8_convert(self, jax_dtype_to): - jax_dtype_to = jnp.dtype(jax_dtype_to) + @parameterized.parameters(2, 4) + def test_fast_i8_convert(self, reg_length): + jax_dtype_to = jnp.dtype(jnp.bfloat16) jax_dtype_from = jnp.dtype(jnp.int8) mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) + assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] arr.astype(mlir_dtype_to).store_untiled(out) x = jnp.arange(-128, 128, dtype=jax_dtype_from) + x = jnp.tile(x, reg_length // 2) reference = x.astype(jax_dtype_to) result = mgpu.as_gpu_kernel( From 14ddb81949675f1e72f0af0802f25cc916e9974e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 28 Nov 2024 05:41:36 -0800 Subject: [PATCH 543/698] [Mosaic GPU] Avoid double-predication when async_copy predicate is specified PiperOrigin-RevId: 700999181 --- jax/experimental/mosaic/gpu/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 409a87eb9af7..d9894051dc81 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -356,7 +356,7 @@ def async_copy( arrive: bool | None = None, uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, - predicate: ir.Value | None = None, + predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. ): index = ir.IndexType.get() i16 = ir.IntegerType.get_signless(16) @@ -504,7 +504,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): uniform_ctx = ( functools.partial(utils.single_thread, per_block=False) - if uniform + if uniform and predicate is None else contextlib.nullcontext ) From d5bfafbcb6937ead3b4b5005c6b9ce7498732430 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 28 Nov 2024 06:01:33 -0800 Subject: [PATCH 544/698] [mgpu] Added a missed case for debug_print types and raise a proper error if a type is unexpected. PiperOrigin-RevId: 701003002 --- jax/experimental/mosaic/gpu/utils.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index fcba0518620b..2279df4f3984 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -108,21 +108,18 @@ def c(val: int | float, ty): return arith.constant(ty, attr) def _debug_scalar_ty_format(arg): - ty_format = None if ir.IndexType.isinstance(arg.type): - return "%llu" + return "%llu", arg if ir.IntegerType.isinstance(arg.type): - width = ir.IntegerType(arg.type).width - ty_format = "%llu" - if width < 64: + if ir.IntegerType(arg.type).width < 64: arg = arith.extui(ir.IntegerType.get_signless(64), arg) + return "%llu", arg if ir.F32Type.isinstance(arg.type): - ty_format = "%f" + return "%f", arg if ir.F16Type.isinstance(arg.type): - ty_format = "%f" arg = arith.extf(ir.F32Type.get(), arg) - - return ty_format, arg + return "%f", arg + raise NotImplementedError(f"Can't print the type {arg.type}") def debug_print(fmt, *args, uniform=True): type_formats = [] From b801539f5c9a3857ce0d274f8ca61f5c4259b5ee Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 28 Nov 2024 08:34:19 -0800 Subject: [PATCH 545/698] [Pallas][Mosaic GPU] Add support for compressing squeezed dims in async_copy + grid fixes This change removes the need to flatten the batch dimension into sequence dimensions in the flash attention kernel. The critical thing here is the observation that we can in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting us reduce its rank when necessary. Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU lowering, which I've fixed. PiperOrigin-RevId: 701035277 --- jax/_src/pallas/mosaic_gpu/lowering.py | 29 +++--- jax/experimental/mosaic/gpu/core.py | 89 +++++++++++++++++-- .../pallas/ops/gpu/attention_mgpu.py | 34 +++---- tests/mosaic/gpu_test.py | 2 +- 4 files changed, 110 insertions(+), 44 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6e7adfc60a53..87dfe2ce776e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -360,10 +360,6 @@ def lower_jaxpr_to_module( assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims - if len(grid_mapping.grid) > 3: - raise NotImplementedError( - "Only <=3D grids are supported in Mosaic GPU lowering." - ) if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "Dynamic grid bounds not supported in the Mosaic GPU lowering." @@ -397,16 +393,19 @@ def lower_jaxpr_to_module( f" {max_concurrent_steps=}, {delay_release=}" ) - block = (128, 1, 1) - grid = grid_mapping.grid if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid[:-1] - - grid = [d for i, d in enumerate(grid) if i not in sequential_axes] - if len(grid) < 3: - grid += (1,) * (3 - len(grid)) + logical_grid = grid_mapping.grid[:-1] else: + block = (128, 1, 1) + logical_grid = grid_mapping.grid + + parallel_grid = [ + d for i, d in enumerate(logical_grid) if i not in sequential_axes + ] + if len(parallel_grid) < 3: + parallel_grid += (1,) * (3 - len(parallel_grid)) + elif len(parallel_grid) > 3: raise NotImplementedError( "Only <=3D grids are supported in Mosaic GPU lowering." ) @@ -500,7 +499,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): _program_id(next(parallel_count)) if axis not in sequential_axes else None - for axis in range(len(grid_mapping.grid)) + for axis in range(len(logical_grid)) ] def make_program_ids(step: ir.Value): @@ -788,7 +787,7 @@ def _(step, carry): prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel( body, - grid=grid, + grid=parallel_grid, cluster=(), block=block, in_shapes=in_structs_gmem, @@ -806,7 +805,9 @@ def _(step, carry): prof_spec=prof_spec, ) - return LoweringResult(module, grid, block, out_structs_gmem, prof_ctx) + return LoweringResult( + module, parallel_grid, block, out_structs_gmem, prof_ctx + ) mosaic_lowering_rules = {} diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index d9894051dc81..16b7f1f59c33 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -234,6 +234,57 @@ def batch(self, leading_rank: int) -> MemRefTransform: ) +@dataclasses.dataclass(frozen=True) +class CollapseLeadingIndicesTransform(MemRefTransform): + """Collapses leading indices into one.""" + strides: tuple[int, ...] + + @functools.cached_property + def common_stride(self) -> int: + return math.gcd(*self.strides) + + def apply(self, ref: ir.Value) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + strides, offset = ref_ty.get_strides_and_offset() + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + raise NotImplementedError("Dynamic offsets are not supported") + max_bound = sum( + (d - 1) * s // self.common_stride + for d, s in zip( + ref_ty.shape[: len(self.strides)], strides[: len(self.strides)] + ) + ) + 1 + new_shape = [max_bound, *ref_ty.shape[len(self.strides):]] + new_strides = [self.common_stride, *strides[len(self.strides):]] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ref_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.reinterpret_cast( + new_ref_ty, ref, [], [], [], + static_offsets=[offset], + static_sizes=new_shape, + static_strides=new_strides, + ) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + flat_idx = c(0, index) + for i, s in zip(idx[:len(self.strides)], self.strides): + flat_idx = arith.addi( + flat_idx, arith.muli(i, c(s // self.common_stride, index)) + ) + return (flat_idx, *idx[len(self.strides):]) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + if any(s != 1 for s in shape[:len(self.strides)]): + raise ValueError("Expected leading indices to be squeezed") + return (1, *shape[len(self.strides):]) + + def batch(self, leading_rank: int) -> MemRefTransform: + raise NotImplementedError # Unused + + OnDeviceProfiler = profiler.OnDeviceProfiler @@ -397,6 +448,17 @@ def async_copy( or gmem_ref.owner.opview.OPERATION_NAME != expected_name ): raise ValueError("GMEM reference in async_copy must be a kernel argument") + gmem_ref_ty = ir.MemRefType(gmem_ref.type) + gmem_strides, _ = gmem_ref_ty.get_strides_and_offset() + if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape): + raise NotImplementedError( + "async_copy assumes the GMEM reference is contiguous" + ) + if any(s * element_bytewidth % 16 != 0 for s in gmem_strides[:-1]): + raise ValueError( + "async_copy requires all GMEM strides except the last one to be a" + " multiple of 16 bytes" + ) base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape @@ -421,9 +483,25 @@ def async_copy( dyn_base_indices = t.transform_index(dyn_base_indices) slice_shape = t.transform_shape(slice_shape) + num_squeezed_dims = len(squeezed_dims) + if len(slice_shape) > 5: + # We can try to collapse all squeezed dims into one. + if len(slice_shape) - num_squeezed_dims + 1 > 5: + raise ValueError( + "Async copies only support striding up to 5 dimensions" + ) + collapse = CollapseLeadingIndicesTransform( + tuple(gmem_strides[d] for d in squeezed_dims) + ) + gmem_transform = (*gmem_transform, collapse) + dyn_base_indices = collapse.transform_index(dyn_base_indices) + slice_shape = collapse.transform_shape(slice_shape) + num_squeezed_dims = 1 + del squeezed_dims, sliced_dims # Those no longer make sense. + smem_ref_ty = ir.MemRefType(smem_ref.type) # We moved all squeezed dims to the front. - if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape): + if slice_shape[num_squeezed_dims:] != tuple(smem_ref_ty.shape): raise ValueError( "Expected the SMEM reference to have the same shape as the" f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" @@ -437,7 +515,7 @@ def async_copy( dyn_base_indices = list(dyn_base_indices) slice_shape = list(slice_shape) - assert all(d == 1 for d in slice_shape[:len(squeezed_dims)]) + assert all(d == 1 for d in slice_shape[:num_squeezed_dims]) collective_size = 1 if collective is not None: if isinstance(collective, gpu.Dimension): @@ -446,14 +524,14 @@ def async_copy( if collective_size > 1: def partition_dim(dim: int, idx: ir.Value, num_chunks: int): # No need to partition squeezed dims. They don't even exist in smem_ref. - assert dim >= len(squeezed_dims) + assert dim >= num_squeezed_dims nonlocal smem_ref slice_shape[dim] //= num_chunks block_offset = arith.muli(idx, c(slice_shape[dim], index)) dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) smem_ref = utils.memref_slice( smem_ref, - (slice(None),) * (dim - len(squeezed_dims)) + (slice(None),) * (dim - num_squeezed_dims) + (utils.ds(block_offset, slice_shape[dim]),), ) stride = 1 @@ -508,9 +586,6 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): else contextlib.nullcontext ) - rank = len(slice_shape) - if rank > 5: # TODO: apaszke - Implement stride compression - raise ValueError("Async copies only support striding up to 5 dimensions") if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index c4ac7e625942..78db197c673d 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -61,25 +61,14 @@ def attention(q, k, v, config: TuningConfig): raise ValueError(f"{head_dim=} must be divisible by 64") if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") - # Squash batch and sequence dimensions. - # This is required because CUDA grid/TMA descriptors have a limited number of - # slice dimensions. - # TODO(apaszke): Implement slice squashing for TMAs. - q = jnp.reshape(q, (batch_size * q_seq_len, num_q_heads, head_dim)) - k = jnp.reshape(k, (batch_size * kv_seq_len, num_kv_heads, head_dim)) - v = jnp.reshape(v, (batch_size * kv_seq_len, num_kv_heads, head_dim)) max_concurrent_steps = min( config.max_concurrent_steps, kv_seq_len // config.block_kv ) block_q, block_kv = config.block_q, config.block_kv - num_q_tiles, rem = divmod(q_seq_len, block_q * 2) - if rem: - raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") def kernel(q_ref, k_ref, v_ref, out_ref, scoped): - bidx = lax.div(lax.axis_index("bq"), num_q_tiles) - qidx = lax.rem(lax.axis_index("bq"), num_q_tiles) + batch = lax.axis_index("batch") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") qo_smem2, k_smem, v_smem = smem_buffers @@ -93,11 +82,11 @@ def perform_schedule_barrier(): def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] - q_seq_base = qidx * (2 * block_q) + wg_idx * block_q + bidx * q_seq_len + q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") plgpu.copy_gmem_to_smem( - q_ref.at[pl.ds(q_seq_base, block_q), q_head], + q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], qo_smem, q_barriers.at[wg_idx], ) @@ -167,7 +156,7 @@ def _wait(): qo_smem[...] = acc.astype(dtype) plgpu.commit_smem() plgpu.copy_smem_to_gmem( - qo_smem, out_ref.at[pl.ds(q_seq_base, block_q), q_head], + qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) @@ -175,16 +164,14 @@ def _memory_wg(): plgpu.set_max_registers(40, action="decrease") kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) for i in range(max_concurrent_steps): - start = i * block_kv + bidx * kv_seq_len - s = (pl.ds(start, block_kv), kv_head) + s = (batch, pl.ds(i * block_kv, block_kv), kv_head) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) - start = tma_step * block_kv + bidx * kv_seq_len - s = (pl.ds(start, block_kv), kv_head) + s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barrier) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barrier) @@ -199,10 +186,13 @@ def kv_epilogue(i, _): def run(refs): q_ref, k_ref, v_ref, out_ref = refs + num_q_tiles, rem = divmod(q_seq_len, block_q * 2) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") mesh = plgpu.GPUMesh( - grid=(batch_size * num_q_tiles, num_q_heads), + grid=(batch_size, num_q_tiles, num_q_heads), num_threads=3, - axis_names=("bq", "heads", "wg"), + axis_names=("batch", "q_seq", "heads", "wg"), approx_math=True, ) @pl.core_map(mesh) @@ -236,7 +226,7 @@ def _kernel_entry(): ) _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) - return jnp.reshape(out, [batch_size, q_seq_len, num_q_heads, head_dim]) + return out @jax.jit diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 71f2d383f809..39182841f190 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1240,7 +1240,7 @@ def run_kernel(shape): x = np.arange(np.prod(shape)).reshape(shape) _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) - with self.assertRaisesRegex(ValueError, "only support striding up to 5"): + with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"): run_kernel([1] * 6) with self.assertRaisesRegex( From db158e6c7aff1d90fc1985bbad1d410b815343a9 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 28 Nov 2024 09:09:58 -0800 Subject: [PATCH 546/698] [Mosaic GPU] Improve the implementation of max and exp Both are very important for FlashAttention and both were poorly mapped to PTX. For exp, we really do not care about denormals when running in approximate mode, since they would produce results so close to 1 that it really doesn't matter. For max, LLVM ended up generating a while bunch of comparisons and selects and failed to take advantage of the max instructions present in GPUs. Both of those changes _significantly_ improve the performance of Mosaic attention kernels for heads smaller than 256 (when the pointwise part dominates the execution time). In one example I looked at, the utilization jumps from 55% to 64%. PiperOrigin-RevId: 701042779 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- .../mosaic/gpu/fragmented_array.py | 61 +++++++++++++------ 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 87dfe2ce776e..42d3fd97f8d5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1230,7 +1230,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): raise NotImplementedError - return x.reduce(arith_dialect.maxnumf, axes[0]) + return x.reduce("max", axes[0]) case _: raise NotImplementedError(f"Unsupported layout {x.layout}") diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index dc2a5f0d891a..ef52c30e270c 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -879,7 +879,10 @@ def _compare(self, other, *, f_pred, si_pred, ui_pred): def max(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.maximumf, other) + maximumf = arith.maximumf + if ir.F32Type.isinstance(self.mlir_dtype): + maximumf = self._lift_fast_instr("max.NaN.f32") + return self._pointwise(maximumf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise( arith.maxsi if self.is_signed else arith.maxui, other @@ -907,8 +910,8 @@ def exp(self, *, approx: bool = False): log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634)) def fast_exp(x): scaled = arith.mulf(x, log2e) - return llvm.inline_asm(f32, [scaled], "ex2.approx.f32 $0, $1;", "=f,f") - return self._pointwise(self._lift_fast_unary(fast_exp)) + return llvm.inline_asm(f32, [scaled], "ex2.approx.ftz.f32 $0, $1;", "=f,f") + return self._pointwise(self._lift_fast_instr(fast_exp)) return self._pointwise(mlir_math.exp) def sin(self, *, approx: bool = False): @@ -917,7 +920,7 @@ def sin(self, *, approx: bool = False): if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( - self._lift_fast_unary("sin.approx.f32") if approx else mlir_math.sin + self._lift_fast_instr("sin.approx.f32") if approx else mlir_math.sin ) def cos(self, *, approx: bool = False): @@ -926,7 +929,7 @@ def cos(self, *, approx: bool = False): if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( - self._lift_fast_unary("cos.approx.f32") if approx else mlir_math.cos + self._lift_fast_instr("cos.approx.f32") if approx else mlir_math.cos ) def tanh(self, *, approx: bool = False): @@ -935,7 +938,7 @@ def tanh(self, *, approx: bool = False): if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( - self._lift_fast_unary("tanh.approx.f32") if approx else mlir_math.tanh + self._lift_fast_instr("tanh.approx.f32") if approx else mlir_math.tanh ) def rsqrt(self, *, approx: bool = False): @@ -944,31 +947,36 @@ def rsqrt(self, *, approx: bool = False): if approx and self.mlir_dtype != ir.F32Type.get(): raise NotImplementedError return self._pointwise( - self._lift_fast_unary("rsqrt.approx.f32") if approx else mlir_math.rsqrt + self._lift_fast_instr("rsqrt.approx.f32") if approx else mlir_math.rsqrt ) @staticmethod - def _lift_fast_unary( + def _lift_fast_instr( instr: str | Callable[[ir.Value], ir.Value], ) -> Callable[[ir.Value], ir.Value]: - def fast_instr(x): + def fast_instr(*args): f32 = ir.F32Type.get() - if x.type == f32: + arg_ty = args[0].type + assert all(a.type == arg_ty for a in args) + if arg_ty == f32: if isinstance(instr, str): - return llvm.inline_asm(f32, [x], instr + " $0, $1;", "=f,f") + args_ptx = ", ".join(f"${i}" for i in range(len(args) + 1)) + return llvm.inline_asm( + f32, args, f"{instr} {args_ptx};", "=f" + ",f" * len(args) + ) else: - return instr(x) - elif ir.VectorType.isinstance(x.type): + return instr(*args) + elif ir.VectorType.isinstance(arg_ty): index = ir.IndexType.get() - result = llvm.mlir_undef(x.type) - [vec_len] = ir.VectorType(x.type).shape + result = llvm.mlir_undef(arg_ty) + [vec_len] = ir.VectorType(arg_ty).shape for i in range(vec_len): - v = vector.extractelement(x, position=c(i, index)) - vr = fast_instr(v) + vs = [vector.extractelement(a, position=c(i, index)) for a in args] + vr = fast_instr(*vs) result = vector.insertelement(vr, result, position=c(i, index)) return result else: - raise NotImplementedError(x.type) + raise NotImplementedError(arg_ty) return fast_instr def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): @@ -1156,7 +1164,20 @@ def reduce_sum(self, scratch) -> ir.Value: utils.warpgroup_barrier() # Make sure everyone is done using scratch. return result - def reduce(self, op, axis): + def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): + if isinstance(op, str): + match op: + case "max": + if ir.F32Type.isinstance(self.mlir_dtype): + op = self._lift_fast_instr("max.NaN.f32") + elif ir.FloatType.isinstance(self.mlir_dtype): + op = arith.maximumf + elif ir.IntegerType.isinstance(self.mlir_dtype): + op = arith.maxsi if self.is_signed else arith.maxui + else: + raise NotImplementedError(self.mlir_dtype) + case _: + raise ValueError(f"Unrecognized reduction operator: {op}") if self.layout != WGMMA_LAYOUT: raise NotImplementedError(self.layout) if axis != 1: @@ -1421,7 +1442,7 @@ def load_tiled( tiled_shape = ref_ty.shape if len(tiled_shape) % 2: raise ValueError("Tiled reference must have even rank") - tiling = Tiling((tiled_shape[len(tiled_shape) // 2:],)) + tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],)) shape = tiling.untile_shape(tiled_shape) registers = np.full(layout.registers_shape(shape), None, dtype=object) reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) From 456dfeb0aebdecbec7ae32f75164c21aaff21e60 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 28 Nov 2024 09:23:32 -0800 Subject: [PATCH 547/698] [Take 2] Raise a better error message if anything other than a sequence of ints is passed to make_mesh or create_device_mesh Reverts a158e02b7d1c1a50e53adfec7f48bec69cc0dc5b PiperOrigin-RevId: 701045239 --- jax/_src/mesh_utils.py | 29 +++++++++++++++++++++++------ jax/_src/sharding_impls.py | 14 +++++++++++--- tests/mesh_utils_test.py | 6 ++++++ tests/pjit_test.py | 8 ++++++++ 4 files changed, 48 insertions(+), 9 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index 16e34e1afaef..588863d1f244 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -705,6 +705,15 @@ def _transpose_trick( *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] ) +def _canonicalize_axis_sizes(axis_sizes: Sequence[int] + ) -> tuple[int, ...] | None: + new_sizes = [] + for s in axis_sizes: + try: + new_sizes.append(int(s)) + except: + return None + return tuple(new_sizes) def create_device_mesh( mesh_shape: Sequence[int], @@ -740,17 +749,25 @@ def create_device_mesh( """ if devices is None: devices = xb.devices() - if np.prod(mesh_shape) != len(devices): + + new_mesh_shape = _canonicalize_axis_sizes(mesh_shape) + if new_mesh_shape is None: + raise ValueError( + f'`mesh_shape` passed to `create_device_mesh` should be a sequence of' + f' ints. Got {mesh_shape}') + del mesh_shape + + if math.prod(new_mesh_shape) != len(devices): raise ValueError( f'Number of devices {len(devices)} must equal the product ' - f'of mesh_shape {mesh_shape}' + f'of mesh_shape {new_mesh_shape}' ) last_device = devices[-1] handler = device_kind_handler_dict.get(last_device.device_kind, None) if handler is not None: result = handler( - mesh_shape, devices, contiguous_submeshes=contiguous_submeshes + new_mesh_shape, devices, contiguous_submeshes=contiguous_submeshes ) if result is not None: return result @@ -758,15 +775,15 @@ def create_device_mesh( if last_device.platform == 'tpu': physical_mesh = _get_physical_tpu_mesh(devices) if contiguous_submeshes: - physical_mesh = _transpose_trick(physical_mesh, mesh_shape) + physical_mesh = _transpose_trick(physical_mesh, new_mesh_shape) device_mesh, _ = _create_device_mesh_for_nd_torus( physical_mesh, - mesh_shape, + new_mesh_shape, allow_split_physical_axes=allow_split_physical_axes, ) return device_mesh else: - device_mesh = np.asarray(devices).reshape(mesh_shape) + device_mesh = np.asarray(devices).reshape(new_mesh_shape) return device_mesh diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 8abe58e52a74..5e1def1079ac 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1714,11 +1714,18 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], """ if devices is None: devices = xla_bridge.devices() - axis_size = math.prod(axis_shapes) + new_axis_shapes = mesh_utils._canonicalize_axis_sizes(axis_shapes) + if new_axis_shapes is None: + raise ValueError( + '`axis_shapes` passed to `make_mesh` should be a sequence of ints.' + f' Got {axis_shapes}') + del axis_shapes + + axis_size = math.prod(new_axis_shapes) if axis_size > len(devices): raise ValueError( f'Number of devices {len(devices)} must be >= the product ' - f'of mesh_shape {axis_shapes}') + f'of mesh_shape {new_axis_shapes}') elif axis_size < len(devices): devices = devices[:axis_size] if devices[0].device_kind in (mesh_utils._TPU_V5_LITE, mesh_utils._TPU_V5E): @@ -1726,5 +1733,6 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], else: allow_split_physical_axes = False mesh_devices = mesh_utils.create_device_mesh( - axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes) + new_axis_shapes, devices, + allow_split_physical_axes=allow_split_physical_axes) return mesh_lib.Mesh(mesh_devices, axis_names) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 66f1fc9f6cfb..4f1b1fb037d6 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -353,6 +353,12 @@ def test_create_device_mesh_for_nd_torus( ) self.assertArraysEqual(assignment, expected_assignment_matrix) + def test_create_device_mesh_non_int_error(self): + with self.assertRaisesRegex( + ValueError, + "`mesh_shape` passed to `create_device_mesh` should be a sequence of ints"): + mesh_utils.create_device_mesh(((4,), 4)) + @parameterized.named_parameters( ('2x2x1', mock_2x2x1_devices,), ('2x2x4', mock_2x2x4_devices, ), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e541c6346666..af81b35570d6 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4458,6 +4458,14 @@ def g(x): self.assertEqual(out2.sharding, s) self.assertEqual(out2.dtype, np.float32) + def test_make_mesh_non_int_error(self): + with self.assertRaisesRegex( + ValueError, + "`axis_shapes` passed to `make_mesh` should be a sequence of ints"): + jax.make_mesh(((4,), 4), ('x', 'y')) + + jax.make_mesh((1, np.int32(1), np.int64(1)), ('x', 'y', 'z')) # doesn't crash + def test_jnp_array_reshard_error(self): if jax.device_count() < 2: self.skipTest('Requires >=2 devices') From f73de23026f62c6b6b9fd412442dbac602b186b0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 28 Nov 2024 15:38:29 -0800 Subject: [PATCH 548/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/fc80c5576b71c986fbd4505a59826f7d433878bc. PiperOrigin-RevId: 701110365 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 177eaff34b23..4957bdcf10b4 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "d0769456f09dde192ab6e6421648bddc38908b39" -XLA_SHA256 = "e3bef213eef3ea6a8459bfdfe863b19380219af0a6c5a017226b1bdc1bda17e9" +XLA_COMMIT = "fc80c5576b71c986fbd4505a59826f7d433878bc" +XLA_SHA256 = "00cc8500299f22d4ec047e08e9b7bd357ba74e9e67acca4131231fadc47ab90a" def repo(): tf_http_archive( From f10d3eb312559a52afa99d0421a2d5a1833ba34e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 29 Nov 2024 02:20:35 -0800 Subject: [PATCH 549/698] [Mosaic GPU] Allow contracting ops into FMAs Using FMAs can significantly increase the ALU throughput and only increases the precision. We use this capability to reduce the number of operations needed to evaluate the softmax part of attention. PiperOrigin-RevId: 701226007 --- jax/_src/pallas/mosaic_gpu/lowering.py | 9 +++- .../mosaic/gpu/fragmented_array.py | 49 ++++++++++++++----- .../pallas/ops/gpu/attention_mgpu.py | 11 +++-- 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 42d3fd97f8d5..0407d76643b8 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1199,6 +1199,13 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): return a.exp(approx=ctx.module_ctx.approx_math) +@register_lowering_rule(lax.exp2_p) +def _exp2_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + a = _ensure_fa(x, x_aval.dtype) + return a.exp2(approx=ctx.module_ctx.approx_math) + + @register_lowering_rule(lax.reduce_sum_p) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in @@ -1216,7 +1223,7 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): raise NotImplementedError - return x.reduce(arith_dialect.addf, axes[0]) + return x.reduce("add", axes[0]) case _: raise NotImplementedError(f"Unsupported layout {x.layout}") diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ef52c30e270c..daabe9b0f060 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -700,7 +700,7 @@ def __neg__(self): def __add__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.addf, other) + return self._pointwise(addf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.addi, other) else: @@ -711,7 +711,7 @@ def __radd__(self, other): def __mul__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.mulf, other) + return self._pointwise(mulf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.muli, other) else: @@ -722,7 +722,7 @@ def __rmul__(self, other): def __sub__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(arith.subf, other) + return self._pointwise(subf, other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.subi, other) else: @@ -730,7 +730,7 @@ def __sub__(self, other): def __rsub__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): - return self._pointwise(lambda s, o: arith.subf(o, s), other) + return self._pointwise(lambda s, o: subf(o, s), other) elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(lambda s, o: arith.subi(o, s), other) else: @@ -904,16 +904,20 @@ def exp(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx: - f32 = ir.F32Type.get() - if self.mlir_dtype != f32: - raise NotImplementedError - log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634)) - def fast_exp(x): - scaled = arith.mulf(x, log2e) - return llvm.inline_asm(f32, [scaled], "ex2.approx.ftz.f32 $0, $1;", "=f,f") - return self._pointwise(self._lift_fast_instr(fast_exp)) + dtype = self.mlir_dtype + log2e = arith.constant(dtype, ir.FloatAttr.get(dtype, 1.4426950408889634)) + return (self * log2e).exp2() return self._pointwise(mlir_math.exp) + def exp2(self, *, approx: bool = False): + if not ir.FloatType.isinstance(self.mlir_dtype): + raise NotImplementedError + if approx: + if not ir.F32Type.isinstance(self.mlir_dtype): + raise NotImplementedError(self.mlir_dtype) + return self._pointwise(self._lift_fast_instr("ex2.approx.ftz.f32")) + return self._pointwise(mlir_math.exp2) + def sin(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError @@ -1125,7 +1129,7 @@ def upcast_to_bf16(reg, high): # NOTE: scratch can be reused immediately once this function returns. def reduce_sum(self, scratch) -> ir.Value: if ir.FloatType.isinstance(self.mlir_dtype): - op = arith.addf + op = addf elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.addi else: @@ -1167,6 +1171,13 @@ def reduce_sum(self, scratch) -> ir.Value: def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): if isinstance(op, str): match op: + case "add": + if ir.FloatType.isinstance(self.mlir_dtype): + op = addf + elif ir.IntegerType.isinstance(self.mlir_dtype): + op = arith.addi + else: + raise NotImplementedError(self.mlir_dtype) case "max": if ir.F32Type.isinstance(self.mlir_dtype): op = self._lift_fast_instr("max.NaN.f32") @@ -1653,3 +1664,15 @@ def tree_unflatten(cls, aux, flat_registers): layout, reg_shape, is_signed = aux registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) + + +# We allow contractions, to potentially take advantage of FMA instructions. +# They can change the results, but the precision should only increase. +def addf(a: ir.Value, b: ir.Value): + return arith.addf(a, b, fastmath=arith.FastMathFlags.contract) + +def subf(a: ir.Value, b: ir.Value): + return arith.subf(a, b, fastmath=arith.FastMathFlags.contract) + +def mulf(a: ir.Value, b: ir.Value): + return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 78db197c673d..6f02396ccb92 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -16,6 +16,7 @@ import dataclasses import functools import itertools +import math import jax from jax import lax from jax._src import test_util as jtu # noqa: F401 @@ -118,11 +119,13 @@ def compute_qk(acc_ref): plgpu.barrier_arrive(k_consumed_barrier) # Softmax - m_ij = jnp.maximum(m_i, qk.max(axis=1)) - alpha = jnp.exp(m_i - m_ij) + # We keep m scaled by log2e to use FMA instructions when computing p. + log2e = math.log2(math.e) + m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e) + alpha = jnp.exp2(m_i - m_ij) m_i = m_ij - p = jnp.exp(qk - lax.broadcast_in_dim(m_ij, (block_q, block_kv), [0])) - acc *= lax.broadcast_in_dim(alpha, (block_q, head_dim), [0]) + p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0])) + acc *= lax.broadcast_in_dim(alpha, acc.shape, [0]) l_i *= alpha p16 = p.astype(dtype) From ea69401effaea5fce2f52e96dbdd42d7c9cdf287 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 29 Nov 2024 09:34:04 -0800 Subject: [PATCH 550/698] [mgpu] Fixed off-by-one issue in pointwise argument shuffling when leading argument is splat. Also adapted the test to catch a possible regression. The issue appeared in >2 operands. PiperOrigin-RevId: 701306731 --- jax/experimental/mosaic/gpu/fragmented_array.py | 2 +- tests/mosaic/gpu_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index daabe9b0f060..2f53b08e3af4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -632,7 +632,7 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): continue elif not isinstance(o.layout, WGSplatFragLayout): return o._pointwise( - lambda o, *args: op(*args[:i], o, *args[i:]), + lambda o, this, *args: op(this, *args[:i], o, *args[i:]), self, *other[:i], *other[i + 1 :], diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 39182841f190..b9d0a591554d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1565,14 +1565,14 @@ def kernel(ctx, src, dst, _): assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) - (pi_arr_sq + pi_arr_cube).store_untiled(dst) + (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) inp = jnp.ones_like(out_shape) * 3.14 result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), inp, out_shape, () )(inp) - np.testing.assert_allclose(result, np.full((128, 32), 3.14 ** 2 + 3.14 ** 3, np.float32)) + np.testing.assert_allclose(result, np.full((128, 32), 3.14, np.float32)) @parameterized.product(in_shape=((128, 128), (128, 64), (64, 128))) From c3c21c74627e04ad40815eba414f665e7e7b9b77 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 29 Nov 2024 09:38:54 -0800 Subject: [PATCH 551/698] [mgpu_pallas] Better support for unsigned integers and floats in iota. PiperOrigin-RevId: 701307324 --- jax/_src/pallas/mosaic_gpu/primitives.py | 22 +++++++++++++++++++--- tests/pallas/mosaic_gpu_test.py | 5 +++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index c1ea3e9870b6..3c5eaa7910dd 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -706,17 +706,33 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): @lowering.register_lowering_rule(broadcasted_iota_p) def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout): del ctx - undef = llvm_dialect.mlir_undef(mlir.dtype_to_ir_type(dtype)) + # Unsigned integers (as opposed to signless) cause MLIR verification + # errors so we only use signless like Mosaic GPU does. + # + # TODO(cperivol): use mgpu.utils.dtype_to_ir_type() instead. + mlir_dtype = ( + ir.IntegerType.get_signless(dtype.itemsize * 8) + if jnp.issubdtype(dtype, jnp.integer) + else mlir.dtype_to_ir_type(dtype) + ) + undef = llvm_dialect.mlir_undef(mlir_dtype) is_signed = ( jnp.issubdtype(dtype, jnp.signedinteger) if jnp.issubdtype(dtype, jnp.integer) else None ) - mlir_dtype = mlir.dtype_to_ir_type(dtype) + + i32 = ir.IntegerType.get_signless(32) + def _cast(x): + if ir.FloatType.isinstance(mlir_dtype): + x = arith_dialect.index_cast(i32, x) + return arith_dialect.uitofp(mlir_dtype, x) + else: + return arith_dialect.index_cast(mlir_dtype, x) return mgpu.FragmentedArray.splat( undef, shape, layout.value, is_signed=is_signed ).foreach( - lambda _, idx: arith_dialect.index_cast(mlir_dtype, idx[dimension]), create_array=True, is_signed=is_signed + lambda _, idx: _cast(idx[dimension]), create_array=True, is_signed=is_signed ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index dba1e67acf02..fa8597ab3195 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -241,8 +241,9 @@ def kernel(x_ref, o_ref): # are never written to. np.testing.assert_array_equal(kernel(x)[:, :16], y[:, :16]) - def test_iota(self): - dtype, dimension = jnp.int8, 1 + @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) + def test_iota(self, dtype): + dimension = 1 @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct((128, 128), dtype), From 031c0acf5053d42c65632b2b0f83a04b61768dd2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 29 Nov 2024 12:07:42 -0800 Subject: [PATCH 552/698] Add new CI scripts for running Pytests This commit adds the new CI scripts for running Pytests. It makes use of the pytest envs inside the "ci/envs/run_tests" folder to control the build behavior. For e.g: for running the GPU tests with Pytest, we will need to run `./ci/run_pytest.sh ./ci/envs/run_tests/pytest_gpu.env`. Note that Pytests need JAX wheels to be installed on the system to work. The `install_wheels_locally.sh` script installs these wheels in CI builds. PiperOrigin-RevId: 701331411 --- ci/envs/default.env | 10 +++++ ci/run_pytest_cpu.sh | 45 +++++++++++++++++++ ci/run_pytest_gpu.sh | 61 ++++++++++++++++++++++++++ ci/run_pytest_tpu.sh | 61 ++++++++++++++++++++++++++ ci/utilities/install_wheels_locally.sh | 33 ++++++++++++++ 5 files changed, 210 insertions(+) create mode 100644 ci/run_pytest_cpu.sh create mode 100644 ci/run_pytest_gpu.sh create mode 100644 ci/run_pytest_tpu.sh create mode 100644 ci/utilities/install_wheels_locally.sh diff --git a/ci/envs/default.env b/ci/envs/default.env index e3bf1a5ab47a..ae434dc61c8f 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -57,3 +57,13 @@ export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12} # Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override # this value in the Github action workflow files. export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} + +# Pytest specific environment variables below. Used in run_pytest_*.sh scripts. +# Sets the number of TPU cores for the TPU machine type. These values are +# defined in the TPU GitHub Actions workflow. +export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} + +# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels +# on the system. By default, it is set to match the version of the hermetic +# Python used by Bazel for building the wheels. +export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh new file mode 100644 index 000000000000..2b19ca5ddaa5 --- /dev/null +++ b/ci/run_pytest_cpu.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pyest CPU tests. Requires a jaxlib wheel to be present +# inside the $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +# Set up all test environment variables +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true +export TF_CPP_MIN_LOG_LEVEL=0 +# End of test environment variable setup + +echo "Running CPU tests..." +"$JAXCI_PYTHON" -m pytest -n auto --tb=short --maxfail=20 tests examples \ No newline at end of file diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh new file mode 100644 index 000000000000..7bc2492781b2 --- /dev/null +++ b/ci/run_pytest_gpu.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt +# wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels inside the +# $JAXCI_OUTPUT_DIR directory on the system. +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +nvidia-smi + +# Set up all test environment variables +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true +export NCCL_DEBUG=WARN +export TF_CPP_MIN_LOG_LEVEL=0 + +# Set the number of processes to run to be 4x the number of GPUs. +export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +export num_processes=`expr 4 \* $gpu_count` + +export XLA_PYTHON_CLIENT_ALLOCATOR=platform +export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 +# End of test environment variable setup + +echo "Running GPU tests..." +"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ +tests examples \ +--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ +--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \ +--deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ +--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ No newline at end of file diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh new file mode 100644 index 000000000000..783d2f9feca5 --- /dev/null +++ b/ci/run_pytest_tpu.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Pyest CPU tests. Requires a jaxlib wheel to be present +# inside $JAXCI_OUTPUT_DIR (../dist) +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. +echo "Installing wheels locally..." +source ./ci/utilities/install_wheels_locally.sh + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +export PY_COLORS=1 +export JAX_SKIP_SLOW_TESTS=true + +"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))" + +"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' +"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' +"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' +strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on' +"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' + +echo "Running TPU tests..." +export JAX_PLATFORMS=tpu,cpu +# Run single-accelerator tests in parallel +export JAX_ENABLE_TPU_XDIST=true + +"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \ +--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ +--maxfail=20 -m "not multiaccelerator" tests examples + +# Run Pallas printing tests, which need to run with I/O capturing disabled. +export TPU_STDERR_LOG_LEVEL=0 +"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest + +# Run multi-accelerator across all chips +"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests \ No newline at end of file diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh new file mode 100644 index 000000000000..181256b90804 --- /dev/null +++ b/ci/utilities/install_wheels_locally.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Install wheels stored in `JAXCI_OUTPUT_DIR` on the system using the Python +# binary set in JAXCI_PYTHON. Use the absolute path to the `find` utility to +# avoid using the Windows version of `find` on Msys. +WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*" -o -name "*jax*cuda*pjrt*" -o -name "*jax*cuda*plugin*" \)) ) + +if [[ -z "$WHEELS" ]]; then + echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" + exit 1 +fi + +echo "Installing the following wheels:" +echo "${WHEELS[@]}" +"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" + +echo "Installing the JAX package in editable mode at the current commit..." +# Install JAX package at the current commit. +"$JAXCI_PYTHON" -m pip install -U -e . From 47858c4ac2fd4757a3b6fc5bb2981b71a71f00c2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 29 Nov 2024 16:23:26 -0800 Subject: [PATCH 553/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/479fb21237319d091ee93e86619c8d4d88bda079. PiperOrigin-RevId: 701368225 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4957bdcf10b4..ea549693abe1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fc80c5576b71c986fbd4505a59826f7d433878bc" -XLA_SHA256 = "00cc8500299f22d4ec047e08e9b7bd357ba74e9e67acca4131231fadc47ab90a" +XLA_COMMIT = "479fb21237319d091ee93e86619c8d4d88bda079" +XLA_SHA256 = "1ad7137b77bffbb11a959a12dee9329eec4f4cae0c0a1d963144579992e059aa" def repo(): tf_http_archive( From cd578d97e8367dce90c96705c386df5aaa299988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tor=20Gunnar=20H=C3=B8st=20Houeland?= <887395+houeland@users.noreply.github.com> Date: Sat, 30 Nov 2024 18:55:00 +0000 Subject: [PATCH 554/698] Fix jnp.matmul return shape documentation If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13) --- jax/_src/numpy/lax_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5f380fad902c..a61b1d67f53e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9076,7 +9076,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, Returns: array containing the matrix product of the inputs. Shape is ``a.shape[:-1]`` - if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading + if ``b.ndim == 1``, otherwise the shape is ``(..., K, M)``, where leading dimensions of ``a`` and ``b`` are broadcast together. See Also: From db4b3f2922ed5d6ec74cc37319cea3063350ef52 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 30 Nov 2024 15:24:30 -0800 Subject: [PATCH 555/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/20d4636c743e53f070612d6b4c6ebd03b2b28bf5. PiperOrigin-RevId: 701562320 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index ea549693abe1..23c476815e04 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "479fb21237319d091ee93e86619c8d4d88bda079" -XLA_SHA256 = "1ad7137b77bffbb11a959a12dee9329eec4f4cae0c0a1d963144579992e059aa" +XLA_COMMIT = "20d4636c743e53f070612d6b4c6ebd03b2b28bf5" +XLA_SHA256 = "9f0c1ba3b0220a9d922acd3df09a0d65a56d37d6e9a5080079fc86b8f67c83fc" def repo(): tf_http_archive( From a1dfdc1d6164ad49afb337da9effd269d430d68b Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Sat, 30 Nov 2024 21:26:07 -0800 Subject: [PATCH 556/698] C++ tree with path API * Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening. * Moves all the key classes down to C++ level, while keeping the APIs unchanged. * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy. * Registered defaultdict and ordereddict via the keypath API now. PiperOrigin-RevId: 701613257 --- jax/_src/tree_util.py | 120 +++++++++++++++-------- tests/package_structure_test.py | 48 ++++++++-- tests/tree_util_test.py | 165 +++++++++++++++++++++++++++++++- 3 files changed, 288 insertions(+), 45 deletions(-) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index bd2ca2df7b0c..73cff5aa8042 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -22,10 +22,11 @@ from functools import partial import operator as op import textwrap -from typing import Any, NamedTuple, TypeVar, Union, overload +from typing import Any, NamedTuple, TypeVar, overload from jax._src import traceback_util from jax._src.lib import pytree +from jax._src.lib import xla_extension_version from jax._src.util import safe_zip, set_module from jax._src.util import unzip2 @@ -209,12 +210,21 @@ def all_leaves(iterable: Iterable[Any], _Children = TypeVar("_Children", bound=Iterable[Any]) _AuxData = TypeVar("_AuxData", bound=Hashable) +KeyEntry = TypeVar("KeyEntry", bound=Any) +KeyLeafPair = tuple[KeyEntry, Any] +KeyLeafPairs = Iterable[KeyLeafPair] +KeyPath = tuple[KeyEntry, ...] @export -def register_pytree_node(nodetype: type[T], - flatten_func: Callable[[T], tuple[_Children, _AuxData]], - unflatten_func: Callable[[_AuxData, _Children], T]) -> None: +def register_pytree_node( + nodetype: type[T], + flatten_func: Callable[[T], tuple[_Children, _AuxData]], + unflatten_func: Callable[[_AuxData, _Children], T], + flatten_with_keys_func: ( + Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None + ) = None, +) -> None: """Extends the set of types that are considered internal nodes in pytrees. See :ref:`example usage `. @@ -279,9 +289,20 @@ def register_pytree_node(nodetype: type[T], >>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32) """ - default_registry.register_node(nodetype, flatten_func, unflatten_func) - none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) - dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) + if xla_extension_version >= 299: + default_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func + ) + none_leaf_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func + ) + dispatch_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func + ) + else: + default_registry.register_node(nodetype, flatten_func, unflatten_func) + none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) + dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) @@ -452,21 +473,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool return all(tree_leaves(tree, is_leaf=is_leaf)) -register_pytree_node( - collections.OrderedDict, - lambda x: (tuple(x.values()), tuple(x.keys())), - lambda keys, values: collections.OrderedDict(safe_zip(keys, values))) - -def _flatten_defaultdict(d): - keys = tuple(sorted(d)) - return tuple(d[k] for k in keys), (d.default_factory, keys) - -register_pytree_node( - collections.defaultdict, - _flatten_defaultdict, - lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) - - class _HashableCallableShim: """Object that delegates __call__, __hash__, and __eq__ to another object.""" @@ -578,11 +584,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any, # flatten_one_level is not exported. -def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: +def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]: """Flatten the given pytree node by one level. Args: - pytree: A valid pytree node, either built-in or registered via + tree: A valid pytree node, either built-in or registered via :func:`register_pytree_node` or related functions. Returns: @@ -601,9 +607,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: >>> meta ('a', 'b') """ - out = default_registry.flatten_one_level(pytree) + out = default_registry.flatten_one_level(tree) if out is None: - raise ValueError(f"can't tree-flatten type: {type(pytree)}") + raise ValueError(f"can't tree-flatten type: {type(tree)}") else: return out @@ -739,10 +745,12 @@ class FlattenedIndexKey(): def __str__(self): return f'[]' -BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey] -KeyEntry = TypeVar("KeyEntry", bound=Hashable) -KeyPath = tuple[KeyEntry, ...] +if xla_extension_version >= 299: + SequenceKey = pytree.SequenceKey # type: ignore + DictKey = pytree.DictKey # type: ignore + GetAttrKey = pytree.GetAttrKey # type: ignore + FlattenedIndexKey = pytree.FlattenedIndexKey # type: ignore @export @@ -764,6 +772,7 @@ def keystr(keys: KeyPath): return ''.join(map(str, keys)) +# TODO(ivyzheng): remove this after _child_keys() also moved to C++. class _RegistryWithKeypathsEntry(NamedTuple): flatten_with_keys: Callable[..., Any] unflatten_func: Callable[..., Any] @@ -780,7 +789,6 @@ def flatten_with_keys(xs): flatten_with_keys, _registry[ty].from_iter ) - _registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {} _register_keypaths( @@ -803,13 +811,9 @@ def flatten_with_keys(xs): @export def register_pytree_with_keys( nodetype: type[T], - flatten_with_keys: Callable[ - [T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData] - ], + flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]], unflatten_func: Callable[[_AuxData, Iterable[Any]], T], - flatten_func: None | ( - Callable[[T], tuple[Iterable[Any], _AuxData]] - ) = None, + flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None, ): """Extends the set of types that are considered internal nodes in pytrees. @@ -870,7 +874,9 @@ def flatten_func_impl(tree): return [c for _, c in key_children], treedef flatten_func = flatten_func_impl - register_pytree_node(nodetype, flatten_func, unflatten_func) + register_pytree_node( + nodetype, flatten_func, unflatten_func, flatten_with_keys + ) _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( flatten_with_keys, unflatten_func ) @@ -1092,6 +1098,40 @@ def flatten_func(x): return nodetype +if xla_extension_version >= 299: + register_pytree_with_keys( + collections.OrderedDict, + lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())), + lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), + ) + + def _flatten_defaultdict_with_keys(d): + keys = tuple(sorted(d)) + return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys) + + register_pytree_with_keys( + collections.defaultdict, + _flatten_defaultdict_with_keys, + lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), + ) +else: + register_pytree_node( + collections.OrderedDict, + lambda x: (tuple(x.values()), tuple(x.keys())), + lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), + ) + + def _flatten_defaultdict(d): + keys = tuple(sorted(d)) + return tuple(d[k] for k in keys), (d.default_factory, keys) + + register_pytree_node( + collections.defaultdict, + _flatten_defaultdict, + lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), + ) + + @export def register_static(cls: type[H]) -> type[H]: """Registers `cls` as a pytree with no leaves. @@ -1144,6 +1184,8 @@ def tree_flatten_with_path( which contains a leaf and its key path. The second element is a treedef representing the structure of the flattened tree. """ + if xla_extension_version >= 299: + return default_registry.flatten_with_path(tree, is_leaf) _, tree_def = tree_flatten(tree, is_leaf) return _generate_key_paths(tree, is_leaf), tree_def @@ -1164,13 +1206,15 @@ def tree_leaves_with_path( - :func:`jax.tree_util.tree_leaves` - :func:`jax.tree_util.tree_flatten_with_path` """ - return _generate_key_paths(tree, is_leaf) + return tree_flatten_with_path(tree, is_leaf)[0] # generate_key_paths is not exported. def generate_key_paths( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tuple[KeyPath, Any]]: + if xla_extension_version >= 299: + return tree_leaves_with_path(tree, is_leaf) return list(_generate_key_paths_((), tree, is_leaf)) _generate_key_paths = generate_key_paths # alias for backward compat diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index 25468c4ba700..d80c750ae859 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -34,17 +34,53 @@ class PackageStructureTest(jtu.JaxTestCase): _mod("jax.errors", exclude=["JaxRuntimeError"]), _mod( "jax.numpy", - exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating", - "dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo", - "flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim", - "number", "object_", "printoptions", "save", "savez", "set_printoptions", - "shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"] + exclude=[ + "array_repr", + "array_str", + "can_cast", + "character", + "complexfloating", + "dtype", + "iinfo", + "index_exp", + "inexact", + "integer", + "iterable", + "finfo", + "flexible", + "floating", + "generic", + "get_printoptions", + "ndarray", + "ndim", + "number", + "object_", + "printoptions", + "save", + "savez", + "set_printoptions", + "shape", + "signedinteger", + "size", + "s_", + "unsignedinteger", + "ComplexWarning", + ], ), _mod("jax.numpy.linalg"), _mod("jax.nn.initializers"), _mod( "jax.tree_util", - exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"], + exclude=[ + "PyTreeDef", + "default_registry", + "KeyEntry", + "KeyPath", + "DictKey", + "GetAttrKey", + "SequenceKey", + "FlattenedIndexKey", + ], ), ]) def test_exported_names_match_module(self, module_name, include, exclude): diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index a3a8bc96eae0..bd0497a33820 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +from collections.abc import Hashable import dataclasses import functools import pickle @@ -24,14 +25,20 @@ from jax import flatten_util from jax import tree_util from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp +# Easier to read. +SequenceKey = tree_util.SequenceKey +DictKey = tree_util.DictKey +GetAttrKey = tree_util.GetAttrKey +FlattenedIndexKey = tree_util.FlattenedIndexKey + def _dummy_func(*args, **kwargs): return - ATuple = collections.namedtuple("ATuple", ("foo", "bar")) class ANamedTupleSubclass(ATuple): @@ -758,6 +765,78 @@ def is_empty(x): ], ) + def testTreeFlattenWithPathBuiltin(self): + x = (1, {"a": 2, "b": 3}) + flattened = tree_util.tree_flatten_with_path(x) + _, tdef = tree_util.tree_flatten(x) + self.assertEqual( + flattened[0], + [ + ((SequenceKey(0),), 1), + ((SequenceKey(1), DictKey("a")), 2), + ((SequenceKey(1), DictKey("b")), 3), + ], + ) + self.assertEqual(flattened[1], tdef) + + def testTreeFlattenWithPathCustom(self): + x = [ + AnObject2( + x=12, + y={"foo": SpecialWithKeys(x=2, y=3), "bar": None}, + z="constantdef", + ), + 5, + ] + flattened, _ = tree_util.tree_flatten_with_path(x) + self.assertEqual( + flattened, + [ + ((SequenceKey(0), "x"), 12), + ((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("x")), 2), + ((SequenceKey(0), "y", DictKey("foo"), GetAttrKey("y")), 3), + ((SequenceKey(1),), 5), + ], + ) + + def testFlattenWithPathDefaultDict(self): + if xla_extension_version < 299: + self.skipTest("Skipping for Python-based with path APIs.") + d = collections.defaultdict(int, {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) + leaves, treedef = tree_util.tree_flatten_with_path(d) + self.assertEqual( + leaves, + [ + ((DictKey("a"),), 1), + ((DictKey("b"),), 2), + ((DictKey("c"), DictKey("a")), 1), + ((DictKey("c"), DictKey("b")), 2), + ], + ) + restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves]) + self.assertEqual(list(restored_d.keys()), ["a", "b", "c"]) + _, from_flatten = tree_util.tree_flatten(d) + self.assertEqual(treedef, from_flatten) + + def testFlattenWithPathOrderedDict(self): + if xla_extension_version < 299: + self.skipTest("Skipping for Python-based with path APIs.") + d = collections.OrderedDict({"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) + leaves, treedef = tree_util.tree_flatten_with_path(d) + self.assertEqual( + leaves, + [ + ((DictKey("b"),), 2), + ((DictKey("a"),), 1), + ((DictKey("c"), DictKey("a")), 1), + ((DictKey("c"), DictKey("b")), 2), + ], + ) + restored_d = tree_util.tree_unflatten(treedef, [l for _, l in leaves]) + self.assertEqual(list(restored_d.keys()), ["b", "a", "c"]) + _, from_flatten = tree_util.tree_flatten(d) + self.assertEqual(treedef, from_flatten) + def testFlattenOneLevel(self): EmptyTuple = collections.namedtuple("EmptyTuple", ()) tree1 = {'a': 1, @@ -838,6 +917,90 @@ def testBadFlattenNonIterableLeaves(self): tree_util.tree_flatten(t) +class TreeKeyTest(absltest.TestCase): + + def testBasic(self): + if xla_extension_version < 299: + self.skipTest("Skipping for Python-based with path APIs.") + + def assert_equal_and_hash_equal(a, b): + self.assertEqual(a, b) + self.assertEqual(hash(a), hash(b)) + + key = SequenceKey(idx=1) + self.assertEqual(str(key), "[1]") + self.assertEqual(key.idx, 1) + assert_equal_and_hash_equal(key, SequenceKey(1)) + + class DictKeyEntry(Hashable): + + def __init__(self, s: str): + self.s = s + + def __hash__(self): + return hash(self.s) + + def __eq__(self, other): + return self.s == other.s + + key = DictKey(key="foo") + self.assertEqual(str(key), "['foo']") + self.assertEqual(key.key, "foo") + assert_equal_and_hash_equal(key, DictKey("foo")) + assert_equal_and_hash_equal( + DictKey(DictKeyEntry("foo")), DictKey(DictKeyEntry("foo")) + ) + + key = GetAttrKey(name="bar") + self.assertEqual(str(key), ".bar") + self.assertEqual(key.name, "bar") + assert_equal_and_hash_equal(key, GetAttrKey("bar")) + + key = FlattenedIndexKey(1) + self.assertEqual(str(key), "[]") + self.assertEqual(key.key, 1) + assert_equal_and_hash_equal(key, FlattenedIndexKey(1)) + self.assertNotEqual(hash(key), hash(SequenceKey(1))) + + def testPatternMatching(self): + keys = [ + SequenceKey(1), + DictKey("foo"), + GetAttrKey("bar"), + FlattenedIndexKey(1), + ] + for key in keys: + match key: + case jax.tree_util.SequenceKey(idx=idx): + self.assertEqual(idx, 1) + case jax.tree_util.DictKey(key=key): + self.assertEqual(key, "foo") + case jax.tree_util.GetAttrKey(name=name): + self.assertEqual(name, "bar") + case jax.tree_util.FlattenedIndexKey(key=idx_key): + self.assertEqual(idx_key, 1) + case _: + raise ValueError(f"key not matched: {key}") + match [ + DictKey("foo"), + ]: + case [DictKey("foo"), *_]: + pass + case _: + raise ValueError(f"keys are not matched: {keys}") + + def testPickle(self): + keys = [ + SequenceKey(1), + DictKey("foo"), + GetAttrKey("bar"), + FlattenedIndexKey(1), + ] + for key in keys: + unpickled = pickle.loads(pickle.dumps(key)) + self.assertEqual(key, unpickled) + + class StaticTest(parameterized.TestCase): @parameterized.parameters( From e124c051f2c6f70b52ca87e10fefe8a5ce9e0d15 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 1 Dec 2024 15:10:42 -0800 Subject: [PATCH 557/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/41e12cc0247edf4ffb1569f2a25c61cec924c755. PiperOrigin-RevId: 701766499 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 23c476815e04..6344349e9ad0 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "20d4636c743e53f070612d6b4c6ebd03b2b28bf5" -XLA_SHA256 = "9f0c1ba3b0220a9d922acd3df09a0d65a56d37d6e9a5080079fc86b8f67c83fc" +XLA_COMMIT = "41e12cc0247edf4ffb1569f2a25c61cec924c755" +XLA_SHA256 = "1eb835846d906813264909024036be866d1ee191959fbe13b0f2eeacbf539fea" def repo(): tf_http_archive( From bd66f5280bde9867b64da1b7d8e2cc7ad2105b67 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 2 Dec 2024 04:26:23 -0800 Subject: [PATCH 558/698] [Mosaic GPU] Add a bank-conflict checker to tiled transfer + transfer planner Instead of only allowing a fixed set of layouts that we've hand verified as bank-conflict free, we now simulate the transactions performed within each warp and verify that no bank conflicts happen. If we detect that the simple schedule does not work out, we attempt to partition the threads in a warp into two groups and stagger the transfers in a way that lets us avoid conflicts. This allows us to match the hand-designed transfer schedule I wrote for 32-bit types, and even generalizes it to more cases automatically (e.g. swizzle=32). PiperOrigin-RevId: 701919158 --- .../mosaic/gpu/fragmented_array.py | 226 ++++++++++++++++-- tests/mosaic/gpu_test.py | 24 +- 2 files changed, 225 insertions(+), 25 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2f53b08e3af4..4d78270e5009 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -20,7 +20,7 @@ import functools import math from collections.abc import Callable -from typing import Iterable, Sequence, TypeVar +from typing import Iterable, Protocol, Sequence, TypeVar import jax from jaxlib.mlir import ir @@ -42,6 +42,8 @@ WARPGROUP_SIZE = utils.WARPGROUP_SIZE WARP_SIZE = 32 WARPS_IN_WARPGROUP = WARPGROUP_SIZE // WARP_SIZE +SMEM_BANKS = 32 +SMEM_BANK_BYTES = 4 c = utils.c @@ -1455,11 +1457,15 @@ def load_tiled( raise ValueError("Tiled reference must have even rank") tiling = Tiling((tiled_shape[len(tiled_shape) // 2 :],)) shape = tiling.untile_shape(tiled_shape) - registers = np.full(layout.registers_shape(shape), None, dtype=object) + zero = ( + vector.splat( + ir.VectorType.get((layout.vector_length,), dtype), c(0, dtype) + ), + ) + registers = np.full(layout.registers_shape(shape), zero, dtype=object) reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): update(registers, llvm.load(reg_ty, ptr)) - assert all(r is not None for r in registers.flat) case WGMMAFragLayout(): bw = mgpu.bytewidth(dtype) m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape @@ -1611,6 +1617,8 @@ def transfer_tiled2( tiled_strides = list(tiling.tile_strides(tuple(ref_strides))) tiled_shape = list(tiling.tile_shape(tuple(ref_ty.shape))) + lane_strides = [tiled_strides[d] for d in layout.lane_dims] + lane_shape = [tiled_shape[d] for d in layout.lane_dims] if tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): @@ -1618,12 +1626,9 @@ def transfer_tiled2( full_tiling = Tiling((ref_tiling_shape, *tiling.tiles)) full_layout = dataclasses.replace(layout, tiling=full_tiling) - # XXX: This method is still slightly incompete. For example, it does not - # verify that the vector transfers don't cross swizzle tile boundaries. It - # also does not guarantee that the transfer pattern does not cause bank - # conflicts. For that reason, we only allow a select subset of layouts. - if layout != _tiled_wgmma_layout(shape) or bw > 2: - raise NotImplementedError("transfer_tiled2 not general enough yet") + plan = plan_tiled_transfer( + tiled_shape, tiled_strides, lane_shape, lane_strides, layout, bw, swizzle + ) dyn_tiled_strides = [c(s) for s in tiled_strides] lane_offset = utils.dyn_dot(full_layout.lane_indices(), dyn_tiled_strides) @@ -1632,14 +1637,20 @@ def transfer_tiled2( if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): raise ValueError("Tiled stores can be performed into SMEM") ptr = utils.memref_ptr(ref, memory_space=3) + _as_consts = lambda consts: [c(const) for const in consts.tolist()] + # This has bits set only for the offset bits that influence swizzling. + swizzle_mask = swizzle_block_elems - swizzle_tile_elems for tile_idx in np.ndindex(*tiled_shape): - const_offset = sum(i * s for i, s in zip(tile_idx, tiled_strides)) + indices = np.asarray([f(tile_idx) for f in plan.tile_index_transforms]) + const_offset = np.dot(indices, tiled_strides) # We split the offset into a part that interacts with swizzling and a # part that doesn't. This lets us generate better code because constant # offsets can be fused into load and store instructions. - const_offset_swizzle = const_offset % swizzle_block_elems + const_offset_swizzle = const_offset & swizzle_mask const_offset_no_swizzle = const_offset - const_offset_swizzle - offset_pre_swizzle = arith.addi(dyn_offset, c(const_offset_swizzle)) + offset_pre_swizzle = arith.addi( + dyn_offset, plan.select(_as_consts(const_offset_swizzle)) + ) swizzle_group = arith.remui( arith.divui(offset_pre_swizzle, c(swizzle_group_elems)), c(swizzle_groups_per_block), @@ -1647,12 +1658,24 @@ def transfer_tiled2( swizzle_bits = arith.muli(swizzle_group, c(swizzle_tile_elems)) offset = arith.xori(offset_pre_swizzle, swizzle_bits) reg_ptr = utils.getelementptr(ptr, [offset], dtype) - reg_ptr = utils.getelementptr(reg_ptr, [const_offset_no_swizzle], dtype) - reg_idx = tiling.tile_indices(full_tiling.untile_indices(tile_idx)) - def get_register(regs, reg_idx=reg_idx): - return regs[reg_idx] - def update_registers(regs, new, reg_idx=reg_idx): - regs[reg_idx] = new + offset_no_swizzle = plan.select(_as_consts(const_offset_no_swizzle)) + reg_ptr = utils.getelementptr(reg_ptr, [offset_no_swizzle], dtype) + reg_idxs = [ + tiling.tile_indices(full_tiling.untile_indices(idx)) + for idx in indices.tolist() + ] + def get_register(regs, reg_idxs=reg_idxs): + return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) + def update_registers(regs, new, reg_idxs=reg_idxs): + # TODO(apaszke): If the staggering forms a permutation with a small + # cycle length, then instead of blending at each step we could construct + # a small routing network (kind of like a sorting network) to fix up + # each cycle separately after all the loads are performed. + # This would be especially useful for dims that are powers of two and + # staggered by another power of 2, since all cycles are of length 2 (and + # we could save half the selects). + for i, reg_idx in enumerate(reg_idxs): + regs[reg_idx] = plan.select_if_group(i, regs[reg_idx], new) yield get_register, update_registers, reg_ptr def tree_flatten(self): @@ -1666,6 +1689,173 @@ def tree_unflatten(cls, aux, flat_registers): return cls(_registers=registers, _layout=layout, _is_signed=is_signed) +class TransferPlan(Protocol): + IndexTransform = Callable[[tuple[int, ...]], tuple[int, ...]] + tile_index_transforms: tuple[IndexTransform, ...] + + def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: + """Selects the value corresponding to the group of the current thread. + + The argument must be of the same length as tile_index_transforms. + """ + raise NotImplementedError + + def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: + """Returns `new` if the current thread belongs to the given group and `old` otherwise. + + group_idx must be between 0 and len(tile_index_transforms) - 1. + """ + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class TrivialTransferPlan(TransferPlan): + @property + def tile_index_transforms(self): + return (lambda x: x,) + + def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: + assert len(group_elems) == 1 + return group_elems[0] + + def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: + assert group_idx == 0 + return new + + +@dataclasses.dataclass(frozen=True) +class StaggeredTransferPlan(TransferPlan): + stagger: int + dim: int + size: int + group_pred: ir.Value + + @property + def tile_index_transforms(self): + dim = self.dim + def rotate(idx: tuple[int, ...]) -> tuple[int, ...]: + return ( + *idx[:dim], (idx[dim] + self.stagger) % self.size, *idx[dim + 1 :], + ) + return (lambda x: x, rotate) + + def select(self, group_elems: Sequence[ir.Value]) -> ir.Value: + assert len(group_elems) == 2 + return arith.select(self.group_pred, group_elems[1], group_elems[0]) + + def select_if_group(self, group_idx: int, old: ir.Value, new: ir.Value) -> ir.Value: + assert 0 <= group_idx <= 1 + sides = [old, new] if group_idx == 0 else [new, old] + return arith.select(self.group_pred, *sides) + + +def plan_tiled_transfer( + tiled_shape: Sequence[int], + tiled_strides: Sequence[int], + lane_shape: Sequence[int], + lane_strides: Sequence[int], + layout: TiledLayout, + bw: int, + swizzle: int, +) -> TransferPlan: + i32 = ir.IntegerType.get_signless(32) + c = lambda x: arith.constant(i32, x) + swizzle_tile_elems = 16 // bw + swizzle_group_elems = 128 // bw + # Below, all calculations are in elements, not in bytes, since it should + # generalize better to sub-byte types. + # Here, we verify two conditions: + # 1. Each vector transfer only accesses addresses that fall within a single + # swizzle tile (if not we'd need to split it and swizzle parts differently). + transfer_alignment = math.gcd(*( + s + for i, (s, d) in enumerate_negative(list(zip(tiled_strides, tiled_shape))) + if d > 1 or i in {layout.warp_dim, *layout.lane_dims} + )) + if ( + swizzle_tile_elems % transfer_alignment + and layout.vector_length <= transfer_alignment + ): + raise ValueError( + "Failed to prove that vector transfers don't cross swizzle tile" + " boundaries. This check is incomplete, and does not guarantee that" + " this is a user error, but it might be." + str(transfer_alignment) + ) + + # 2. The transfer pattern does not cause bank conflicts. + # TODO(apaszke): For now, when performing transfers narrower than a bank, + # we simply narrow each bank to the transfer width. The truth is more likely + # that bank conflicts only don't occur if the addresses mapping to the same + # bank are contiguous, but that's a more complicated check to perform. + transfer_bytes = layout.vector_length * bw + if transfer_bytes > SMEM_BANK_BYTES * 4: + raise NotImplementedError + if bw > SMEM_BANK_BYTES: + raise NotImplementedError + smem_bank_bytes = min(SMEM_BANK_BYTES, transfer_bytes) + num_banks = SMEM_BANKS * (SMEM_BANK_BYTES // smem_bank_bytes) + elems_per_bank = smem_bank_bytes // bw + num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) + wavefront_lanes = WARP_SIZE // num_wavefronts + + lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) + def has_bank_conflicts(tile_idx_transform): + tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) + tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] + lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] + assert lane_tile_idx.shape[1] in {1, WARP_SIZE} + lane_tile_offsets = np.dot(lane_tile_idx, tiled_strides) + offsets = lane_tile_offsets + lane_offsets_in_tile # [#tiles, #lanes] + assert offsets.shape[-1] == WARP_SIZE + swizzle_groups = (offsets // swizzle_group_elems) % (swizzle // 16) + swizzle_bits = swizzle_groups * swizzle_tile_elems + lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks + wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) + # Order of threads within the wavefront is unimportant. + wavefront_banks = np.sort(wavefront_banks, axis=-1) + # There are no conflicts if each wavefront only contains unique banks. + return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) + + # We don't need any special treatment if there are no conflicts when each lane + # transfers the same tile at a time. + if not has_bank_conflicts(lambda tile_idx: tile_idx): + return TrivialTransferPlan() + + # Otherwise, we will try to partition the lanes into two groups and have + # each group store to different tile. The only tile dimensions that can help + # us with bank conflicts are those that have multiple elements and a stride + # that's not a multiple of the number of banks. + # + # Note that the code is set up so that we could also consider partitioning + # the lanes into more groups, but the selects will become more expensive if + # we do that. It's a possibility we have if we need it. + candidate_dims = ( + i for i, (s, d) in enumerate(zip(tiled_strides, tiled_shape)) + if d > 1 and s % (SMEM_BANKS * elems_per_bank) + ) + for dim in candidate_dims: + for group_stride in (1, 2, 4, 8, 16): + # We change the group assignment each group_stride lanes. + lane_id = np.arange(WARP_SIZE)[:, None] + lane_group = (lane_id // group_stride) % 2 + # We only consider a transformation where the second group stores to a + # tile that's a constant offset (modulo dim size) from the first one. + for stagger in range(1, tiled_shape[dim]): + offset = np.zeros(len(tiled_shape), np.int64) + offset[dim] = stagger + transform = lambda idx: (idx + offset * lane_group) % tiled_shape + if not has_bank_conflicts(transform): + # We've found a strategy that avoids bank conflicts! + lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE)) + group_idx = arith.remui(arith.divui(lane_idx, c(group_stride)), c(2)) + group_pred = arith.cmpi(arith.CmpIPredicate.ne, group_idx, c(0)) + return StaggeredTransferPlan( + stagger, dim, tiled_shape[dim], group_pred + ) + raise ValueError( + "Failed to synthesize a transfer pattern that avoids bank conflicts" + ) + # We allow contractions, to potentially take advantage of FMA instructions. # They can change the results, but the precision should only increase. def addf(a: ir.Value, b: ir.Value): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index b9d0a591554d..87916ac1c364 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1755,18 +1755,21 @@ def kernel(ctx, dst, _): @parameterized.product( load_tiled=[False, True], store_tiled=[False, True], - dtype=[jnp.int16], + dtype=[jnp.int8, jnp.int16, jnp.int32], swizzle=[32, 64, 128], - num_col_tiles=[1, 2, 4], + num_col_tiles=[1, 2, 3], ) def test_copy_tiled(self, load_tiled, store_tiled, dtype, swizzle, num_col_tiles): mlir_dtype = utils.dtype_to_ir_type(dtype) - col_tiling = swizzle // bytewidth(mlir_dtype) + bw = bytewidth(mlir_dtype) + col_tiling = swizzle // bw m, n = 128, col_tiling * num_col_tiles tiling = (64, col_tiling) tiled_layout = fa._tiled_wgmma_layout((m, n)) load_layout = tiled_layout if load_tiled else mgpu.WGMMA_LAYOUT store_layout = tiled_layout if store_tiled else mgpu.WGMMA_LAYOUT + if (not load_tiled or not store_tiled) and bw == 4 and swizzle == 32: + self.skipTest("Old code path does not support this") def kernel(ctx, in_, out, smems): smem_in, smem_out, barrier = smems ctx.async_copy(src_ref=in_, dst_ref=smem_in, swizzle=swizzle, barrier=barrier) @@ -1800,14 +1803,21 @@ def kernel(ctx, in_, out, smems): # Verify that we don't use too many registers for the transfers. # We verify LDS and STS separately, because they might use two different # methods of computing offsets and we don't rely on CSE between them. - register_pattern = re.compile(r"(R[0-9]+)") expected_regs = swizzle // bytewidth(mlir_dtype) // 8 + # When the bytewidth is smaller than 2 the swizzle pattern changes every 2 + # column tiles, so we only need half the registers. + if load_tiled and store_tiled: # The old code doesn't optimize properly. + if bytewidth(mlir_dtype) < 2: + expected_regs //= 2 for instr in ("STS", "LDS"): with self.subTest(instr + " count"): addrs = re.findall(instr + r".* \[(.*)\]", get_sass()) - chain = itertools.chain.from_iterable - used_regs = set(chain(register_pattern.findall(addr) for addr in addrs)) - self.assertLen(used_regs, expected_regs) + def get_reg(addr): + if (pos := addr.find("+")) != -1: + return addr[:pos] + return addr + used_regs = {get_reg(addr) for addr in addrs} + self.assertLessEqual(len(used_regs), expected_regs) if __name__ == "__main__": From a4e742d2fe17ae134bcd8b42b56085913dd40a14 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Mon, 2 Dec 2024 06:51:49 +0000 Subject: [PATCH 559/698] Save residuals in the decode attention pallas kernel --- .../pallas/ops/gpu/decode_attention.py | 71 ++++++++++++++--- tests/pallas/gpu_attention_test.py | 77 +++++++++++++++---- 2 files changed, 123 insertions(+), 25 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index d09f1fbac113..1c558c220ea9 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -143,6 +143,7 @@ def decode_attn_unbatched( grid: tuple[int, ...] | None, interpret: bool, debug: bool, + return_residuals: bool ): num_heads, head_dim = q.shape k_seq_len, _ = k.shape @@ -215,7 +216,10 @@ def decode_attn_unbatched( l_next = (l * correction).sum(axis=0) eps = jnp.finfo(l_next.dtype).eps o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps) - return o + if return_residuals: + return o, (l_next, m_next) + else: + return o @functools.partial( @@ -230,6 +234,7 @@ def decode_attn_unbatched( "grid", "interpret", "debug", + "return_residuals" ], ) def mqa( @@ -247,6 +252,7 @@ def mqa( grid: tuple[int, ...] | None = None, interpret: bool = False, debug: bool = False, + return_residuals: bool = False ): sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) bs = q.shape[0] @@ -265,6 +271,7 @@ def mqa( grid=grid, interpret=interpret, debug=debug, + return_residuals=return_residuals ) return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len) @@ -281,6 +288,7 @@ def mqa( "grid", "interpret", "debug", + "return_residuals" ], ) def gqa( @@ -298,6 +306,7 @@ def gqa( grid: tuple[int, ...] | None = None, interpret: bool = False, debug: bool = False, + return_residuals: bool = False, ): sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) batch_size, q_heads, head_dim = q.shape @@ -331,14 +340,23 @@ def gqa( grid=grid, interpret=interpret, debug=debug, + return_residuals=return_residuals, ) with_kv_heads = jax.vmap(inner) - o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed, - start_idx, kv_seq_len) - return o.reshape(batch_size, q_heads, head_dim) + o, *res = jax.vmap(with_kv_heads)( + q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len + ) + o = o.reshape(batch_size, q_heads, head_dim) + if return_residuals: + l, m = res[0] + l = l.reshape(batch_size, q_heads) + m = m.reshape(batch_size, q_heads) + return o, (l, m) + else: + return o -@functools.partial(jax.jit, static_argnames=["sm_scale"]) +@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"]) def mqa_reference( q, # [bs, num_q_heads, head_dim] k, # [bs, k_seq_len, head_dim] @@ -346,10 +364,16 @@ def mqa_reference( start_idx=None, # [bs] kv_seq_len=None, # [bs] sm_scale=None, + return_residuals=False ): + original_dtype = q.dtype + q = q.astype(jnp.float32) + k = k.astype(jnp.float32) bs = q.shape[0] sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32) + if sm_scale is not None and sm_scale != 1.0: + logits = logits * sm_scale if start_idx is not None or kv_seq_len is not None: start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None @@ -358,8 +382,17 @@ def mqa_reference( & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) mask = mask[:, None, :] logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) - weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) - return jnp.einsum("bns,bsd->bnd", weights, v) + + m = logits.max(axis=-1) + s = jnp.exp(logits - m[..., None]) + l = s.sum(axis=-1) + s = s / l[..., None] + o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype) + + if return_residuals: + return o, (l, m) + else: + return o @functools.partial(jax.jit, static_argnames=["sm_scale"]) @@ -387,7 +420,7 @@ def mha_reference( return jnp.einsum("bns,bsnd->bnd", weights, v) -@functools.partial(jax.jit, static_argnames=["sm_scale"]) +@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"]) def gqa_reference( q, # [bs, num_q_heads, head_dim] k, # [bs, k_seq_len, num_k_heads, head_dim] @@ -395,7 +428,11 @@ def gqa_reference( start_idx=None, # [bs] kv_seq_len=None, # [bs] sm_scale=None, + return_residuals=False ): + original_dtype = q.dtype + q = q.astype(jnp.float32) + k = k.astype(jnp.float32) sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1])) bs, num_q_heads, head_dim = q.shape num_kv_heads = k.shape[2] @@ -412,6 +449,8 @@ def gqa_reference( logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype( jnp.float32 ) + if sm_scale is not None and sm_scale != 1.0: + logits = logits * sm_scale if start_idx is not None or kv_seq_len is not None: start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,)) kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None @@ -420,7 +459,17 @@ def gqa_reference( & (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None])) mask = mask[:, None, None, :] logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min) - weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) - o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed) + + m = logits.max(axis=-1) + s = jnp.exp(logits - m[..., None]) + l = s.sum(axis=-1) + s = s / l[..., None] + o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype) o = o.reshape(bs, num_q_heads, head_dim) - return o + + if return_residuals: + l = l.reshape(bs, num_q_heads) + m = m.reshape(bs, num_q_heads) + return o, (l, m) + else: + return o diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index afd2f6ae3fcf..3b4aa1551591 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -21,6 +21,7 @@ from jax import random from jax._src import config from jax._src import test_util as jtu + if sys.platform != "win32": from jax.experimental.pallas.ops.gpu import decode_attention else: @@ -48,8 +49,9 @@ def setUp(self): self.skipTest("On CPU, the test works only in interpret mode") if jax.config.x64_enabled: self.skipTest("The test works only in 32-bit") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): + if jtu.test_device_matches( + ["cuda"] + ) and not jtu.is_cuda_compute_capability_at_least("8.0"): self.skipTest("Only works on GPU with capability >= sm80") if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") @@ -62,8 +64,10 @@ class DecodeAttentionTest(PallasBaseTest): @parameterized.named_parameters(*[ ( - (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_" - f"{start_idx=}_{kv_seq_len=}"), + ( + f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_" + f"{start_idx=}_{kv_seq_len=}_{return_residuals=}" + ), batch_size, seq_len, num_heads, @@ -71,6 +75,7 @@ class DecodeAttentionTest(PallasBaseTest): kwargs, start_idx, kv_seq_len, + return_residuals, ) for ( batch_size, @@ -85,6 +90,7 @@ class DecodeAttentionTest(PallasBaseTest): ] for start_idx in [None, 123] for kv_seq_len in [None, 250] + for return_residuals in [False, True] ]) @jax.numpy_dtype_promotion("standard") def test_mqa( @@ -96,6 +102,7 @@ def test_mqa( kwargs, start_idx, kv_seq_len, + return_residuals, ): del kwargs @@ -104,16 +111,36 @@ def test_mqa( k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16) v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16) - o = decode_attention.mqa(q, k, v, start_idx=start_idx, - kv_seq_len=kv_seq_len, interpret=self.INTERPRET) - o_ref = decode_attention.mqa_reference(q, k, v, start_idx=start_idx, - kv_seq_len=kv_seq_len) + o, *res = decode_attention.mqa( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + interpret=self.INTERPRET, + ) + o_ref, *res_ref = decode_attention.mqa_reference( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + ) np.testing.assert_allclose(o, o_ref, atol=0.05) + if return_residuals: + l, m = res[0] + l_ref, m_ref = res_ref[0] + np.testing.assert_allclose(l, l_ref, atol=0.05) + np.testing.assert_allclose(m, m_ref, atol=0.05) @parameterized.named_parameters(*[ ( - (f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}" - f"_{kwargs=}_{start_idx=}_{kv_seq_len=}"), + ( + f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}" + f"_{kwargs=}_{start_idx=}_{kv_seq_len=}_{return_residuals=}" + ), batch_size, seq_len, num_q_heads, @@ -122,6 +149,7 @@ def test_mqa( kwargs, start_idx, kv_seq_len, + return_residuals, ) for ( batch_size, @@ -137,6 +165,7 @@ def test_mqa( ] for start_idx in [None, 123] for kv_seq_len in [None, 250] + for return_residuals in [False, True] ]) @jax.numpy_dtype_promotion("standard") def test_gqa( @@ -149,6 +178,7 @@ def test_gqa( kwargs, start_idx, kv_seq_len, + return_residuals, ): del kwargs @@ -162,11 +192,30 @@ def test_gqa( v = random.normal( k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16 ) - o = decode_attention.gqa(q, k, v, start_idx=start_idx, - kv_seq_len=kv_seq_len, interpret=self.INTERPRET) - o_ref = decode_attention.gqa_reference(q, k, v, start_idx=start_idx, - kv_seq_len=kv_seq_len) + o, *res = decode_attention.gqa( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + interpret=self.INTERPRET, + ) + o_ref, *res_ref = decode_attention.gqa_reference( + q, + k, + v, + start_idx=start_idx, + kv_seq_len=kv_seq_len, + return_residuals=return_residuals, + ) np.testing.assert_allclose(o, o_ref, atol=0.05) + if return_residuals: + l, m = res[0] + l_ref, m_ref = res_ref[0] + np.testing.assert_allclose(l, l_ref, atol=0.05) + np.testing.assert_allclose(m, m_ref, atol=0.05) + class DecodeAttentionInterpretTest(DecodeAttentionTest): INTERPRET = True From 5d5b06cf8a3aeea269d01ac3cbb5b25e017be4a9 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 2 Dec 2024 07:31:30 -0800 Subject: [PATCH 560/698] [jax] Canonicalize dtypes when checking if dtypes present in target dtypes list. PiperOrigin-RevId: 701961663 --- jax/_src/lax/lax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index be7d13195554..631fab903277 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3704,12 +3704,14 @@ def maybe_convert_dtype(input_dtype, target_dtype): return input_dtype if not isinstance(target_dtype, tuple): target_dtype = (target_dtype,) - return input_dtype if input_dtype in target_dtype else target_dtype[0] + if np.dtype(input_dtype) in map(np.dtype, target_dtype): + return input_dtype + return target_dtype[0] if algorithm == DotAlgorithmPreset.BF16_BF16_F32: lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type) rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type) - if lhs_dtype == dtypes.bfloat16: + if np.dtype(lhs_dtype) == dtypes.bfloat16: out_dtype = maybe_convert_dtype(out_dtype, (np.float32, dtypes.bfloat16)) else: From aff7714dc0f49cc0097e4db08e028b68182c8ab9 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 2 Dec 2024 07:32:12 -0800 Subject: [PATCH 561/698] [Pallas:MGPU] Fix an overly strict precision requirement in tests They started failing after we allowed LLVM to perform contractions of adds and muls, but the difference is tiny. PiperOrigin-RevId: 701961845 --- tests/pallas/mosaic_gpu_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index fa8597ab3195..e936206aa4bf 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -492,9 +492,7 @@ def layer_norm_np(x): jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float32) * input_factor ) - # TODO(cperivol): find out why in this particular case we have a small-ish error. - rtol = 1e-07 if input_factor > 10 else 5e-5 - np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=rtol) + np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5) def test_print(self): @functools.partial( From 8a3161953c4b2234b3b6bcf1e5fe640c7b00ba9a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 2 Dec 2024 08:25:32 -0800 Subject: [PATCH 562/698] [Pallas:MGPU] Make the shapes from the attention example more interesting This bumps up the number of heads and removes the batch_size=2 case: it's very similar to batch_size=1 and doubles the script runtime. We also don't do full autotuning by default since the largest size that works usually performs the best. PiperOrigin-RevId: 701976192 --- .../pallas/ops/gpu/attention_mgpu.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 6f02396ccb92..b684aef409f1 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -249,25 +249,26 @@ def attention_reference(q, k, v): def main(unused_argv): - num_q_heads = 1 - num_kv_heads = 1 - problem_it = itertools.product((1, 2), (4096, 32768,), (64, 128, 256,)) + num_q_heads = 16 + num_kv_heads = 16 + problem_it = itertools.product((1,), (4096, 32768,), (64, 128, 256,)) for batch_size, seq_len, head_dim in problem_it: q_seq_len = kv_seq_len = seq_len print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}" f"{num_q_heads=:<4} {head_dim=:<6} ====") - param_it = itertools.product((64,), (64, 128, 256)) - best = None k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) - for block_q, block_kv in param_it: + block_q = 64 + best = None + for block_kv in (256, 128, 64): config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2) try: out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v) - out_ref = attention_reference(q, k, v) - np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + if seq_len < 32768: + out_ref = attention_reference(q, k, v) + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) except ValueError as e: if "exceeds available shared memory" in e.args[0]: continue @@ -285,6 +286,7 @@ def main(unused_argv): ) if best is None or runtime_us < best[0]: best = (runtime_us, achieved_tc_util) + break # Remove this for full autotuning. if best is not None: print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization") From 97d201e2f18bbed6c214bf231eef3016d72c46ee Mon Sep 17 00:00:00 2001 From: charleshofer Date: Mon, 2 Dec 2024 10:28:23 -0600 Subject: [PATCH 563/698] Update ci-build.yaml to use specific image --- .github/workflows/ci-build.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index ead7f4c5ad69..1ad6db4baedd 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -102,6 +102,8 @@ jobs: documentation: name: Documentation - test code snippets runs-on: ubuntu-latest + container: + image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 timeout-minutes: 10 strategy: matrix: From bbc4a20c85ac9c1ada0245313d8cacdd33e98fe8 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Mon, 2 Dec 2024 23:10:19 +0530 Subject: [PATCH 564/698] Update the docstring of jax.lax.switch --- jax/_src/lax/control_flow/conditionals.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 9e1f7e04c741..7a5e9e9dee8a 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -87,6 +87,7 @@ def switch(index, branches, *operands): Args: index: Integer scalar type, indicating which branch function to apply. branches: Sequence of functions (A -> B) to be applied based on ``index``. + All branches must return the same output structure. operands: Operands (A) input to whichever branch is applied. Returns: From b1423a366931b4e4f96902117489dbf9aba8d847 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 2 Dec 2024 09:43:57 -0800 Subject: [PATCH 565/698] [Pallas:MGPU] Fix a use-after-free in lowering The lifetime of values is bound to the ops that produce them, which are deleted after the `with` block. The lifetime of types is bound to the context. PiperOrigin-RevId: 701997797 --- jax/_src/pallas/mosaic_gpu/lowering.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0407d76643b8..c3563906ec63 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1538,14 +1538,13 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args ) - yielded = [ - _ensure_ir_value(out, aval.dtype) or out + yielded_types, _ = jax.tree.flatten([ + (_ensure_ir_value(out, aval.dtype) or out).type for out, aval in zip(outs, ctx.avals_out) - ] - yielded_leaves, _ = jax.tree.flatten(yielded) + ]) switch_op = scf_dialect.IndexSwitchOp( - [v.type for v in yielded_leaves], + yielded_types, _as_index(_ensure_ir_value(index, index_aval.dtype)), ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), num_caseRegions=len(branches) - 1, From a7039a275ecb837f83848154104c613b849f9481 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Dec 2024 10:20:34 -0800 Subject: [PATCH 566/698] jnp.reshape: raise TypeError when specifying newshape --- jax/_src/numpy/lax_numpy.py | 17 ++++------------- tests/lax_numpy_test.py | 9 ++------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a61b1d67f53e..6a0e4059c4ae 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2143,20 +2143,11 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(micky774): deprecated 2024-5-9, remove after deprecation expires. + # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. if not isinstance(newshape, DeprecatedArg): - if shape is not None: - raise ValueError( - "jnp.reshape received both `shape` and `newshape` arguments. Note that " - "using `newshape` is deprecated, please only use `shape` instead." - ) - deprecations.warn( - "jax-numpy-reshape-newshape", - ("The newshape argument of jax.numpy.reshape is deprecated. " - "Please use the shape argument instead."), stacklevel=2) - shape = newshape - del newshape - elif shape is None: + raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." + " Use shape instead.") + if shape is None: raise TypeError( "jnp.shape requires passing a `shape` argument, but none was given." ) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ef80e368c9c7..ef7faf9e3559 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3428,13 +3428,8 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CompileAndCheck(jnp_fun, args_maker) def testReshapeDeprecatedArgs(self): - msg = "The newshape argument of jax.numpy.reshape is deprecated." - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-reshape-newshape"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) - with assert_warns_or_errors(msg): + msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." + with self.assertRaisesRegex(TypeError, msg): jnp.reshape(jnp.arange(4), newshape=(2, 2)) @jtu.sample_product( From 784ebeabc890284f48f2b13f71d97bd71fedb0fc Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 2 Dec 2024 10:31:45 -0800 Subject: [PATCH 567/698] [Mosaic GPU] Automatically squash a >3D logical grid into a 3D physical CUDA grid. PiperOrigin-RevId: 702013252 --- jax/_src/pallas/mosaic_gpu/lowering.py | 101 ++++++++++++++++++++----- tests/pallas/mosaic_gpu_test.py | 59 +++++++++++++++ 2 files changed, 143 insertions(+), 17 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index c3563906ec63..5d33d8895b61 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -201,6 +201,7 @@ class ModuleContext: ] name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches + squashed_dims: tuple[int, ...] def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: """Reserves a barrier. @@ -403,12 +404,15 @@ def lower_jaxpr_to_module( parallel_grid = [ d for i, d in enumerate(logical_grid) if i not in sequential_axes ] - if len(parallel_grid) < 3: + if len(parallel_grid) <= 3: + squashed_dims = () parallel_grid += (1,) * (3 - len(parallel_grid)) - elif len(parallel_grid) > 3: - raise NotImplementedError( - "Only <=3D grids are supported in Mosaic GPU lowering." - ) + else: + # If we have >3 parallel dimensions, we merge all leading dimensions + # into the first (Dimension.x) CUDA grid dimension. + squashed_dims = parallel_grid[:-2] + parallel_grid = [math.prod(parallel_grid[:-2]), *parallel_grid[-2:]] + if sequential_axes: # TODO(slebedev): Support multiple sequential axes. if len(sequential_axes) > 1: @@ -496,7 +500,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): parallel_count = it.count() program_ids_template = [ - _program_id(next(parallel_count)) + _program_id(next(parallel_count), squashed_dims=squashed_dims) if axis not in sequential_axes else None for axis in range(len(logical_grid)) @@ -520,6 +524,7 @@ def make_program_ids(step: ir.Value): runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), + squashed_dims=squashed_dims, ) del runtime_smem, grouped_barriers, runtime_barriers @@ -911,12 +916,42 @@ def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): raise NotImplementedError("pl.program_id() is not supported in this context") return ctx.module_ctx.program_ids[axis] - -def _program_id(axis: int) -> ir.Value: - return arith_dialect.index_cast( - ir.IntegerType.get_signless(32), - gpu_dialect.block_id(gpu_dialect.Dimension(axis)), - ) +def _unravel_program_id( + block_id: ir.Value, + axis: int, + dimensions: tuple[int, ...], + row_major: bool = False +) -> ir.Value: + """Computes the program ID for axes compressed into one block dimension.""" + if row_major: + div_value = math.prod(dimensions[axis+1:]) + else: + div_value = math.prod(dimensions[:axis]) + div_value = _as_index(_i32_constant(div_value)) + pid = arith_dialect.divui(block_id, div_value) + axis_size = _as_index(_i32_constant(dimensions[axis])) + pid = arith_dialect.remui(pid, axis_size) + return arith_dialect.index_cast(ir.IntegerType.get_signless(32), pid) + + +def _program_id(parallel_axis: int, squashed_dims: tuple[int, ...]) -> ir.Value: + if squashed_dims: + if parallel_axis < len(squashed_dims): + # All squashed dimensions are mapped to Dimension.x. + block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) + return _unravel_program_id(block_id, parallel_axis, squashed_dims) + else: + # Handle unsquashed axes. + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_id(gpu_dialect.Dimension( + parallel_axis - len(squashed_dims) + 1)), + ) + else: + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_id(gpu_dialect.Dimension(parallel_axis)), + ) @register_lowering_rule(primitives.num_programs_p) @@ -1244,16 +1279,44 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): @register_lowering_rule(lax.axis_index_p) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + i32 = ir.IntegerType.get_signless(32) grid_names = ctx.module_ctx.grid_mapping.grid_names + squashed_dims = ctx.module_ctx.squashed_dims + if squashed_dims: + unsquashed_names = grid_names[-3:] + squashed_names = grid_names[:-3] + else: + # These are unused but initialized for type checkers. + unsquashed_names = () + squashed_names = () if grid_names and axis_name in grid_names: if axis_name == grid_names[-1]: return mgpu.warpgroup_idx(sync=True) else: - idx = grid_names.index(axis_name) - return arith_dialect.index_cast( - ir.IntegerType.get_signless(32), - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) + if squashed_dims: + if axis_name in unsquashed_names: + # We add 1 to the index because the first dimension is the + # squashed dimension. + # e.g. for the grid (a, b, c, d, wg) + # squashed = (a, b) Mapped to Dimension.x (0) + # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) + idx = unsquashed_names.index(axis_name) + 1 + return arith_dialect.index_cast( + i32, + gpu_dialect.block_id(gpu_dialect.Dimension(idx)), + ) + elif axis_name in squashed_names: + # All squashed dimensions are mapped to Dimension.x. + block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) + axis = squashed_names.index(axis_name) + return _unravel_program_id(block_id, axis, squashed_dims) + else: + if axis_name in grid_names: + idx = grid_names.index(axis_name) + return arith_dialect.index_cast( + i32, + gpu_dialect.block_id(gpu_dialect.Dimension(idx)), + ) raise ValueError( "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" ) @@ -1669,10 +1732,14 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value: def _i32_constant(v: int) -> ir.Value: + if v < jnp.iinfo(jnp.int32).min or v > jnp.iinfo(jnp.int32).max: + raise ValueError(f"Integer constant out of range for i32: {v}") return arith_dialect.constant(ir.IntegerType.get_signless(32), v) def _i64_constant(v: int) -> ir.Value: + if v < jnp.iinfo(jnp.int64).min or v > jnp.iinfo(jnp.int64).max: + raise ValueError(f"Integer constant out of range for i64: {v}") return arith_dialect.constant(ir.IntegerType.get_signless(64), v) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e936206aa4bf..d27c7d887db1 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -609,6 +609,30 @@ def kernel(o_ref): jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32), ) + def test_program_id_in_squashed_grid(self): + # Tests whether a grid with >3 logical dimensions is correctly squashed to + # 3 CUDA grid dimensions. + grid = (2, 3, 4, 5) + @functools.partial( + pl.pallas_call, + in_specs=(), + out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)), + out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32), + grid=grid, + ) + def kernel(o_ref): + mult = 1 + idx = 0 + for axis in range(len(grid)-1, -1, -1): + idx += pl.program_id(axis) * mult + mult *= pl.num_programs(axis) + o_ref[...] = jnp.full(o_ref.shape, idx) + + np.testing.assert_array_equal( + kernel()[:, :, :, :, 0], + jnp.arange(math.prod(grid), dtype=jnp.int32).reshape(*grid) + ) + def test_program_id_in_block_spec(self): @functools.partial( pl.pallas_call, @@ -1383,6 +1407,41 @@ def kernel(): f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) + def test_multiple_wg_with_squashed_grid(self): + # Tests whether a grid with >3 logical dimensions is correctly squashed to + # 3 CUDA grid dimensions. + b = 4 + x_dim = 3 + y_dim = 5 + z_dim = 7 + num_threads = 2 + mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim), + num_threads=num_threads, + axis_names=("b", "x", "y", "z", "wg")) + + @jax.jit + def f(): + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def _(): + b_idx = jax.lax.axis_index("b") + x_idx = jax.lax.axis_index("x") + y_idx = jax.lax.axis_index("y") + z_idx = jax.lax.axis_index("z") + wg_idx = jax.lax.axis_index("wg") + bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) + y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( + bxyzw_idx, (128,) + ) + y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32) + return inner(y_init) + result = f()[:, :, :, :, :, 0] + ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( + result.shape) + np.testing.assert_array_equal(result, ref) + + def test_cross_wg_barrier(self): mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) From f182aa8eddbdb99d2299f2eb2153b03078a124da Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Dec 2024 10:57:57 -0800 Subject: [PATCH 568/698] Skip vecmat & matvec in NumPy tests. --- tests/lax_numpy_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ef80e368c9c7..535e2c632027 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6446,7 +6446,8 @@ def testWrappedSignaturesMatch(self): _available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all if dtype != dtypes.bfloat16] -UNIMPLEMENTED_UFUNCS = {'spacing'} +# TODO(jakevdp): implement missing ufuncs. +UNIMPLEMENTED_UFUNCS = {'spacing', 'matvec', 'vecmat'} def _all_numpy_ufuncs() -> Iterator[str]: From c9a5902216b61bddb88415b9da6bbcf589ef12ea Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Dec 2024 12:39:56 -0800 Subject: [PATCH 569/698] [jax] Typing on common_devices_indices_map PiperOrigin-RevId: 702053791 --- jax/_src/sharding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index cee3542f0006..23f0ef13cb00 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -43,7 +43,8 @@ def _addressable_devices_indices_map( if d.process_index == d.client.process_index()} @cache(max_size=4096, trace_context_in_key=False) -def common_devices_indices_map(s, global_shape: Shape) -> Mapping[Device, Index]: +def common_devices_indices_map( + s: Sharding, global_shape: Shape) -> Mapping[Device, Index]: s.shard_shape(global_shape) # raises a good error message hlo_sharding = s._to_xla_hlo_sharding(len(global_shape)) indices = op_sharding_to_indices(hlo_sharding, global_shape, From f8b753cf93a4e8576efe88c6650e7a64968e5fbc Mon Sep 17 00:00:00 2001 From: charleshofer Date: Mon, 2 Dec 2024 15:56:10 -0600 Subject: [PATCH 570/698] Update ci-build.yaml --- .github/workflows/ci-build.yaml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 1ad6db4baedd..33d413062ad7 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -101,9 +101,7 @@ jobs: documentation: name: Documentation - test code snippets - runs-on: ubuntu-latest - container: - image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04 + runs-on: ROCM-Ubuntu timeout-minutes: 10 strategy: matrix: @@ -147,10 +145,6 @@ jobs: python-version: ['3.10'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Image Setup - run: | - apt update - apt install -y libssl-dev libsqlite3-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: From 9f20301739105125e119ff34be3bb275a26fd40c Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Dec 2024 14:42:05 -0800 Subject: [PATCH 571/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6070a19a7c7d1c62e71d73a4ee5a710c641dad2c. PiperOrigin-RevId: 702090722 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 6344349e9ad0..75abd0ab5032 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "41e12cc0247edf4ffb1569f2a25c61cec924c755" -XLA_SHA256 = "1eb835846d906813264909024036be866d1ee191959fbe13b0f2eeacbf539fea" +XLA_COMMIT = "6070a19a7c7d1c62e71d73a4ee5a710c641dad2c" +XLA_SHA256 = "fb12e680ac47facf40bda36d436195e1d5454f48d02fb8540a9c6363cfe03cb1" def repo(): tf_http_archive( From 0134fa834cc2e3930e4f87796ccb939e5660cd4f Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Mon, 2 Dec 2024 14:53:26 -0800 Subject: [PATCH 572/698] Update Cloud TPU workflow with new build.py usage PiperOrigin-RevId: 702094141 --- .github/workflows/cloud-tpu-ci-nightly.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 7d7bc84fe135..b598565223c0 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -63,9 +63,11 @@ jobs: $PYTHON -m pip uninstall -y jax jaxlib libtpu if [ "${{ matrix.jaxlib-version }}" == "head" ]; then # Build and install jaxlib at head - $PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \ - --bazel_options="--override_repository=xla=$(pwd)/xla" \ - --bazel_options=--color=yes + $PYTHON build/build.py build --wheels=jaxlib \ + --bazel_options=--config=rbe_linux_x86_64 \ + --local_xla_path=$(pwd)/xla" \ + --verbose + $PYTHON -m pip install dist/*.whl # Install "jax" at head From b3c405c2f55b5c1d2cb42755bfcb8ed8a5e03b38 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 2 Dec 2024 15:58:13 -0800 Subject: [PATCH 573/698] [shape_poly] Remove obsolete part of the shape polymorphism documentation The section of division limitations is now obsolte, because JAX can represent division symbolically. --- docs/export/shape_poly.md | 41 +-------------------------------------- 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 6ad7fb5c2b09..9254030a4e1c 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -159,7 +159,7 @@ new shape: It is possible to convert dimension expressions explicitly to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`. The result of these operations can be used as regular JAX arrays, -bug cannot be used anymore as dimensions in shapes. +but cannot be used anymore as dimensions in shapes, e.g., in `reshape`: ```python >>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))( @@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass These errors arise in a pre-processing step before the compilation. -### Division of symbolic dimensions is partially supported - -JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. -In particular, JAX will handle the cases when either (a) there -is no remainder, or (b) the divisor is a constant -in which case there may be a constant remainder. - -For example, the code below results in a division error when trying to -compute the inferred dimension for a `reshape` operation: - -```python ->>> b, = export.symbolic_shape("b") ->>> export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b,), dtype=np.int32)) -Traceback (most recent call last): -jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1). -The remainder mod(b, - 2) should be 0. - -``` - -Note that the following will succeed: - -```python ->>> b, = export.symbolic_shape("b") ->>> # We specify that the first dimension is a multiple of 4 ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((4*b,), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,2*b]),) - ->>> # We specify that some other dimension is even ->>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( -... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32)) ->>> exp.out_avals -(ShapedArray(int32[2,15*b]),) - -``` - (shape_poly_debugging)= ## Debugging From 6a8bbcbadfe93cfa2d9f03fcb5be43a44cab6f28 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 18 Nov 2024 17:06:28 -0800 Subject: [PATCH 574/698] Add an option to deactivate automatic cluster detection in jax.distributed.initialize(). --- jax/_src/distributed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index f80f90bde186..e9796d61c6f3 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -62,7 +62,8 @@ def initialize(self, if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')): local_device_ids = list(map(int, env_ids.split(","))) - if None in (coordinator_address, num_processes, process_id, local_device_ids): + if (cluster_detection_method != 'deactivate' and + None in (coordinator_address, num_processes, process_id, local_device_ids)): (coordinator_address, num_processes, process_id, local_device_ids) = ( clusters.ClusterEnv.auto_detect_unset_distributed_params( coordinator_address, @@ -217,7 +218,8 @@ def initialize(coordinator_address: str | None = None, cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment, and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``. - Legacy auto-detect options (OMPI, Slurm) remain enabled. + Legacy auto-detect options "ompi" (OMPI) and "slurm" (Slurm) remain enabled. "deactivate" bypasses + automatic cluster detection. initialization_timeout: Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. From 7bd81dbe0d1c3c7a17a6c5a50f65d8caf09cb5aa Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Tue, 19 Nov 2024 21:36:25 -0500 Subject: [PATCH 575/698] [Mosaic GPU] Improve default kernel name and add option to customize This allows users to distinguish Mosaic GPU kernels from other kernels when using profiling programs such as Nsight Systems. The new default behavior is to use `mosaic_gpu__kernel` as the kernel name, where `` is the name of the Mosaic GPU Python kernel function passed to `as_gpu_kernel` or `as_torch_gpu_kernel`. We also add a new `kernel_name` optional argument to `as_gpu_kernel` and `as_torch_gpu_kernel`. If `kernel_name` is not `None`, the resulting kernel name is `mosaic_gpu__kernel`. This is useful when the Mosaic GPU Python kernel function is constructed through metaprogramming so that the final specialized kernel can have different meaningful names depending on the metaparameters. Previously the kernel name was always `main_kernel`. --- jax/experimental/mosaic/gpu/core.py | 11 +++++-- jaxlib/mosaic/gpu/BUILD | 1 + jaxlib/mosaic/gpu/custom_call.cc | 50 ++++++++++++++++++++++++++--- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 16b7f1f59c33..4b4882f65f0c 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -847,6 +847,7 @@ def _lower_as_gpu_kernel( out_shape, smem_scratch_shape: ShapeTree | Union[ShapeTree], module_name: str, + kernel_name: str | None = None, prof_spec: profiler.ProfilerSpec | None = None, ): ptr_ty = ir.Type.parse("!llvm.ptr") @@ -873,6 +874,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: module = ir.Module.create() attrs = module.operation.attributes attrs["sym_name"] = ir.StringAttr.get(module_name) + if kernel_name is None: + kernel_name = getattr(body, "__name__", "anonymous") with ir.InsertionPoint(module.body): _declare_runtime_functions() gmem_scratch_bytes = 0 @@ -882,7 +885,7 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: ir.Attribute.parse("#llvm.linkage"), addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + @func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}") def main(token_ptr, buffers): nonlocal gmem_scratch_bytes token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) @@ -947,6 +950,7 @@ def as_gpu_kernel( prof_spec: profiler.ProfilerSpec | None = None, cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", + kernel_name: str | None = None, ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -956,7 +960,7 @@ def as_gpu_kernel( module, out_shape, unwrap_output_tuple = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec + module_name, kernel_name, prof_spec ) ) @@ -1014,6 +1018,7 @@ def as_torch_gpu_kernel( prof_spec: profiler.ProfilerSpec | None = None, cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", + kernel_name: str | None = None, ): try: import torch @@ -1032,7 +1037,7 @@ def as_torch_gpu_kernel( module, out_shape, unwrap_output_tuple = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec + module_name, kernel_name, prof_spec ) ) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 1f78782a0891..cb52488e79cc 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -126,6 +126,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 2d479f712408..54792b3097f7 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "llvm/include/llvm/ADT/SmallVector.h" #include "llvm/include/llvm/Support/CodeGen.h" @@ -415,6 +416,40 @@ GetKernelCache() { return std::make_pair(&context_cache, &mutex); } +absl::StatusOr> GetHostAndInitFuncNames( + mlir::ModuleOp module_op) { + // We look for two top level C-interface functions: + // - "host" function with symbol name "_mlir_ciface_" + // - "init" function with symbol name "_mlir_ciface__init" + constexpr std::string_view prefix = "_mlir_ciface_"; + std::vector names; + for (mlir::LLVM::LLVMFuncOp llvm_func : + module_op.getOps()) { + if (llvm_func.getName().starts_with(prefix)) { + names.push_back(llvm_func.getName().str()); + } + } + if (auto size = names.size(); size != 2) { + return absl::InternalError(absl::StrFormat( + "Expected to locate 2 symbols with %s prefix in the MLIR module, found " + "%d instead.", + prefix, size)); + } + // _mlir_ciface__init now follows _mlir_ciface_ + std::sort(names.begin(), names.end()); + + std::string host_func_name = names[0]; + std::string init_func_name = names[1]; + + if (init_func_name != absl::StrCat(host_func_name, "_init")) { + return absl::InternalError(absl::StrFormat( + "Expected init function name to equal the concatenation of the host " + "function name and the \"_init\" suffix, instead got " + "init_func_name=%s, host_func_name=%s.", + init_func_name, host_func_name)); + } + return std::make_pair(host_func_name, init_func_name); +} absl::StatusOr CompileAndInit(const char* module) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); @@ -430,9 +465,16 @@ absl::StatusOr CompileAndInit(const char* module) { return maybe_engine.status(); } mlir::ExecutionEngine* execution_engine = maybe_engine->get(); - auto main = execution_engine->lookupPacked("_mlir_ciface_main"); - auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); - if (!init || !main) { + + auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op); + if (!host_and_init_func_names.ok()) { + return host_and_init_func_names.status(); + } + auto [host_name, init_name] = host_and_init_func_names.value(); + + auto host = execution_engine->lookupPacked(host_name); + auto init = execution_engine->lookupPacked(init_name); + if (!init || !host) { return absl::InternalError("Failed to retrieve kernel function"); } void* module_ptr = nullptr; @@ -442,7 +484,7 @@ absl::StatusOr CompileAndInit(const char* module) { void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); return CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*main)); + reinterpret_cast(*host)); } // Each compiled kernel has a unique init func, and each kernel is used from From f43fa9fc78a6a16af0d94a37a85062ff38076a54 Mon Sep 17 00:00:00 2001 From: nireekshak Date: Tue, 3 Dec 2024 05:07:53 +0000 Subject: [PATCH 576/698] Fix some typos --- docs/gpu_performance_tips.md | 2 +- docs/gradient-checkpointing.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 2 +- docs/notebooks/Common_Gotchas_in_JAX.md | 2 +- docs/notebooks/autodiff_remat.ipynb | 2 +- docs/notebooks/autodiff_remat.md | 2 +- docs/xla_flags.md | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index 1f5cc0727605..5a760db98684 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -112,7 +112,7 @@ don't seem useful for multi-host communication yet. ## Multi-Process -We recommand using one process per GPU and not one per node. In some +We recommend using one process per GPU and not one per node. In some cases, this can speed up jitted computation. The {func}`jax.distributed.initialize` API will automatically understand that configuration when run under SLURM. However, this only a rule of diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 33efaed6274b..3ef927e056f2 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -443,7 +443,7 @@ print_fwd_bwd(f, 3.) When differentiated functions are staged out to XLA for compilation — for example by applying {func}`jax.jit` to a function which contains a {func}`jax.grad` call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **{func}`jax.checkpoint` often isn't needed for differentiated functions under a {func}`jax.jit`**. XLA will optimize things for you. -One exception is when using staged-out control flow, like {func}`jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass `scan` and the corresponding backward-pass `scan`), typically aren't aren't as thorough. As a result, it's often a good idea to use {func}`jax.checkpoint` on the body function passed to {func}`jax.lax.scan`. +One exception is when using staged-out control flow, like {func}`jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-pass `scan` and the corresponding backward-pass `scan`), typically aren't as thorough. As a result, it's often a good idea to use {func}`jax.checkpoint` on the body function passed to {func}`jax.lax.scan`. For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a {func}`jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this: diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 02077d2a6b00..8823fac13042 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -202,7 +202,7 @@ "id": "cDpQ5u63Ba_H" }, "source": [ - "It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results." + "It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX's functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results." ] }, { diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index f35c5ead13b7..1529dcef5e37 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -121,7 +121,7 @@ print(jit(pure_uses_internal_state)(5.)) +++ {"id": "cDpQ5u63Ba_H"} -It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results. +It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX's functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results. ```{code-cell} ipython3 :id: w99WXa6bBa_H diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 82381838a5aa..feb906546341 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -1129,7 +1129,7 @@ "source": [ "When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you.\n", "\n", - "One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n", + "One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`.\n", "\n", "For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:" ] diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 0a6c84b2d88f..8ba87dcfee18 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -490,7 +490,7 @@ print_fwd_bwd(f, 3.) When differentiated functions are staged out to XLA for compilation, for example by applying `jax.jit` to a function which contains a `jax.grad` call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result, **`jax.checkpoint` often isn't needed for differentiated functions under a `jax.jit`**. XLA will optimize things for you. -One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`. +One exception is when using staged-out control flow, like `jax.lax.scan`. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-pass `scan` and the corresponding backward-pass `scan`, typically aren't as thorough. As a result, it's often a good idea to use `jax.checkpoint` on the body function passed to `jax.lax.scan`. For example, one common pattern in large [Transformer models](https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)) is to express the architecture as a `jax.lax.scan` over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this: diff --git a/docs/xla_flags.md b/docs/xla_flags.md index b332940ccb9d..fd351a7966b2 100644 --- a/docs/xla_flags.md +++ b/docs/xla_flags.md @@ -59,7 +59,7 @@ XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py | Flag | Type | Notes | | ---- | ---- | ----- | | `xla_tpu_enable_data_parallel_all_reduce_opt` | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. | -| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure. | +| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes don't match what can be saved in place in the stacked variables. Can increase memory pressure. | | `xla_tpu_enable_async_collective_fusion` | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. | | `xla_tpu_enable_async_collective_fusion_fuse_all_gather` | TristateFlag (true/false/auto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to `auto`, it will be enabled based on the target. | | `xla_tpu_enable_async_collective_fusion_multiple_steps` | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. | From 0bb68f6ad2c467633973f48b49d2d27e28dc9bd2 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 3 Dec 2024 03:57:38 -0800 Subject: [PATCH 577/698] [Pallas:MGPU] Add tests for attention with non-trivial batch size PiperOrigin-RevId: 702280467 --- tests/pallas/mgpu_attention_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index 3fa4f6a6f2dd..43727f47338b 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -49,7 +49,7 @@ def setUp(self): self.skipTest("Only works on GPU with capability sm90a") @parameterized.product( - batch_size=(1,), + batch_size=(1, 4), q_seq_len=(4096,), kv_seq_len=(4096,), num_q_and_kv_heads=((4, 1), # MQA From abf8f43007893e084ac8dc23c59a4dd14ca13973 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 3 Dec 2024 06:25:55 -0800 Subject: [PATCH 578/698] [jax] Improve naming of `DotAlgorithmPreset` properties and simplify return types. PiperOrigin-RevId: 702317395 --- jax/_src/lax/lax.py | 78 ++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 631fab903277..ad08e1335a40 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -876,7 +876,7 @@ def __str__(self) -> str: return self.name @property - def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: + def supported_lhs_types(self) -> tuple[DTypeLike, ...] | None: match self: case ( DotAlgorithmPreset.DEFAULT @@ -887,7 +887,7 @@ def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: ): return None case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32: - return np.float16 + return (np.float16,) case ( DotAlgorithmPreset.BF16_BF16_BF16 | DotAlgorithmPreset.BF16_BF16_F32 @@ -897,13 +897,13 @@ def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: # type. If not, we explicitly cast to bfloat16. return (dtypes.bfloat16, np.float32) case DotAlgorithmPreset.F64_F64_F64: - return np.float64 + return (np.float64,) case _: - return np.float32 + return (np.float32,) @property - def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None: - return self.lhs_precision_type + def supported_rhs_types(self) -> tuple[DTypeLike, ...] | None: + return self.supported_lhs_types @property def accumulation_type(self) -> DTypeLike | None: @@ -927,12 +927,19 @@ def accumulation_type(self) -> DTypeLike | None: def supported_output_types(self) -> tuple[DTypeLike, ...] | None: match self: case ( - DotAlgorithmPreset.ANY_F8_ANY_F8_F32 | - DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM + DotAlgorithmPreset.ANY_F8_ANY_F8_F32 + | DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM ): - return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn, - dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, - dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz) + return ( + np.float32, + np.float16, + dtypes.bfloat16, + dtypes.float8_e4m3fn, + dtypes.float8_e5m2, + dtypes.float8_e5m2fnuz, + dtypes.float8_e4m3fnuz, + dtypes.float8_e4m3b11fnuz, + ) case DotAlgorithmPreset.F16_F16_F32: return (np.float32, np.float16) case _: @@ -3699,35 +3706,32 @@ def get_algorithm_compute_types( rhs_dtype: DTypeLike, out_dtype: DTypeLike | None = None, ) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]: - def maybe_convert_dtype(input_dtype, target_dtype): - if target_dtype is None: - return input_dtype - if not isinstance(target_dtype, tuple): - target_dtype = (target_dtype,) - if np.dtype(input_dtype) in map(np.dtype, target_dtype): - return input_dtype - return target_dtype[0] + if isinstance(algorithm, DotAlgorithm): + return ( + algorithm.lhs_precision_type, + algorithm.rhs_precision_type, + algorithm.accumulation_type, + ) + + supported_output_types = algorithm.supported_output_types if algorithm == DotAlgorithmPreset.BF16_BF16_F32: - lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type) - rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type) - if np.dtype(lhs_dtype) == dtypes.bfloat16: - out_dtype = maybe_convert_dtype(out_dtype, - (np.float32, dtypes.bfloat16)) - else: - out_dtype = maybe_convert_dtype(out_dtype, np.float32) - return lhs_dtype, rhs_dtype, out_dtype - else: - if isinstance(algorithm, DotAlgorithmPreset): - supported_output_types = algorithm.supported_output_types - else: - supported_output_types = (algorithm.accumulation_type,) + # If dtype is anything other than float32, it will be cast to bfloat16. + if np.dtype(lhs_dtype) != np.float32: + supported_output_types = (np.float32, dtypes.bfloat16) - return ( - maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type), - maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type), - maybe_convert_dtype(out_dtype, supported_output_types), - ) + def maybe_convert_dtype(input_dtype, target_dtypes): + if target_dtypes is None: + return input_dtype + if np.dtype(input_dtype) in map(np.dtype, target_dtypes): + return input_dtype + return target_dtypes[0] + + return ( + maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types), + maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types), + maybe_convert_dtype(out_dtype, supported_output_types), + ) def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, From a54319ec1886ed920d50cacf10e147a743888464 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 3 Dec 2024 08:04:42 -0800 Subject: [PATCH 579/698] [jax] Make `DotAlgorithmPreset.supported_output_types` a function of the input types. PiperOrigin-RevId: 702342849 --- jax/_src/lax/lax.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ad08e1335a40..f6d35cf74c40 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -923,8 +923,15 @@ def accumulation_type(self) -> DTypeLike | None: case _: return np.float32 - @property - def supported_output_types(self) -> tuple[DTypeLike, ...] | None: + def supported_output_types( + self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike + ) -> tuple[DTypeLike, ...] | None: + if np.dtype(lhs_dtype) != np.dtype(rhs_dtype): + raise ValueError( + f"The dot algorithm '{self}' requires both inputs to have the same " + f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.' + ) + match self: case ( DotAlgorithmPreset.ANY_F8_ANY_F8_F32 @@ -942,6 +949,11 @@ def supported_output_types(self) -> tuple[DTypeLike, ...] | None: ) case DotAlgorithmPreset.F16_F16_F32: return (np.float32, np.float16) + case DotAlgorithmPreset.BF16_BF16_F32: + if np.dtype(lhs_dtype) == dtypes.bfloat16: + return (np.float32, dtypes.bfloat16) + else: + return (np.float32,) case _: accumulation_type = self.accumulation_type return None if accumulation_type is None else (accumulation_type,) @@ -3713,13 +3725,6 @@ def get_algorithm_compute_types( algorithm.accumulation_type, ) - supported_output_types = algorithm.supported_output_types - - if algorithm == DotAlgorithmPreset.BF16_BF16_F32: - # If dtype is anything other than float32, it will be cast to bfloat16. - if np.dtype(lhs_dtype) != np.float32: - supported_output_types = (np.float32, dtypes.bfloat16) - def maybe_convert_dtype(input_dtype, target_dtypes): if target_dtypes is None: return input_dtype @@ -3727,11 +3732,12 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return input_dtype return target_dtypes[0] - return ( - maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types), - maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types), - maybe_convert_dtype(out_dtype, supported_output_types), + lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types) + rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types) + out_type = maybe_convert_dtype( + out_dtype, algorithm.supported_output_types(lhs_dtype, rhs_dtype) ) + return lhs_dtype, rhs_dtype, out_type def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, From 4b6035cad2beb0490ae9d2640a380d343ce602a2 Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Tue, 3 Dec 2024 08:28:49 -0800 Subject: [PATCH 580/698] Fix doc typo PiperOrigin-RevId: 702349782 --- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 3c5eaa7910dd..85b7364ce2cc 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -401,7 +401,7 @@ def wgmma( a, b: pallas_core.TransformedRef, ) -> None: - """Performs and asynchronous warp group matmul-accumulate on the given references. + """Performs an asynchronous warp group matmul-accumulate on the given references. Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, except that the computation is performed asynchronously. From 2afc65a165e4abd3255aff6685f8226222606bd8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Dec 2024 08:44:35 -0800 Subject: [PATCH 581/698] Fix nightly numpy test --- tests/lax_numpy_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 535e2c632027..a85c7832e6c8 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6262,6 +6262,7 @@ def testWrappedSignaturesMatch(self): 'isnat', 'loadtxt', 'matrix', + 'matvec', 'may_share_memory', 'memmap', 'min_scalar_type', @@ -6287,7 +6288,8 @@ def testWrappedSignaturesMatch(self): 'show_runtime', 'test', 'trapz', - 'typename'} + 'typename', + 'vecmat'} # symbols removed in NumPy 2.0 skip |= {'add_docstring', From cc95327a579db1f4bd7300a7f606e23890204d0b Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 3 Dec 2024 08:52:24 -0800 Subject: [PATCH 582/698] Fix missing quotes in local xla path PiperOrigin-RevId: 702356318 --- .github/workflows/cloud-tpu-ci-nightly.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index b598565223c0..fe879617c8a7 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -54,6 +54,10 @@ jobs: with: repository: openxla/xla path: xla + # We need to mark the GitHub workspace as safe as otherwise git commands will fail. + - name: Mark GitHub workspace as safe + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Install JAX test requirements run: | $PYTHON -m pip install -U -r build/test-requirements.txt @@ -65,7 +69,7 @@ jobs: # Build and install jaxlib at head $PYTHON build/build.py build --wheels=jaxlib \ --bazel_options=--config=rbe_linux_x86_64 \ - --local_xla_path=$(pwd)/xla" \ + --local_xla_path="$(pwd)/xla" \ --verbose $PYTHON -m pip install dist/*.whl From c4d19ca83cdcfbf2d34e2affb86946da2f4773dc Mon Sep 17 00:00:00 2001 From: Goran Flegar Date: Tue, 3 Dec 2024 10:55:13 -0800 Subject: [PATCH 583/698] Integrate Triton up to [9732c047](https://github.com/openai/triton/commits/9732c04701bd856daca89bde38bafa4636ca56a8) PiperOrigin-RevId: 702397897 --- jaxlib/gpu/triton_kernels.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index c4a9af5ffe2e..22397ff908bc 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -466,7 +466,8 @@ KernelCall::KernelCall(Kernel kernel, uint32_t grid_0, uint32_t grid_1, absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) { std::vector params; - params.reserve(parameters_.size()); + // We need an additional parameter for the scratchpad buffer. + params.reserve(parameters_.size() + 1); for (size_t i = 0; i < parameters_.size(); ++i) { const Parameter& param = parameters_[i]; if (std::holds_alternative(param.value)) { @@ -492,6 +493,14 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) { param.value))); } } + // Triton's kernel ABI expects an additional scratchpad global memory. + // For now it is only used for on-device creation of TMA descriptors, which + // we do not use yet, so we are just replacing this argument with a null + // pointer. + // TODO: b/381242007 - Allocate a proper buffer if we want to use + // device-side TMA APIs. + void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns. + params.push_back(&scratch_ptr); return kernel_.Launch(stream, grid_, params.data()); } From 73962b740890a728295fa09f515dcf96cb820822 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 3 Dec 2024 11:14:41 -0800 Subject: [PATCH 584/698] Reverts a54319ec1886ed920d50cacf10e147a743888464 PiperOrigin-RevId: 702405512 --- jax/_src/lax/lax.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f6d35cf74c40..ad08e1335a40 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -923,15 +923,8 @@ def accumulation_type(self) -> DTypeLike | None: case _: return np.float32 - def supported_output_types( - self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike - ) -> tuple[DTypeLike, ...] | None: - if np.dtype(lhs_dtype) != np.dtype(rhs_dtype): - raise ValueError( - f"The dot algorithm '{self}' requires both inputs to have the same " - f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.' - ) - + @property + def supported_output_types(self) -> tuple[DTypeLike, ...] | None: match self: case ( DotAlgorithmPreset.ANY_F8_ANY_F8_F32 @@ -949,11 +942,6 @@ def supported_output_types( ) case DotAlgorithmPreset.F16_F16_F32: return (np.float32, np.float16) - case DotAlgorithmPreset.BF16_BF16_F32: - if np.dtype(lhs_dtype) == dtypes.bfloat16: - return (np.float32, dtypes.bfloat16) - else: - return (np.float32,) case _: accumulation_type = self.accumulation_type return None if accumulation_type is None else (accumulation_type,) @@ -3725,6 +3713,13 @@ def get_algorithm_compute_types( algorithm.accumulation_type, ) + supported_output_types = algorithm.supported_output_types + + if algorithm == DotAlgorithmPreset.BF16_BF16_F32: + # If dtype is anything other than float32, it will be cast to bfloat16. + if np.dtype(lhs_dtype) != np.float32: + supported_output_types = (np.float32, dtypes.bfloat16) + def maybe_convert_dtype(input_dtype, target_dtypes): if target_dtypes is None: return input_dtype @@ -3732,12 +3727,11 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return input_dtype return target_dtypes[0] - lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types) - rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types) - out_type = maybe_convert_dtype( - out_dtype, algorithm.supported_output_types(lhs_dtype, rhs_dtype) + return ( + maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types), + maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types), + maybe_convert_dtype(out_dtype, supported_output_types), ) - return lhs_dtype, rhs_dtype, out_type def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, From 2dae81a8ed21daf1c4cbbbd21d255aba3e5a6d80 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Tue, 3 Dec 2024 12:55:58 -0800 Subject: [PATCH 585/698] [Pallas TPU] Enable test for `jnp.logical_not` because it's now supported PiperOrigin-RevId: 702439876 --- tests/pallas/ops_test.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 214e258892ec..8586ae346654 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -808,14 +808,6 @@ def test_elementwise(self, fn, dtype): ): self.skipTest(f"{fn.__name__} not implemented on TPU") - # TODO: https://github.com/jax-ml/jax/issues/24243 - if ( - jtu.test_device_matches(["tpu"]) - and fn == jnp.logical_not - and not self.INTERPRET - ): - self.skipTest("logical_not on TPU is only supported in interpret mode") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), From 9e5edb7015ac80679b3a5f5e9868e30d9e726cb3 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Tue, 3 Dec 2024 14:57:58 -0800 Subject: [PATCH 586/698] [Mosaic TPU] Support packed type matmul with arbitrary shapes. This cl removes all the shape constrains in matmul for all types. We only need to mask out subelement on contracting dim. Instead of unpacking data and applying masks, we create a VREG-sized i32 "mask" which contains subelement mask info to logical and with target vreg. Through this way, in order to mask sub-elements, each target vreg only needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + packing). PiperOrigin-RevId: 702480077 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 9 +++ .../tpu/transforms/apply_vector_layout.cc | 78 ++++++++++++------- .../tpu/transforms/infer_vector_layout.cc | 75 ++++-------------- 3 files changed, 73 insertions(+), 89 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 3271c0874572..2107cd7fcf82 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -544,6 +544,15 @@ LogicalResult MatmulOp::verify() { // however, a good start and the recommended place to add more invariants. const VectorType lhs_ty = getLhs().getType(); const VectorType rhs_ty = getRhs().getType(); + const VectorType acc_ty = getAcc().getType(); + const VectorType res_ty = getResult().getType(); + if (acc_ty != res_ty) { + return emitOpError( + "Not implemented: matmul acc and result have different types"); + } + if (acc_ty.getElementTypeBitWidth() != 32) { + return emitOpError("Expected matmul acc to be 32-bit"); + } if (getTransposeLhs()) { emitOpError( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8ade7450881a..cac197479f4a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1764,19 +1764,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, // TODO(tlongeri): This should be part of the tpu::MatmulOp verifier TPU_ASSERT_EQ_OP(lhs_shape.size(), 2); TPU_ASSERT_EQ_OP(rhs_shape.size(), 2); - // The code below puts no constraints on the second dimension of both lhs and - // rhs. However, leading axis of lhs and rhs needs to be a multiple of native - // tiling for packed types. - if (layout_lhs.packing() != 1 && lhs_shape[0] % layout_lhs.tiling()[0] != 0) { - return op.emitOpError( - "Not implemented: Unsupported LHS shape with padded tiling and " - "narrower data type"); - } - if (layout_rhs.packing() != 1 && rhs_shape[0] % layout_rhs.tiling()[0] != 0) { - return op.emitOpError( - "Not implemented: Unsupported RHS shape with padded tiling and " - "narrower data type"); - } const int64_t padded_lhs_rows = llvm::alignTo(lhs_shape[0], layout_lhs.tiling()[0]); @@ -1787,10 +1774,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, const int64_t padded_rhs_cols = llvm::alignTo(rhs_shape[1], layout_rhs.tiling()[1]); - if (llvm::alignTo(lhs_shape[0], layout_acc.tiling()[0]) != padded_lhs_rows) { - return op.emitOpError( - "Not implemented: Matmul acc requires less padding than lhs"); - } FAILUREOR_ASSIGN_OR_RETURN( xla::Array lhs_vregs, disassemble(builder, layout_lhs, lhs, ctx.target_shape)); @@ -1801,7 +1784,6 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, xla::Array rhs_vregs, disassemble(builder, layout_rhs, rhs, ctx.target_shape)); TPU_ASSERT_EQ_OP(padded_lhs_rows, lhs_vregs.dim(0) * layout_lhs.tiling()[0]); - TPU_ASSERT_EQ_OP(padded_lhs_rows, acc_vregs.dim(0) * layout_acc.tiling()[0]); TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]); const VectorType i32_vreg_ty = @@ -1823,12 +1805,14 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, // We can also extend this helper function with padding_top and padding_left // based on the offsets in vregs. - // TODO(b/341729764): Support mask subelements. + const Value i32_zeros_vreg = builder.create( + op.getLoc(), + DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0))); + const Value i32_max_vreg = builder.create( + op.getLoc(), DenseElementsAttr::get( + i32_vreg_ty, builder.getI32IntegerAttr(0xffffffff))); auto maskVregs = [&](xla::Array &vregs, int64_t padding_bottom, int64_t padding_right) { - const Value i32_zeros_vreg = builder.create( - op.getLoc(), - DenseElementsAttr::get(i32_vreg_ty, builder.getI32IntegerAttr(0))); auto vreg_ty = cast(vregs.begin()->getType()); int packing = vreg_ty.getRank() > 2 ? vreg_ty.getShape()[2] : 1; // Mask out the bottom. @@ -1836,14 +1820,49 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, // We have limited the row size of LHS and RHS need to be a multiple of // native tiling at the beginning of this rule. Therefore, it is safe to // bitcast to x32 vreg for masking. - CHECK_EQ(padding_bottom % packing, 0); - padding_bottom /= packing; - auto mask_bottom = getX32VmaskByPaddingEnd(0, padding_bottom); + int sub_padding = padding_bottom % packing; + int x32_padding_bottom = padding_bottom / packing; + auto mask_bottom = getX32VmaskByPaddingEnd(0, x32_padding_bottom); + // Create an int32 vreg which contains subelement masking and then + // logical_and with target vreg to mask out the unaligned paddings. + // Eg. if padding_bottom = 5, packing = 2, and assume the vreg shape is + // [8, 128], then the mask will be: + // + // sublane 0: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 1: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 2: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 3: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 4: [0xffffffff, 0xffffffff, ..., 0xffffffff] + // sublane 5: [0x0000ffff, 0x0000ffff, ..., 0x0000ffff] + // sublane 6: [0 , 0 , ..., 0 ] + // sublane 7: [0 , 0 , ..., 0 ] + // + // Through this way, in order to mask sub-elements, each target vreg only + // needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + // + packing). + Value partial_sublane_mask = builder.create( + op.getLoc(), + DenseElementsAttr::get( + i32_vreg_ty, + builder.getI32IntegerAttr( + 0xffffffff >> + (sub_padding * vreg_ty.getElementTypeBitWidth())))); + // Insert 0xffffffff above the blended sublane. + Value sublane_mask = builder.create( + getX32VmaskByPaddingEnd(0, x32_padding_bottom + 1), i32_max_vreg, + partial_sublane_mask); + // Insert 0 below the blended sublane. + sublane_mask = builder.create(mask_bottom, sublane_mask, + i32_zeros_vreg); for (int64_t i = 0; i < vregs.dim(1); ++i) { Value &vreg = vregs({vregs.dim(0) - 1, i}); Value i32_vreg = builder.create(i32_vreg_ty, vreg); - i32_vreg = builder.create(mask_bottom, i32_vreg, - i32_zeros_vreg); + if (sub_padding > 0) { + i32_vreg = builder.create(i32_vreg, sublane_mask); + } else { + i32_vreg = builder.create(mask_bottom, i32_vreg, + i32_zeros_vreg); + } vreg = builder.create(vreg_ty, i32_vreg); } } @@ -1929,8 +1948,9 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, lhs_zeros_vreg); xla::Array target_rhs_vregs( {target_rhs_row_vregs, target_rhs_col_vregs}, rhs_zeros_vreg); - xla::Array target_acc_vregs({acc_vregs.dim(0), target_acc_col_vregs}, - acc_zeros_vreg); + xla::Array target_acc_vregs( + {lhs_vregs.dim(0) * layout_lhs.packing(), target_acc_col_vregs}, + acc_zeros_vreg); target_lhs_vregs.UpdateSlice(lhs_vregs, {0, 0}); target_rhs_vregs.UpdateSlice(rhs_vregs, {0, 0}); target_acc_vregs.UpdateSlice(acc_vregs, {0, 0}); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 1be81e733161..bc733742df19 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -903,66 +903,21 @@ class VectorLayoutInferer { } LogicalResult infer(tpu::MatmulOp op) { - auto get_operand_layout = - [&](Value v, llvm::StringRef operand_name, - std::optional major_multiple = std::nullopt, - std::optional minor_multiple = - std::nullopt) -> std::optional { - auto layout = getLayout(v); - if (!layout.has_value()) { - op->emitOpError("Internal error: assert failed: Operand ") - << operand_name << " has no vector layout"; - return std::nullopt; - } - auto vty = cast(v.getType()); - auto tiling = nativeTiling(vty.getElementTypeBitWidth()); - auto shape = vty.getShape().take_back(2); - if (shape[0] % major_multiple.value_or(tiling[0]) != 0 || - shape[1] % minor_multiple.value_or(tiling[1]) != 0) { - op->emitOpError("Matmul operand ") - << operand_name << " must have a shape divisible by (" - << major_multiple.value_or(tiling[0]) << ", " - << minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0] - << ", " << shape[1] << ")"; - return std::nullopt; - } - // Override tiling to match the native one. - return VectorLayout(layout->bitwidth(), {0, 0}, tiling, - ImplicitDim::kNone); - }; - auto res_ty = dyn_cast(op->getResult(0).getType()); - TPU_CHECK_OP(res_ty, "only vector results supported"); - TPU_CHECK_OP(res_ty.getElementTypeBitWidth() == kNativeBitwidth, - "only 32-bit matmul results supported"); - std::array in_layout; - CHECK_EQ(op->getNumOperands(), 3); - std::optional lhs_major_multiple; - std::optional rhs_major_multiple; - // We don't restrict the first lhs axis when the data is not packed. - if (cast(op->getOperand(0).getType()) - .getElementTypeBitWidth() == kNativeBitwidth) { - lhs_major_multiple = 1; - } - // We don't restrict the first rhs axis when the data is not packed. - if (cast(op->getOperand(1).getType()) - .getElementTypeBitWidth() == kNativeBitwidth) { - rhs_major_multiple = 1; - } - in_layout[0] = - get_operand_layout(op->getOperand(0), "lhs", lhs_major_multiple, 1); - if (!in_layout[0].has_value()) { - return failure(); - } - in_layout[1] = - get_operand_layout(op->getOperand(1), "rhs", rhs_major_multiple, 1); - if (!in_layout[1].has_value()) { - return failure(); - } - in_layout[2] = get_operand_layout(op->getOperand(2), "result", 1, 1); - if (!in_layout[2].has_value()) { - return failure(); - } - setLayout(op, in_layout, + auto lhs_bitwidth = op.getLhs().getType().getElementTypeBitWidth(); + auto rhs_bitwidth = op.getRhs().getType().getElementTypeBitWidth(); + auto acc_bitwidth = op.getAcc().getType().getElementTypeBitWidth(); + auto res_bitwidth = op.getResult().getType().getElementTypeBitWidth(); + TPU_CHECK_OP(acc_bitwidth == kNativeBitwidth, + "Expected 32-bit acc in tpu::MatmulOp"); + TPU_CHECK_OP(res_bitwidth == kNativeBitwidth, + "Expected 32-bit result in tpu::MatmulOp"); + auto lhs_layout = VectorLayout( + lhs_bitwidth, {0, 0}, nativeTiling(lhs_bitwidth), ImplicitDim::kNone); + auto rhs_layout = VectorLayout( + rhs_bitwidth, {0, 0}, nativeTiling(rhs_bitwidth), ImplicitDim::kNone); + auto acc_layout = VectorLayout( + acc_bitwidth, {0, 0}, nativeTiling(acc_bitwidth), ImplicitDim::kNone); + setLayout(op, {lhs_layout, rhs_layout, acc_layout}, VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, ImplicitDim::kNone)); return success(); From ceeed909dc324179efc725189d757b3f4d236cb4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 3 Dec 2024 15:38:56 -0800 Subject: [PATCH 587/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/05f004e8368c955b872126b1c978c60e33bbc5c8. PiperOrigin-RevId: 702493080 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 75abd0ab5032..232d19cfbe26 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "6070a19a7c7d1c62e71d73a4ee5a710c641dad2c" -XLA_SHA256 = "fb12e680ac47facf40bda36d436195e1d5454f48d02fb8540a9c6363cfe03cb1" +XLA_COMMIT = "05f004e8368c955b872126b1c978c60e33bbc5c8" +XLA_SHA256 = "f0bedada96f5f1d09f5047c7f9db32d460d147bd0f192607cfbbee9fe5ee2d5f" def repo(): tf_http_archive( From 0140a98e34786790332e60d1d3b4d8a82d29d896 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Dec 2024 15:43:33 -0800 Subject: [PATCH 588/698] Improve trace-time performance of jnp.isscalar --- jax/_src/numpy/lax_numpy.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6a0e4059c4ae..5af8c6ddad19 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool: >>> jnp.isscalar(slice(10)) False """ - if (isinstance(element, (np.ndarray, jax.Array)) - or hasattr(element, '__jax_array__') - or np.isscalar(element)): + if np.isscalar(element): + return True + elif isinstance(element, (np.ndarray, jax.Array)): + return element.ndim == 0 + elif hasattr(element, '__jax_array__'): return asarray(element).ndim == 0 return False From fcf0b6d3daca6001cfa190f433553d5e85a86796 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 3 Dec 2024 16:24:34 -0800 Subject: [PATCH 589/698] Add _raw_platform to work around extra platform normalization logic and enable GPU aot compilation without a GPU present. Fixes https://github.com/jax-ml/jax/issues/23971 PiperOrigin-RevId: 702506848 --- jax/_src/interpreters/pxla.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 11df2d38f21d..7e13285c4a56 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2220,7 +2220,10 @@ def lower_sharding_computation( out_shardings = _concretize_abstract_shardings( out_shardings, global_out_avals, device_assignment) - platforms = lowering_platforms or (backend.platform,) + # TODO(parkers): One _raw_platform has been unified with platform, + # change this back to just read platform. + platforms = lowering_platforms or ( + getattr(backend, "_raw_platform", backend.platform),) committed = bool( devices_from_context or From fd4b160880d5331de1df556bbcc6b7af5b548ce5 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 3 Dec 2024 16:53:20 -0800 Subject: [PATCH 590/698] Use JAX's default device instead of jax.devices()[0], if set. PiperOrigin-RevId: 702515221 --- jax/_src/dispatch.py | 2 +- jax/_src/interpreters/pxla.py | 4 ++-- jax/_src/pallas/mosaic/BUILD | 1 + jax/_src/pallas/mosaic/pipeline.py | 3 ++- jax/extend/backend.py | 3 +++ 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b3f16c724ee4..54c3a43e8a84 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -440,7 +440,7 @@ def _device_put_sharding_impl(x, aval, device, copy): return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) - sh = SingleDeviceSharding(pxla._get_default_device() + sh = SingleDeviceSharding(pxla.get_default_device() if device is None else device) return _DeferredShardArg(x, sh, aval, device is not None, copy) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7e13285c4a56..98cfbbb1d589 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1710,7 +1710,7 @@ class DeviceAssignmentMismatchError(Exception): ] -def _get_default_device() -> xc.Device: +def get_default_device() -> xc.Device: if isinstance(config.default_device.value, str): return xb.get_backend(config.default_device.value).local_devices()[0] else: @@ -1749,7 +1749,7 @@ def _get_and_check_device_assignment( if first_sharding_info is None and devices: final_device_assignment = devices elif first_sharding_info is None: - final_device_assignment = (_get_default_device(),) + final_device_assignment = (get_default_device(),) else: final_device_assignment = first_sharding_info[0] # type: ignore return xb.get_device_backend(final_device_assignment[0]), final_device_assignment diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index bf0d83bb3dc9..c9754933908f 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -124,6 +124,7 @@ py_library( "//jax:pallas", "//jax:util", "//jax/_src/pallas", + "//jax/extend:backend", ] + py_deps("numpy"), ) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 0112b3cb4dbb..6ddb21e77bd8 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -33,6 +33,7 @@ from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import primitives as tpu_primitives from jax.experimental import pallas as pl +from jax.extend.backend import get_default_device import jax.numpy as jnp import numpy as np @@ -75,7 +76,7 @@ def add_leaves(i, x): @jax_util.cache(trace_context_in_key=False) def _get_tpu_generation() -> int: - kind = jax.devices()[0].device_kind + kind = get_default_device().device_kind if kind.endswith(' lite'): kind = kind[:-len(' lite')] assert kind[:5] == "TPU v", kind diff --git a/jax/extend/backend.py b/jax/extend/backend.py index b1e471133482..8d5488baba16 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -24,3 +24,6 @@ get_backend as get_backend, register_backend_factory as register_backend_factory, ) +from jax._src.interpreters.pxla import ( + get_default_device as get_default_device +) From f6f4ef06cd39e3760d4ed3d281e0169d08d840a6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Dec 2024 17:20:40 -0800 Subject: [PATCH 591/698] Fix indexing corner case with empty ellipses --- jax/_src/numpy/lax_numpy.py | 13 ++++++++----- tests/lax_numpy_indexing_test.py | 8 ++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5af8c6ddad19..3d99405428de 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -11971,6 +11971,14 @@ def _int(aval): def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], normalize_indices: bool = True) -> _Indexer: + # Check whether advanced indices are contiguous. We must do this before + # removing ellipses (https://github.com/jax-ml/jax/issues/25109) + # If advanced idexing axes do not appear contiguously, NumPy semantics + # move the advanced axes to the front. + is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray)) + or isscalar(e) for e in idx]) + advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1) + # Remove ellipses and add trailing slice(None)s. idx = _canonicalize_tuple_index(len(x_shape), idx) @@ -11987,10 +11995,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], # Check for advanced indexing: # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing - # Do the advanced indexing axes appear contiguously? If not, NumPy semantics - # move the advanced axes to the front. - advanced_axes_are_contiguous = False - advanced_indexes: Sequence[Array | np.ndarray] | None = None # The positions of the advanced indexing axes in `idx`. @@ -12009,7 +12013,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j) for e, i, j in advanced_pairs) advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs) - advanced_axes_are_contiguous = bool(np.all(np.diff(idx_advanced_axes) == 1)) x_axis = 0 # Current axis in x. y_axis = 0 # Current axis in y, before collapsing. See below. diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 392af2688c1d..ab625d10b4d8 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -399,6 +399,14 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)), out_shape=(3,)), ]), + ("EllipsisWithArrayIndices", [ + IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])), + out_shape=(2, 4)), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])), + out_shape=(2, 3)), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])), + out_shape=(3, 2)), + ]), ] From 8c78c1e735a7f9ed35665b9ed4364a843df1599e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 3 Dec 2024 18:06:46 -0800 Subject: [PATCH 592/698] Convert MSYS' Linux-like paths to Windows paths in JAX CI. This is necessary on Windows because some applications such as Bazel/Docker do not understand/handle Linux-like paths that msys uses. The script first converts all `JAXCI.*DIR` variables containing msys-like paths to Windows paths using cygpath. Then it sources those variables into the shell environment so that any tools that use those variables can run correctly. PiperOrigin-RevId: 702533023 --- .../convert_msys_paths_to_win_paths.py | 80 +++++++++++++++++++ ci/utilities/setup_build_environment.sh | 7 ++ 2 files changed, 87 insertions(+) create mode 100644 ci/utilities/convert_msys_paths_to_win_paths.py diff --git a/ci/utilities/convert_msys_paths_to_win_paths.py b/ci/utilities/convert_msys_paths_to_win_paths.py new file mode 100644 index 000000000000..6164e6a5e29d --- /dev/null +++ b/ci/utilities/convert_msys_paths_to_win_paths.py @@ -0,0 +1,80 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Converts MSYS Linux-like paths stored in env variables to Windows paths. + +This is necessary on Windows, because some applications do not understand/handle +Linux-like paths MSYS uses, for example, Bazel. +""" +import argparse +import os +import subprocess + +def msys_to_windows_path(msys_path): + """Converts an MSYS path to a Windows path using cygpath. + + Args: + msys_path: The MSYS path to convert. + + Returns: + The corresponding Windows path. + """ + try: + # Use cygpath with the -w flag to convert to Windows format + process = subprocess.run(['cygpath', '-w', msys_path], capture_output=True, text=True, check=True) + windows_path = process.stdout.strip() + return windows_path + except FileNotFoundError: + print("Error: cygpath not found. Make sure it's in your PATH.") + return None + except subprocess.CalledProcessError as e: + print(f"Error converting path: {e}") + return None + +def should_convert(var: str, + convert: list[str] | None): + """Check the variable name against convert list""" + if var in convert: + return True + else: + return False + +def main(parsed_args: argparse.Namespace): + converted_paths = {} + + for var, value in os.environ.items(): + if not value or not should_convert(var, + parsed_args.convert): + continue + converted_path = msys_to_windows_path(value) + converted_paths[var] = converted_path + + var_str = '\n'.join(f'export {k}="{v}"' + for k, v in converted_paths.items()) + # The string can then be piped into `source`, to re-set the + # 'converted' variables. + print(var_str) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=( + 'Convert MSYS paths in environment variables to Windows paths.')) + parser.add_argument('--convert', + nargs='+', + required=True, + help='Space separated list of environment variables to convert. E.g: --convert env_var1 env_var2') + args = parser.parse_args() + + main(args) diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index e77e84f3c07f..964a6e4ac679 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -68,4 +68,11 @@ if [[ ! -z ${JAXCI_XLA_GIT_DIR} ]]; then echo "and JAXCI_XLA_COMMIT in your environment. Note that the Bazel RBE test" echo "commands overrides the XLA repository and thus require a local copy of" echo "XLA to run." +fi + +# On Windows, convert MSYS Linux-like paths to Windows paths. +if [[ $(uname -s) =~ "MSYS_NT" ]]; then + echo 'Converting MSYS Linux-like paths to Windows paths (for Bazel, Python, etc.)' + # Convert all "JAXCI.*DIR" variables + source <(python3 ./ci/utilities/convert_msys_paths_to_win_paths.py --convert $(env | grep "JAXCI.*DIR" | awk -F= '{print $1}')) fi \ No newline at end of file From cb2cf56e9761921bd4ed1edf8394cda8ad4c3a01 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 4 Dec 2024 01:26:22 -0800 Subject: [PATCH 593/698] [Mosaic GPU] Add missing import. PiperOrigin-RevId: 702629468 --- jax/experimental/mosaic/gpu/examples/matmul.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 2ca22f54e1b4..7aa96e7fa5d3 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -20,6 +20,7 @@ import jax from jax import random +from jax._src import test_util as jtu # noqa: F401 from jax._src.interpreters import mlir from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import * # noqa: F403 From 1ddba9b1d08b0775ff17132aac2081da2e1602c6 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 4 Dec 2024 04:20:49 -0800 Subject: [PATCH 594/698] [mgpu_pallas] Optionally pass default value instead of raising an error when trying to ensure ir Value. PiperOrigin-RevId: 702672662 --- jax/_src/pallas/mosaic_gpu/lowering.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5d33d8895b61..b120d297456b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -142,6 +142,7 @@ def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources: # Assume that unsupported primitives are neutral wrt resource usage. continue rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params) + return rs @@ -1592,6 +1593,15 @@ def _while_lowering_rule( def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in + def _yielded_values(outs, avals): + ret = [] + for out, aval in zip(outs, avals): + if isinstance(out, mgpu.FragmentedArray): + ret.append(out) + else: + ret.append(_ensure_ir_value(out, aval.dtype)) + return ret + # We need the branch return mlir types in order to construct the # switch operation. To avoid leaking information about what kind of # mlir types are internal to FragmentedArrays and other mgpu types, @@ -1601,10 +1611,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args ) - yielded_types, _ = jax.tree.flatten([ - (_ensure_ir_value(out, aval.dtype) or out).type - for out, aval in zip(outs, ctx.avals_out) - ]) + yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] switch_op = scf_dialect.IndexSwitchOp( yielded_types, @@ -1626,11 +1633,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts ) - yielded = [ - _ensure_ir_value(out, aval.dtype) or out - for out, aval in zip(outs, ctx.avals_out) - ] - yielded_leaves, yielded_treedef = jax.tree.flatten(yielded) + yielded_leaves, yielded_treedef = jax.tree.flatten(_yielded_values(outs, ctx.avals_out)) if treedef is None: treedef = yielded_treedef else: From 3895e0372caa9e7dc393b0ef203ae988ebd57f86 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Wed, 4 Dec 2024 06:35:47 -0800 Subject: [PATCH 595/698] [mgpu_pallas] Allow loading scalars or indexing arrays from gmem using splat. PiperOrigin-RevId: 702704429 --- jax/_src/pallas/mosaic_gpu/lowering.py | 11 +++++++++++ tests/pallas/mosaic_gpu_test.py | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index b120d297456b..dbd7cb13c078 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1039,10 +1039,15 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") + x_aval = ctx.avals_in[0] + transforms = jax.tree.unflatten(tree, leaves) x_smem, transforms = _handle_reshaping(x_smem, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) + + print("ctx:", ctx) + print("transforms:", transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != (64, swizzle // x_aval.dtype.itemsize): @@ -1051,6 +1056,12 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle ) case (): + # Handle scalar indexing. + if not ctx.avals_out[0].shape: + is_signed = mgpu_utils.is_signed(x_aval.dtype) + val = memref_dialect.load(x_smem, []) + return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed) + return mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d27c7d887db1..d87f4a1aa373 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -574,6 +574,18 @@ def kernel(x_ref, o_ref): self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) + def test_load_scalar(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), + in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.broadcast_to(x_ref[10], (128,)) + + np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)), + jnp.full((128,), 10, dtype=jnp.int32)) + def test_run_scoped(self): def kernel(x_ref, o_ref): def body(tmp_ref): From 11090be0b37e9c79c9e9d4fd5e1507c0120aa1cb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 4 Dec 2024 06:54:11 -0800 Subject: [PATCH 596/698] [Mosaic GPU] Add an optimization barrier The barrier is a no-op at runtime, but appears as a side-effecting op to LLVM which prevents it from moving the (even pure) computations that involve the supplied arrays past the barrier. PiperOrigin-RevId: 702709125 --- jax/_src/pallas/mosaic_gpu/lowering.py | 6 + jax/experimental/mosaic/gpu/__init__.py | 1 + .../mosaic/gpu/fragmented_array.py | 105 ++++++++++++++++++ jax/experimental/mosaic/gpu/wgmma.py | 58 ++-------- jax/experimental/pallas/mosaic_gpu.py | 2 +- .../pallas/ops/gpu/attention_mgpu.py | 17 ++- tests/mosaic/gpu_test.py | 14 +++ 7 files changed, 147 insertions(+), 56 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index dbd7cb13c078..2900da133158 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1681,6 +1681,12 @@ def _bitcast_convert_type_lowering_rule( ) +@register_lowering_rule(lax.optimization_barrier_p) +def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): + args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) + return mgpu.optimization_barrier(*args) + + def _bcast( x: ir.Value, y: ir.Value, diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 7857ffb3c09b..c1daa33576bb 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -43,6 +43,7 @@ WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, + optimization_barrier as optimization_barrier, ) from .utils import ( BarrierRef as BarrierRef, diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 4d78270e5009..2da3de70658b 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1866,3 +1866,108 @@ def subf(a: ir.Value, b: ir.Value): def mulf(a: ir.Value, b: ir.Value): return arith.mulf(a, b, fastmath=arith.FastMathFlags.contract) + + +def optimization_barrier(*arrays: mgpu.FragmentedArray): + """Acts as an optimization barrier for LLVM. + + Passing arrays through this function will make sure that they are computed + before any side-effecting operations that follow this barrier. + """ + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + + regs = [] + reg_dtypes = [] + reg_constraints = [] + ptx_lines = ["// Optimization barrier"] + repack_fns = [] + # We unpack each array into a flat list of registers, and prepare the + # functions that invert the transform in repack_fns. + for array in arrays: + ptx_lines.append("// Next array") + reg_ty = array.registers.flat[0].type + dtype = array.mlir_dtype + num_prev_cstr = len(reg_constraints) + if ir.F32Type.isinstance(dtype): + if ir.VectorType.isinstance(reg_ty): + [vec_len] = ir.VectorType(reg_ty).shape + array_regs = [ # pylint: disable=g-complex-comprehension + vector.extractelement(reg, position=c(pos, index)) + for reg in array.registers.flat + for pos in range(vec_len) + ] + def _repack(regs, reg_ty=reg_ty): + reg = llvm.mlir_undef(reg_ty) + [vec_len] = ir.VectorType(reg_ty).shape + for i_elem in range(vec_len): + reg = llvm.insertelement( + reg, next(regs), arith.constant(i32, i_elem) + ) + return reg + repack_fns.append(_repack) + else: + array_regs = list(array.registers.flat) + repack_fns.append(lambda regs: next(regs)) + reg_constraint = "f" + elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): + if not ir.VectorType.isinstance(reg_ty): + raise NotImplementedError(array.mlir_dtype) + [vec_len] = ir.VectorType(reg_ty).shape + if vec_len != 2: + raise NotImplementedError(vec_len) + i32_reg_ty = ir.VectorType.get((1,), i32) + array_regs = [ + vector.extractelement( + vector.bitcast(i32_reg_ty, reg), position=c(0, index) + ) + for reg in array.registers.flat + ] + reg_constraint = "r" + def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): + return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) + repack_fns.append(_repack) + else: + raise NotImplementedError(array.mlir_dtype) + regs += array_regs + reg_dtypes += [array_regs[0].type] * len(array_regs) + reg_constraints += [f"={reg_constraint}"] * len(array_regs) + reg_constraints += [reg_constraint] * len(array_regs) + ptx_lines += [ + f"mov.b32 ${i}, ${len(array_regs)+i}" + for i in range(num_prev_cstr, num_prev_cstr + len(array_regs)) + ] + reg_constraints = ",".join(reg_constraints) + ptx = ";\n\t".join(ptx_lines) + ";" + struct_ty = ir.Type.parse( + f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" + ) + result_struct = llvm.inline_asm( + struct_ty, regs, ptx, reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [ + llvm.extractvalue(dtype, result_struct, [i]) + for i, dtype in enumerate(reg_dtypes) + ] + i32 = ir.IntegerType.get_signless(32) + results = [] + regs_it = iter(regs) + for array, repack_fn in zip(arrays, repack_fns, strict=True): + num_regs = array.registers.size + reg_ty = array.registers.flat[0].type + if ir.VectorType.isinstance(reg_ty): + reg_ty = ir.VectorType(reg_ty) + new_registers = np.empty((num_regs,), dtype=object) + for i_vreg in range(num_regs): + reg = repack_fn(regs_it) + assert reg.type == reg_ty, (reg.type, reg_ty) + new_registers[i_vreg] = reg + results.append( + FragmentedArray( + _registers=new_registers.reshape(array.registers.shape), + _layout=array.layout, + _is_signed=array.is_signed, + ) + ) + return results[0] if len(arrays) == 1 else results diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index ba0f130364ff..6f4d96fbd218 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -23,6 +23,7 @@ from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import vector +from jaxlib.mlir.dialects import nvvm import numpy as np import jax.experimental.mosaic.gpu as mgpu @@ -445,58 +446,13 @@ def wgmma( def wgmma_fence(array: mgpu.FragmentedArray): """Fences the array construction from WGMMA instructions. - This is a little workaround to force LLVM to initialize the PTX registers - before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats - in-register computation as pure and can move it after the fence, which is - explicitly disallowed by the PTX programming model. + LLVM treats in-register computation as pure and can move it after the fence, + which is explicitly disallowed by the PTX programming model. For that reason, + we insert an LLVM optimization barrier before the fence. """ - i32 = ir.IntegerType.get_signless(32) - index = ir.IndexType.get() - dtype = array.mlir_dtype - src_vec_ty = ir.VectorType(array.registers.flat[0].type) - assert src_vec_ty.shape == [2] - - if dtype == ir.F32Type.get(): - regs = [ # pylint: disable=g-complex-comprehension - vector.extractelement(reg, position=c(pos, index)) - for reg in array.registers.flat - for pos in range(2) - ] - reg_dtype = dtype - reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs) - ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))] - elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): - regs = [_as_i32_reg(reg) for reg in array.registers.flat] - reg_dtype = i32 - reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs) - ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))] - else: - raise NotImplementedError(dtype) - reg_constraints = ",".join(reg_constraints_list) - # Copy over the registers. ptxas should be able to remove the moves. - ptx_lines.append("wgmma.fence.sync.aligned") - ptx = ";\n".join(ptx_lines) + ";\n" - dtype_str = str(reg_dtype) - struct_ty = ir.Type.parse( - f"!llvm.struct<({','.join(dtype_str for _ in regs)})>" - ) - acc_struct = llvm.inline_asm( - struct_ty, regs, ptx, reg_constraints, - asm_dialect=0, has_side_effects=True, - ) - regs = [ - llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs)) - ] - if dtype == ir.F32Type.get(): - registers = _as_fragmented_reg_ndarray( - regs, array.mlir_dtype, array.registers.shape - ) - elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get(): - regs = [_unpack_i32(src_vec_ty, r) for r in regs] - registers = np.asarray(regs, dtype=object).reshape(array.registers.shape) - else: - raise NotImplementedError(dtype) - return mgpu.FragmentedArray(_registers=registers, _layout=array.layout, _is_signed=array.is_signed) + array = mgpu.optimization_barrier(array) + nvvm.wgmma_fence_aligned() + return array def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]): diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 2a6a6fa83663..8da2a5095927 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -31,12 +31,12 @@ from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait +from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast -from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index b684aef409f1..294ef153ff93 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -129,10 +129,19 @@ def compute_qk(acc_ref): l_i *= alpha p16 = p.astype(dtype) - plgpu.barrier_wait(v_barriers.at[slot]) - perform_schedule_barrier() - - l_i += p.sum(axis=1) + def end_softmax_barriers(): + plgpu.barrier_arrive(schedule_barrier) # Done with softmax! + plgpu.barrier_wait(v_barriers.at[slot]) + plgpu.barrier_wait(schedule_barrier) # Wait until TensorCore is free. + # Can't fully explain why, but empirically the ordering here influences + # the performance of the final kernel quite significantly. + if head_dim <= 128: + l_i += p.sum(axis=1) + acc, l_i, m_i, p16 = lax.optimization_barrier((acc, l_i, m_i, p16)) + end_softmax_barriers() + else: + end_softmax_barriers() + l_i += p.sum(axis=1) # PV def compute_pv(acc_ref): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 87916ac1c364..80c4048720a3 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1674,6 +1674,20 @@ def kernel(ctx, inp, out, smem): )(x) np.testing.assert_array_equal(result, reference) + @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) + def test_optimization_barrier(self, dtype): + def kernel(ctx, inp, out, smem): + del ctx, smem + arr = mgpu.FragmentedArray.load_strided(inp) + arr2 = arr * 2 + arr, arr2 = mgpu.optimization_barrier(arr, arr2) + (arr + arr2).store_untiled(out) + + x = jnp.arange(256, dtype=dtype) + + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None) + np.testing.assert_array_equal(f(x), x * 3) + class ProfilerTest(TestCase): From 5a250097e46667ed3ea25b7600b7d65d40938854 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 4 Dec 2024 10:00:28 -0500 Subject: [PATCH 597/698] Fix Windows portability problem in compilation cache test. --- tests/compilation_cache_test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 0f949aaf1490..428e518eab51 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -18,6 +18,7 @@ from functools import partial import logging import math +import os import platform import unittest from unittest import mock @@ -539,6 +540,7 @@ def test_backend_serialization_deserialization(self): def test_persistent_cache_enable_xla_caches(self): if jtu.jaxlib_version() <= (0, 4, 35): self.skipTest("Test requires AutotuneCacheMode bindings") + s = os.sep with config.compilation_cache_dir("jax-cache"): with config.persistent_cache_enable_xla_caches("none"): compile_options = compiler.get_compile_options( @@ -552,15 +554,15 @@ def test_persistent_cache_enable_xla_caches(self): compile_options = compiler.get_compile_options( num_replicas=1, num_partitions=1 ) - self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "jax-cache/xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, f"jax-cache{s}xla_gpu_kernel_cache_file") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) - self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "jax-cache/xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) with config.persistent_cache_enable_xla_caches("xla_gpu_kernel_cache_file"): compile_options = compiler.get_compile_options( num_replicas=1, num_partitions=1 ) - self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "jax-cache/xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, f"jax-cache{s}xla_gpu_kernel_cache_file") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) @@ -570,7 +572,7 @@ def test_persistent_cache_enable_xla_caches(self): ) self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) - self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "jax-cache/xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) @jtu.with_config( From 2ac26924578efc7df0e6897237bad01eda081a8d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 4 Dec 2024 07:55:01 -0800 Subject: [PATCH 598/698] Disable backwards compatibility test for Triton IR. Triton doesn't promise backwards compatibility of its IR, so the test is misguided: it is testing a property that isn't true. If we wanted to promise backwards compatibility, we would need to use a versioned IR across the boundary. PiperOrigin-RevId: 702725103 --- tests/pallas/export_back_compat_pallas_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 9e9935884b3a..1b810bcb6f26 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -18,6 +18,7 @@ """ import math +import unittest from absl.testing import absltest import jax @@ -47,6 +48,9 @@ def setUp(self): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() + @unittest.skip("TODO(necula): This test is checking backwards compatibility " + "of Triton IR, but Triton doesn't promise backwards " + "compatibility for its IR.") def test_triton_add_one(self): def func(x): def add_one(x_ref, o_ref): From 12b45b32357a5bf4f1ad26c0bb4657113d0ef1c7 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Dec 2024 07:58:59 -0800 Subject: [PATCH 599/698] [pallas:mosaic_gpu] `emit_pipeline` no longer ignores transforms PiperOrigin-RevId: 702726201 --- jax/_src/pallas/mosaic_gpu/core.py | 19 +++++++++++++++++++ jax/_src/pallas/mosaic_gpu/pipeline.py | 8 +++++++- tests/pallas/mosaic_gpu_test.py | 25 +++++++++++++++++++------ 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index d1f75009c33d..0cc1f77adba3 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -138,6 +138,14 @@ class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC): def to_gpu_transform(self) -> mgpu.MemRefTransform: pass + def batch(self, leading_rank: int): + """Returns a transform that accepts a ref with the extra `leading_rank` dims. + + The returned transform should leave the leading dimensions unchanged and + only apply to the suffix of the shape. + """ + raise NotImplementedError + def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: return aval.update( shape=self.to_gpu_transform().transform_shape(aval.shape) @@ -161,6 +169,9 @@ def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: ref, transforms=(*ref.transforms, UntileRef(self.tiling)) ) + def batch(self, leading_rank: int): + return self + def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) @@ -228,6 +239,11 @@ def __post_init__(self): if set(self.permutation) != set(range(len(self.permutation))): raise ValueError(f"Permutation {self.permutation} is not a permutation.") + def batch(self, leading_rank: int): + return TransposeTransform( + (*range(leading_rank), *(d + leading_rank for d in self.permutation)) + ) + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: return dataclasses.replace( ref, @@ -304,6 +320,9 @@ def __post_init__(self): " accepted." ) + def batch(self, leading_rank: int): + return self + def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: return dataclasses.replace( ref, transforms=(*ref.transforms, UnswizzleRef(self.swizzle)) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 9fcca6acdacc..069b8d9e78d3 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -195,7 +195,13 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( [ - gpu_core.SMEM((max_concurrent_steps, *spec.block_shape), ref.dtype) # type: ignore + gpu_core.SMEM( + (max_concurrent_steps, *spec.block_shape), # type: ignore + ref.dtype, + transforms=tuple( + t.batch(1) for t in getattr(spec, "transforms", ()) + ), + ) if _in_smem(spec) else None for spec, ref in zip(it.chain(in_specs, out_specs), gmem_refs) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d87f4a1aa373..283e3e1a83c6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -820,7 +820,6 @@ def kernel(x_ref, o_ref): x = jnp.arange(256) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) - @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): self.skip_unless_sm90a() @@ -1233,23 +1232,37 @@ def body(step, _): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) - def test_emit(self): + @parameterized.parameters( + ((),), + ((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),), + ) + def test_emit(self, transforms): num_steps = 4 def kernel(x_gmem, o_gmem): plgpu.emit_pipeline( kernel_body, - in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], - out_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + in_specs=[ + plgpu.GPUBlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + out_specs=[ + plgpu.GPUBlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], grid=(num_steps,), max_concurrent_steps=2, )(x_gmem, o_gmem) def kernel_body(x_smem, o_smem): + # +1 for the indexing done by ``emit_pipeline`. + self.assertLen(x_smem.transforms, len(transforms) + 1) o_smem[...] = x_smem[...] + 1.0 - x = jnp.arange(32 * num_steps * 16) - x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + x = jnp.arange(64 * num_steps * 64) + x = x.reshape(-1, num_steps * 64).astype(jnp.float32) kernel_fn = pl.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], From 46eb77bee357e4fecdc57525a0d49c8a2f4b2f2b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Dec 2024 08:20:05 -0800 Subject: [PATCH 600/698] [pallas:mosaic_gpu] Use `jax.tree_util.register_dataclass` for transforms PiperOrigin-RevId: 702733084 --- jax/_src/pallas/mosaic_gpu/core.py | 36 +++++------------------------- 1 file changed, 5 insertions(+), 31 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 0cc1f77adba3..d77ae4358703 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -104,7 +104,6 @@ def __str__(self) -> str: return self.value def __call__( - self, shape: tuple[int, ...], dtype: jnp.dtype, @@ -161,7 +160,6 @@ class TilingTransform(MemoryRefTransform): shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a tiling of (64, 32) will be tiled as (4, 8, 64, 32). """ - tiling: tuple[int, ...] def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef: @@ -176,10 +174,10 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) -@tree_util.register_pytree_node_class +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class UntileRef(state_types.Transform): - tiling: tuple[int, ...] + tiling: tuple[int, ...] = dataclasses.field(metadata=dict(static=True)) def transform_shape(self, shape): if shape is None: @@ -214,14 +212,6 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) - def tree_flatten(self): - return (), (self.tiling,) - - @classmethod - def tree_unflatten(cls, metadata, arrays): - assert not arrays - return cls(*metadata) - def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: inverse = [-1] * len(permutation) @@ -257,7 +247,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(self.permutation) -@tree_util.register_pytree_node_class +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class TransposeRef(state_types.Transform): permutation: tuple[int, ...] @@ -287,14 +277,6 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) - def tree_flatten(self): - return (), (self.permutation,) - - @classmethod - def tree_unflatten(cls, metadata, arrays): - assert not arrays - return cls(*metadata) - def transpose_ref( ref: pallas_core.TransformedRef | Any, @@ -345,10 +327,10 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: return aval -@tree_util.register_pytree_node_class +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class UnswizzleRef(state_types.Transform): - swizzle: int + swizzle: int = dataclasses.field(metadata=dict(static=True)) def untransform_index( self, idxs: tuple[Index, ...] @@ -369,14 +351,6 @@ def untransform_index( raise ValueError("Swizzled dims cannot be sliced") return idxs, self - def tree_flatten(self): - return (), (self.swizzle,) - - @classmethod - def tree_unflatten(cls, metadata, arrays): - assert not arrays - return cls(*metadata) - @dataclasses.dataclass class GPUBlockSpec(pallas_core.BlockSpec): From bdadc53ebcd40a5091d66d2586deba82fe5e01ca Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 4 Dec 2024 09:50:45 -0800 Subject: [PATCH 601/698] Disable JaxAotTest.test_topology_pjit_serialize on GPU, which fails in CI. PiperOrigin-RevId: 702759889 --- tests/aot_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/aot_test.py b/tests/aot_test.py index bca0d66ed384..194982e046ba 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -62,6 +62,7 @@ def verify_serialization(lowered): jax.pmap(lambda x: x * x).lower( np.zeros((len(jax.devices()), 4), dtype=np.float32))) + @jtu.skip_on_devices('gpu') # Test fails in CI def test_topology_pjit_serialize(self): try: aot_topo = topologies.get_topology_desc( From 681b9c2ebe1a4b09ff7c9149e26aafc282507a84 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 4 Dec 2024 09:51:53 -0800 Subject: [PATCH 602/698] Disable pgle_test on non-GPU platforms. PGLE is only intended to work on GPU. PiperOrigin-RevId: 702760248 --- tests/pgle_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pgle_test.py b/tests/pgle_test.py index ef91c399db16..fb144cacbc98 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -46,6 +46,9 @@ class PgleTest(jtu.JaxTestCase): def setUp(self): super().setUp() + if not jtu.test_device_matches(["gpu"]): + self.skipTest('Profile-guideded latency estimation only supported on GPU') + cc.set_cache_dir(None) cc.reset_cache() From 653f65452d4a0b793b8e66f4a81057426f9648b0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 4 Dec 2024 09:58:45 -0800 Subject: [PATCH 603/698] Fix the broken behavior of not resetting the abstract_mesh and device_context properly during `__exit__`. PiperOrigin-RevId: 702762477 --- jax/_src/mesh.py | 46 +++++++++------------------------------------- jax/_src/pjit.py | 5 +++-- jax/_src/stages.py | 4 ++-- 3 files changed, 14 insertions(+), 41 deletions(-) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index a83287d5ecad..d7a87b6d6b33 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -455,10 +455,15 @@ def local_mesh(self): _raise_value_error("local_mesh") def __enter__(self): - return push_abstract_mesh_context(self) + abstract_mesh_context.stack.append(self) + abstract_mesh_context.mesh = self + jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh) + return self def __exit__(self, exc_type, exc_value, traceback): - pop_abstract_mesh_context() + abstract_mesh_context.stack.pop() + abstract_mesh_context.mesh = abstract_mesh_context.stack[-1] + jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh) return False @staticmethod @@ -480,35 +485,6 @@ def __init__(self): abstract_mesh_context = AbstractMeshContext() -def push_abstract_mesh_context(val): - abstract_mesh_context.stack.append(val) - abstract_mesh_context.mesh = val - # TODO(yashkatariya): Allow setting empty tuples and tuples with None in them. - # Right now that leads to weird numerical issues. - non_none_meshes = tuple(m for m in abstract_mesh_context.stack - if m is not None) - if non_none_meshes: - jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) - return val - -def pop_abstract_mesh_context(): - abstract_mesh_context.stack.pop() - abstract_mesh_context.mesh = abstract_mesh_context.stack[-1] - non_none_meshes = tuple(m for m in abstract_mesh_context.stack - if m is not None) - if non_none_meshes: - jax_config.abstract_mesh_context_manager.set_local(non_none_meshes) - - -class null_mesh_context: - - def __enter__(self): - return push_abstract_mesh_context(None) - - def __exit__(self, *excinfo): - pop_abstract_mesh_context() - return False - @contextlib.contextmanager def set_mesh(mesh: Mesh): @@ -529,14 +505,10 @@ def __init__(self): def enter_device_context(mesh: Mesh): device_context.stack.append(mesh) device_context.concrete_mesh = mesh - non_none_meshes = tuple(m for m in device_context.stack if m is not None) - if non_none_meshes: - jax_config.device_context.set_local(non_none_meshes) + jax_config.device_context.set_local(device_context.concrete_mesh) try: yield finally: device_context.stack.pop() device_context.concrete_mesh = device_context.stack[-1] - non_none_meshes = tuple(m for m in device_context.stack if m is not None) - if non_none_meshes: - jax_config.device_context.set_local(non_none_meshes) + jax_config.device_context.set_local(device_context.concrete_mesh) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 7f29d745e48d..462be851efd7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -16,6 +16,7 @@ from collections import defaultdict from collections.abc import Callable, Sequence, Iterable +import contextlib import dataclasses from functools import partial import inspect @@ -695,7 +696,7 @@ def _infer_params_impl( def get_abstract_mesh(in_avals): if not config.sharding_in_types.value: - return mesh_lib.null_mesh_context() + return contextlib.nullcontext() m = None for a in in_avals: # TODO(yashkatariya): Remove this when mesh context can be set by the user. @@ -708,7 +709,7 @@ def get_abstract_mesh(in_avals): m = a.sharding.mesh # type: ignore # TODO(yashkatariya): Remove this when mesh context can be set by the user. if m is None: - return mesh_lib.null_mesh_context() + return contextlib.nullcontext() assert isinstance(m, AbstractMesh) return m diff --git a/jax/_src/stages.py b/jax/_src/stages.py index cc89a3338313..b6f3b63d3de4 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,6 +30,7 @@ """ from __future__ import annotations +import contextlib import functools from collections.abc import Sequence from dataclasses import dataclass @@ -43,7 +44,6 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src import mesh as mesh_lib from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.layout import Layout from jax._src.interpreters import mlir @@ -717,7 +717,7 @@ class Traced(Stage): "_args_flat", "_arg_names", "_num_consts"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, abstract_mesh=mesh_lib.null_mesh_context(), + lower_callable, abstract_mesh=contextlib.nullcontext(), args_flat=None, arg_names=None, num_consts: int = 0): self.jaxpr = jaxpr self.args_info = args_info From 721b517e998f0a123ba550bf11481865a2342970 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 4 Dec 2024 10:11:30 -0800 Subject: [PATCH 604/698] [Pallas] Update changelog for `pl.estimate_cost` PiperOrigin-RevId: 702767883 --- docs/pallas/CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index bd86741c9165..2687cbc909fb 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -20,6 +20,11 @@ Remember to align the itemized text with the first line of an item within a list {func}`jax.experimental.tpu.run_scoped`. Both are now available in {mod}`jax.experimental.pallas`. +* New functionality + + * Added a cost estimate tool {func}`pl.estimate_cost` for automatically + constructing a kernel cost estimate from a JAX reference function. + ## Released with jax 0.4.34 (October 4, 2024) * Changes From 8563449ac32df52cd7b1f3e17d90f7e64cb0dd59 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 4 Dec 2024 11:56:23 -0800 Subject: [PATCH 605/698] CI: update array-api-tests to latest commit --- .github/workflows/jax-array-api.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 8f2029eb9191..54a2bf469a38 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -28,7 +28,7 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20 + ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} From 9e2708eb57ab1810ee576d7da6d489c64bb995ad Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 4 Dec 2024 12:11:54 -0800 Subject: [PATCH 606/698] [sharding_in_types] Use `set_mesh` API to trigger sharding_in_types instead of the config option. PiperOrigin-RevId: 702814257 --- jax/_src/interpreters/pxla.py | 11 +-- jax/_src/pjit.py | 11 +-- jax/_src/test_util.py | 11 +++ tests/pjit_test.py | 160 +++++++++++++++------------------- 4 files changed, 88 insertions(+), 105 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 98cfbbb1d589..41f91202e1de 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2193,15 +2193,8 @@ def lower_sharding_computation( assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) - if config.sharding_in_types.value: - # TODO(yashkatariya): Thread it via jit path and remove the None check by - # making tests go via set_mesh API always. - devices_from_context = ( - None if mesh_lib.device_context.concrete_mesh is None - else mesh_lib.device_context.concrete_mesh._flat_devices_tuple) - else: - devices_from_context = (None if context_mesh is None or context_mesh.empty - else context_mesh._flat_devices_tuple) + devices_from_context = (None if context_mesh is None or context_mesh.empty + else context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. unique_intermediate_shardings = util.stable_unique( diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 462be851efd7..1f8378d8487e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -707,9 +707,6 @@ def get_abstract_mesh(in_avals): f'Mesh for all inputs should be equal. Got one mesh: {m} and' f' another mesh: {a.sharding.mesh}') m = a.sharding.mesh # type: ignore - # TODO(yashkatariya): Remove this when mesh context can be set by the user. - if m is None: - return contextlib.nullcontext() assert isinstance(m, AbstractMesh) return m @@ -1791,8 +1788,12 @@ def _pjit_lower( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - mesh, api_name = ((resource_env.physical_mesh, 'pjit') - if resource_env is not None else (None, 'jit')) + if config.sharding_in_types.value: + mesh = mesh_lib.device_context.concrete_mesh + api_name = 'jit' + else: + mesh, api_name = ((resource_env.physical_mesh, 'pjit') + if resource_env is not None else (None, 'jit')) return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c639ebd03586..0bd5c7b139a1 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -51,6 +51,7 @@ from jax._src import pjit as pjit_lib from jax._src import stages from jax._src import xla_bridge +from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir from jax._src.interpreters import pxla @@ -1442,6 +1443,16 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) +def with_user_mesh(sizes, names): + def decorator(fn): + def mesh_fn(*args, **kwargs): + mesh = create_mesh(sizes, names) + with mesh_lib.set_mesh(mesh): + return fn(*args, **kwargs, mesh=mesh) + return mesh_fn + return decorator + + def create_mesh(mesh_shape, axis_names, iota_order=False): size = math.prod(mesh_shape) if len(jax.devices()) < size: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index af81b35570d6..52261ef025bd 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4630,38 +4630,16 @@ def f(x): ins, _ = f.lower(np.arange(8)).compile().input_shardings self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) - def test_sharding_in_types_with_set_mesh(self): - if config.use_shardy_partitioner.value: - self.skipTest("ShiT doesn't work with shardy") - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - with mesh_lib.set_mesh(mesh): - np_inp = np.arange(16.).reshape(8, 2) - s = NamedSharding(mesh, P('x', 'y')) - arr = jax.device_put(np_inp, s) - - @jax.jit - def f(x): - self.assertEqual(x.sharding.spec, s.spec) - x = x * 2 - self.assertEqual(x.sharding.spec, s.spec) - x = x * x - self.assertEqual(x.sharding.spec, s.spec) - return x - - out = f(arr) - self.assertEqual(out.sharding, s) - self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) - def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") -@jtu.with_config(jax_sharding_in_types=True, jax_use_shardy_partitioner=False) +@jtu.with_config(jax_use_shardy_partitioner=False) class ShardingInTypesTest(jtu.JaxTestCase): - def test_basic_mul(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_basic_mul(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4696,8 +4674,8 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - def test_fully_replicated_array_mul(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_fully_replicated_array_mul(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp1, s) @@ -4745,8 +4723,8 @@ def g(x, y): ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), ('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce') ) - def test_dot_general(self, spec1, spec2, out_spec, collective_name): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) @@ -4781,8 +4759,8 @@ def g(x, y): self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) - def test_dot_general_out_type(self): - mesh = jtu.create_mesh((4,), ('x',)) + @jtu.with_user_mesh((4,), ('x',)) + def test_dot_general_out_type(self, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) @@ -4824,8 +4802,8 @@ def f(x, y): "dot_general requires contracting dimensions to have consistent sharding", TypeError), ) - def test_dot_general_error(self, spec1, spec2, error_msg, error_type): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) @@ -4837,8 +4815,8 @@ def f(x, y): with self.assertRaisesRegex(error_type, error_msg): f(arr1, arr2) - def test_dot_general_batch_error(self): - mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_dot_general_batch_error(self, mesh): arr1 = jax.device_put(np.ones((8, 4, 2)), NamedSharding(mesh, P('x', 'y', 'z'))) arr2 = jax.device_put(np.ones((8, 2, 4)), @@ -4856,9 +4834,8 @@ def test_dot_general_batch_error(self): ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) - def test_aval_repr(self): - mesh = jtu.create_mesh((2, 2), ('model', 'data')) - + @jtu.with_user_mesh((2, 2), ('model', 'data')) + def test_aval_repr(self, mesh): aval = core.ShapedArray((128, 64), np.float32, sharding=NamedSharding(mesh, P('model', 'data'))) self.assertEqual(aval.str_short(), 'float32[128@model,64@data]') @@ -4876,14 +4853,14 @@ def test_aval_repr(self): self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]') @parameterized.named_parameters( - ('all', None, P('x', 'y'), P()), - ('first', 0, P('x', 'y'), P('y')), - ('second', 1, P('x', 'y'), P('x')), - ('first2', 0, P(('x', 'y'), None), P(None)), + ('all', None, P('x', 'y'), P(), True), + ('first', 0, P('x', 'y'), P('y'), True), + ('second', 1, P('x', 'y'), P('x'), True), + ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, in_spec) arr = jax.device_put(np_inp, s) @@ -4907,14 +4884,14 @@ def f(x): self.assertIn('all-reduce', compiled_text) @parameterized.named_parameters( - ('all', None, P('x', 'y'), P()), - ('first', 0, P('x', 'y'), P('y')), - ('second', 1, P('x', 'y'), P('x')), - ('first2', 0, P(('x', 'y'), None), P(None)), + ('all', None, P('x', 'y'), P(), True), + ('first', 0, P('x', 'y'), P('y'), True), + ('second', 1, P('x', 'y'), P('x'), True), + ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - def test_reduce_max(self, axis, in_spec, out_spec, reduce=True): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, in_spec) arr = jax.device_put(np_inp, s) @@ -4954,8 +4931,8 @@ def g(x): ('2', 2, P('x', 'y', None)), ('-1', -1, P('x', 'y', None)), ) - def test_broadcast_in_dim(self, axis, out_spec): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_broadcast_in_dim(self, axis, out_spec, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4980,8 +4957,8 @@ def f(x): ('3', 3), ('4', 4), ) - def test_integer_pow(self, pow): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_integer_pow(self, pow, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5010,12 +4987,13 @@ def test_broadcasting_nary_error(self): def f(x, y): return x + y - with self.assertRaisesRegex( - ValueError, "Mesh for all inputs should be equal"): - f(arr1, arr2) + with config.sharding_in_types(True): + with self.assertRaisesRegex( + ValueError, "Mesh for all inputs should be equal"): + f(arr1, arr2) - def test_sin_unop(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5032,8 +5010,8 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_jnp_array(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_jnp_array(self, mesh): np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5048,8 +5026,8 @@ def f(x): f(arr) - def test_lax_transpose_rule(self): - mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_lax_transpose_rule(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', 'y', 'z')) arr = jax.device_put(np_inp, s) @@ -5067,8 +5045,8 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_broadcasted_iota_with_sharding(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_broadcasted_iota_with_sharding(self, mesh): np_inp = np.arange(4) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np_inp, s) @@ -5094,8 +5072,8 @@ def g(x): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - def test_einsum_with_out_type(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_einsum_with_out_type(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -5140,8 +5118,8 @@ def h2(x, y): self.assertEqual(out[0].sharding, arr3.sharding) self.assertEqual(out[1].sharding, arr4.sharding) - def test_einsum_inverse(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_einsum_inverse(self, mesh): np_inp = np.arange(64.) @jax.jit @@ -5179,9 +5157,9 @@ def h2(x, y): ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), ('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True) ) + @jtu.with_user_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, - use_sharding_arg): - mesh = jtu.create_mesh((2,), ('x',)) + use_sharding_arg, mesh): np_inp = np.arange(math.prod(src_shape), dtype=np.float32).reshape(src_shape) arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) @@ -5209,8 +5187,8 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - def test_select(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_select(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp, s) @@ -5234,8 +5212,8 @@ def f(pred, on_true, on_false): TypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) - def test_device_put_reshard(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_device_put_reshard(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5250,8 +5228,8 @@ def f(x): self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - def test_shard_map_full_manual(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_shard_map_full_manual(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5275,8 +5253,8 @@ def f(x, y): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - def test_shard_map_dot(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_shard_map_dot(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -5302,8 +5280,8 @@ def f(x, y): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - def test_slice(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_slice(self, mesh): np_inp = np.arange(16.).reshape(4, 4) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5333,8 +5311,8 @@ def g(x): with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) - def test_squeeze(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_squeeze(self, mesh): np_inp = np.arange(16.).reshape(4, 4, 1) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) @@ -5359,8 +5337,8 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - def test_pad(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_pad(self, mesh): np_inp = np.arange(8.) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) @@ -5401,8 +5379,8 @@ def g(x): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) - def test_concatenate(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) + @jtu.with_user_mesh((2, 1), ('x', 'y')) + def test_concatenate(self, mesh): np_inp = np.arange(16.).reshape(4, 4) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp, s) @@ -5443,8 +5421,8 @@ def g(x, y): out = jax.jit(jax.grad(g))(arr1, arr2) self.assertEqual(out.sharding, s) - def test_scan(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_scan(self, mesh): carry = jax.device_put(np.arange(16.).reshape(2, 8), NamedSharding(mesh, P(None, 'x'))) arr = jax.device_put(np.arange(128.).reshape(8, 8, 2), @@ -5481,8 +5459,8 @@ def g(carry, arr): ValueError, "0th dimension of all xs should be replicated"): f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) - def test_argminmax(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_argminmax(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) From 1a3c9c44dc95ecc00545eff35ce3ed91977690ea Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 4 Dec 2024 12:45:51 -0800 Subject: [PATCH 607/698] [Pallas] Fix type annotation on TritonCompilerParams PiperOrigin-RevId: 702825912 --- jax/_src/pallas/triton/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py index a61dfd61b9b1..097f8497e8f7 100644 --- a/jax/_src/pallas/triton/core.py +++ b/jax/_src/pallas/triton/core.py @@ -35,4 +35,4 @@ class TritonCompilerParams(pallas_core.CompilerParams): PLATFORM: ClassVar[str] = "triton" num_warps: int | None = None num_stages: int | None = None - serialized_metadata: str | None = None + serialized_metadata: bytes | None = None From a735bf83e5e956895587b50bf7bef7f79d1fca0d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 4 Dec 2024 14:03:45 -0800 Subject: [PATCH 608/698] Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py PiperOrigin-RevId: 702852769 --- jax/_src/config.py | 8 ++++- jax/_src/core.py | 4 +-- jax/_src/mesh.py | 55 ++++++++++++----------------------- jax/_src/pjit.py | 16 +++++----- jax/_src/stages.py | 6 ++-- jax/experimental/shard_map.py | 4 +-- tests/pjit_test.py | 4 +-- 7 files changed, 43 insertions(+), 54 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index eea686fa0f3c..b0214a8c8dcb 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1079,8 +1079,9 @@ class JitConfig: def __init__(self, name): self._name = name + @property def value(self): - return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) + return self.get_local() def get_local(self): return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) @@ -1088,6 +1089,11 @@ def get_local(self): def set_local(self, value): update_thread_local_jit_state(**{self._name: value}) + def swap_local(self, new_value): + prev_value = self.value + self.set_local(new_value) + return prev_value + trace_state = JitConfig('trace_state') axis_env_state = JitConfig('axis_env_state') mesh_context_manager = JitConfig('mesh_context_manager') diff --git a/jax/_src/core.py b/jax/_src/core.py index 122b7bcf5eb2..30893ce99ce4 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1605,10 +1605,10 @@ def get_sharding(sharding, ndim): assert len(sharding.spec) == ndim return sharding - context_mesh = mesh_lib.abstract_mesh_context.mesh + context_mesh = mesh_lib.get_abstract_mesh() # TODO(yashkatariya): Error out and ask users to set the context mesh in their # code. - if context_mesh is None: + if not context_mesh: return None assert sharding is None return NamedSharding(context_mesh, P(*[None] * ndim)) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index d7a87b6d6b33..87b931683215 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -454,18 +454,6 @@ def local_devices(self): def local_mesh(self): _raise_value_error("local_mesh") - def __enter__(self): - abstract_mesh_context.stack.append(self) - abstract_mesh_context.mesh = self - jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh) - return self - - def __exit__(self, exc_type, exc_value, traceback): - abstract_mesh_context.stack.pop() - abstract_mesh_context.mesh = abstract_mesh_context.stack[-1] - jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh) - return False - @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): jax_config.abstract_mesh_context_manager.set_local(mesh) @@ -478,37 +466,32 @@ def _raise_value_error(name): raise ValueError(f"AbstractMesh does not implement {name}") -class AbstractMeshContext(threading.local): - def __init__(self): - self.stack = [None] - self.mesh = self.stack[-1] +@contextlib.contextmanager +def set_abstract_mesh(mesh: AbstractMesh): + prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh) + try: + yield + finally: + jax_config.abstract_mesh_context_manager.set_local(prev_val) -abstract_mesh_context = AbstractMeshContext() +def get_abstract_mesh(): + return jax_config.abstract_mesh_context_manager.value @contextlib.contextmanager -def set_mesh(mesh: Mesh): - with (mesh.abstract_mesh, jax_config.sharding_in_types(True), - enter_device_context(mesh)): +def set_concrete_mesh(mesh: Mesh): + prev_val = jax_config.device_context.swap_local(mesh) + try: yield + finally: + jax_config.device_context.set_local(prev_val) - -class DeviceContext(threading.local): - def __init__(self): - self.stack = [None] - self.concrete_mesh = self.stack[-1] - -device_context = DeviceContext() +def get_concrete_mesh(): + return jax_config.device_context.value @contextlib.contextmanager -def enter_device_context(mesh: Mesh): - device_context.stack.append(mesh) - device_context.concrete_mesh = mesh - jax_config.device_context.set_local(device_context.concrete_mesh) - try: +def set_mesh(mesh: Mesh): + with (set_abstract_mesh(mesh.abstract_mesh), + jax_config.sharding_in_types(True), set_concrete_mesh(mesh)): yield - finally: - device_context.stack.pop() - device_context.concrete_mesh = device_context.stack[-1] - jax_config.device_context.set_local(device_context.concrete_mesh) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 1f8378d8487e..6c632c6fc51e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -16,7 +16,6 @@ from collections import defaultdict from collections.abc import Callable, Sequence, Iterable -import contextlib import dataclasses from functools import partial import inspect @@ -187,7 +186,7 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs): try: # TODO(yashkatariya): Maybe thread this into pjit params like resource_env # and set the context manager down the stack? - with p.abstract_mesh: + with mesh_lib.set_abstract_mesh(p.abstract_mesh): if (core.trace_state_clean() and not config.debug_key_reuse.value and not config.data_dependent_tracing_fallback.value): @@ -645,9 +644,9 @@ def _infer_params_impl( attr_token = _attr_token(flat_fun, in_type) abstract_mesh = ( - get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None - else mesh_lib.abstract_mesh_context.mesh) - with abstract_mesh: + get_abstract_mesh_from_avals(in_type) + if not mesh_lib.get_abstract_mesh() else mesh_lib.get_abstract_mesh()) + with mesh_lib.set_abstract_mesh(abstract_mesh): jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, dbg, HashableFunction(res_paths, closure=()), @@ -694,9 +693,9 @@ def _infer_params_impl( attrs_tracked, abstract_mesh), args_flat -def get_abstract_mesh(in_avals): +def get_abstract_mesh_from_avals(in_avals): if not config.sharding_in_types.value: - return contextlib.nullcontext() + return None m = None for a in in_avals: # TODO(yashkatariya): Remove this when mesh context can be set by the user. @@ -1789,7 +1788,8 @@ def _pjit_lower( lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): if config.sharding_in_types.value: - mesh = mesh_lib.device_context.concrete_mesh + cur_mesh = mesh_lib.get_concrete_mesh() + mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None api_name = 'jit' else: mesh, api_name = ((resource_env.physical_mesh, 'pjit') diff --git a/jax/_src/stages.py b/jax/_src/stages.py index b6f3b63d3de4..db26813de8bc 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,7 +30,6 @@ """ from __future__ import annotations -import contextlib import functools from collections.abc import Sequence from dataclasses import dataclass @@ -44,6 +43,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util +from jax._src import mesh as mesh_lib from jax._src.sharding_impls import UnspecifiedValue, AUTO from jax._src.layout import Layout from jax._src.interpreters import mlir @@ -717,7 +717,7 @@ class Traced(Stage): "_args_flat", "_arg_names", "_num_consts"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, abstract_mesh=contextlib.nullcontext(), + lower_callable, abstract_mesh=None, args_flat=None, arg_names=None, num_consts: int = 0): self.jaxpr = jaxpr self.args_info = args_info @@ -747,7 +747,7 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, try: # TODO(yashkatariya): Maybe thread this into pjit params like resource_env # and set the context manager down the stack? - with self._abstract_mesh: + with mesh_lib.set_abstract_mesh(self._abstract_mesh): lowering = new_callable() except pxla.DeviceAssignmentMismatchError as e: fails, = e.args diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 07f631f6ec49..1c529b8938f1 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -46,7 +46,7 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import AbstractMesh, Mesh, AxisTypes +from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, @@ -484,7 +484,7 @@ def _shard_map_staging( in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) with (core.extend_axis_env_nd(list(mesh.shape.items())), - pjit.get_abstract_mesh(in_avals_)): + set_abstract_mesh(pjit.get_abstract_mesh_from_avals(in_avals_))): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 52261ef025bd..7e66c41cca76 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5237,7 +5237,7 @@ def test_shard_map_full_manual(self, mesh): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) - self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective) return x * y @jax.jit @@ -5262,7 +5262,7 @@ def test_shard_map_dot(self, mesh): def g(x, y): self.assertTrue(x.sharding.mesh._are_all_axes_collective) self.assertTrue(y.sharding.mesh._are_all_axes_collective) - self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective) + self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y return jax.lax.psum(z, axis_name='y') From 5ade371c88a1f879556ec29867b173da49ae57f0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 4 Dec 2024 14:12:14 -0800 Subject: [PATCH 609/698] Remove obsolete deprecation Followup to https://github.com/jax-ml/jax/pull/24748, which removed the deprecated `reshape` parameter. PiperOrigin-RevId: 702856398 --- jax/_src/deprecations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index c7a956068981..33a813340be8 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -125,7 +125,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register("jax-numpy-astype-complex-to-real") register("jax-numpy-array-none") -register('jax-numpy-reshape-newshape') register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') From e05afefc97296a8c4fcdce3aba07d08d8392af81 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 4 Dec 2024 14:27:11 -0800 Subject: [PATCH 610/698] [Pallas] Pallas documentation cleanup --- docs/pallas/{ => design}/async_note.md | 0 docs/pallas/{ => design}/design.md | 8 ++--- docs/pallas/design/index.rst | 9 +++++ docs/pallas/grid_blockspec.md | 36 ++------------------ docs/pallas/index.rst | 5 ++- docs/pallas/quickstart.ipynb | 47 ++++++++++++++++---------- docs/pallas/quickstart.md | 45 ++++++++++++++---------- docs/pallas/tpu/pipelining.ipynb | 10 +++++- docs/pallas/tpu/pipelining.md | 4 +++ 9 files changed, 88 insertions(+), 76 deletions(-) rename docs/pallas/{ => design}/async_note.md (100%) rename docs/pallas/{ => design}/design.md (99%) create mode 100644 docs/pallas/design/index.rst diff --git a/docs/pallas/async_note.md b/docs/pallas/design/async_note.md similarity index 100% rename from docs/pallas/async_note.md rename to docs/pallas/design/async_note.md diff --git a/docs/pallas/design.md b/docs/pallas/design/design.md similarity index 99% rename from docs/pallas/design.md rename to docs/pallas/design/design.md index f6fc8f5926cb..17c7a6dbdc0f 100644 --- a/docs/pallas/design.md +++ b/docs/pallas/design/design.md @@ -94,7 +94,7 @@ Pallas kernels via JAX transformations.
-![Pallas lowering path](../_static/pallas/pallas_flow.png) +![Pallas lowering path](../../_static/pallas/pallas_flow.png) Visualization of Pallas lowering paths
@@ -413,10 +413,10 @@ verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU. -### Examples +### GPU Examples -Note all the following examples are for GPU only. They will require some small -changes to work on TPUs. +Note all the following examples are for GPU only. They will require tweaks to +the block sizes to work on TPUs. #### `add` diff --git a/docs/pallas/design/index.rst b/docs/pallas/design/index.rst new file mode 100644 index 000000000000..d11a13d39fe8 --- /dev/null +++ b/docs/pallas/design/index.rst @@ -0,0 +1,9 @@ +Pallas Design Notes +=================== + +.. toctree:: + :caption: Design + :maxdepth: 2 + + design + async_note diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index cde200528785..c1b2c2b95229 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -44,39 +44,7 @@ For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and You can also use {func}`jax.experimental.pallas.num_programs` to get the grid size for a given axis. -Here's an example kernel that uses a `grid` and `program_id`. - -```python ->>> import jax ->>> from jax.experimental import pallas as pl - ->>> def iota_kernel(o_ref): -... i = pl.program_id(0) -... o_ref[i] = i - -``` - -We now execute it using `pallas_call` with an additional `grid` argument. - -```python ->>> def iota(size: int): -... return pl.pallas_call(iota_kernel, -... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), -... grid=(size,), interpret=True)() ->>> iota(8) -Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) - -``` - -On GPUs, each program is executed in parallel on separate thread blocks. -Thus, we need to think about race conditions on writes to HBM. -A reasonable approach is to write our kernels in such a way that different -programs write to disjoint places in HBM to avoid these parallel writes. - -On TPUs, programs are executed in a combination of parallel and sequential -(depending on the architecture) so there are slightly different considerations. - -See {ref}`pallas_tpu_noteworthy_properties`. +See {ref}`grids_by_example` for a simple kernel that uses this API. (pallas_blockspec)= @@ -131,6 +99,8 @@ shape `x_shape` are computed as in the function `slice_for_invocation` below: ```python +>>> import jax +>>> from jax.experimental import pallas as pl >>> def slices_for_invocation(x_shape: tuple[int, ...], ... x_spec: pl.BlockSpec, ... grid: tuple[int, ...], diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 5969349c962a..b2e2fca6c82e 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -22,7 +22,6 @@ See also the :class:`jax.experimental.pallas` module API documentation. :maxdepth: 2 quickstart - design grid_blockspec @@ -34,9 +33,9 @@ See also the :class:`jax.experimental.pallas` module API documentation. .. toctree:: :caption: Design Notes - :maxdepth: 1 + :maxdepth: 2 - async_note + design/index .. toctree:: :caption: Other diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index af34d167400b..11dd2108e405 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -72,8 +72,9 @@ "\n", "Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n", "it does not take in `jax.Array`s as inputs and doesn't return any values.\n", - "Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n", - "but we are given an `o_ref`, which corresponds to the desired output.\n", + "Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory.\n", + "Note that we also don't have any outputs but we are given an `o_ref`, which corresponds\n", + "to the desired output.\n", "\n", "**Reading from `Ref`s**\n", "\n", @@ -150,7 +151,8 @@ "**What's actually happening here?**\n", "\n", "Thus far we've described how to think about Pallas kernels but what we've actually\n", - "accomplished is we're writing a function that's executed very close to the compute units.\n", + "accomplished is we're writing a function that's executed very close to the compute units\n", + "since values are loaded into the innermost (fastest) portion of the memory hierarchy.\n", "\n", "On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when\n", "we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM)\n", @@ -195,6 +197,8 @@ "live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n", "that operate on \"blocks\" of those arrays that can fit in SRAM.\n", "\n", + "(grids_by_example)=\n", + "\n", "### Grids by example\n", "\n", "To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n", @@ -240,7 +244,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We now execute it using `pallas_call` with an additional `grid` argument." + "We now execute it using `pallas_call` with an additional `grid` argument.\n", + "On GPUs, we can call the kernel directly like so:" ] }, { @@ -260,6 +265,7 @@ } ], "source": [ + "# GPU version\n", "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", @@ -272,16 +278,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "On GPUs, each program is executed in parallel on separate threads.\n", - "Thus, we need to think about race conditions on writes to HBM.\n", - "A reasonable approach is to write our kernels in such a way that different\n", - "programs write to disjoint places in HBM to avoid these parallel writes.\n", - "On the other hand, parallelizing the computation is how we can execute\n", - "operations like matrix multiplications really quickly.\n", - "\n", - "On TPUs, programs are executed in a combination of parallel and sequential\n", - "(depending on the architecture) so there are slightly different considerations.\n", - "\n", + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n", "To call the above kernel on TPU, run:" ] }, @@ -292,6 +291,7 @@ "metadata": {}, "outputs": [], "source": [ + "# TPU version\n", "from jax.experimental.pallas import tpu as pltpu\n", "\n", "def iota(size: int):\n", @@ -307,11 +307,22 @@ "id": "68f97b4e", "metadata": {}, "source": [ - "TPUs distinguish between vector and scalar memory spaces and in this case the\n", - "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", - "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n", + "### Grid semantics\n", + "\n", + "On GPUs, each program is executed in parallel on separate threads.\n", + "Thus, we need to think about race conditions on writes to HBM.\n", + "A reasonable approach is to write our kernels in such a way that different\n", + "programs write to disjoint locations in HBM to avoid these parallel writes.\n", + "On the other hand, parallelizing the computation is how we can execute\n", + "operations like matrix multiplications really quickly.\n", + "\n", + "In contrast, TPUs operate like a very wide SIMD machine.\n", + "Some TPU models contain multiple cores, but in many cases a TPU can be\n", + "treated as a single-threaded processor. The grid on a TPU can be\n", + "specified in a combination of parallel and sequential dimensions, where sequential\n", + "dimensions are guaranteed to run serially.\n", "\n", - "You can read more details at {ref}`pallas_grid`." + "You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`." ] }, { diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index e11868f5f671..fff1dcb730f3 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -53,8 +53,9 @@ def add_vectors_kernel(x_ref, y_ref, o_ref): Let's dissect this function a bit. Unlike most JAX functions you've probably written, it does not take in `jax.Array`s as inputs and doesn't return any values. -Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs -but we are given an `o_ref`, which corresponds to the desired output. +Instead, it takes in *`Ref`* objects as inputs, which represent mutable buffers in memory. +Note that we also don't have any outputs but we are given an `o_ref`, which corresponds +to the desired output. **Reading from `Ref`s** @@ -101,7 +102,8 @@ thereof). **What's actually happening here?** Thus far we've described how to think about Pallas kernels but what we've actually -accomplished is we're writing a function that's executed very close to the compute units. +accomplished is we're writing a function that's executed very close to the compute units +since values are loaded into the innermost (fastest) portion of the memory hierarchy. On GPU, `x_ref` corresponds to a value in high-bandwidth memory (HBM) and when we do `x_ref[...]` we are copying the value from HBM into static RAM (SRAM) @@ -134,6 +136,8 @@ Part of writing Pallas kernels is thinking about how to take big arrays that live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations that operate on "blocks" of those arrays that can fit in SRAM. +(grids_by_example)= + ### Grids by example To automatically "carve" up the inputs and outputs, you provide a `grid` and @@ -169,8 +173,10 @@ def iota_kernel(o_ref): ``` We now execute it using `pallas_call` with an additional `grid` argument. +On GPUs, we can call the kernel directly like so: ```{code-cell} ipython3 +# GPU version def iota(size: int): return pl.pallas_call(iota_kernel, out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), @@ -178,19 +184,13 @@ def iota(size: int): iota(8) ``` -On GPUs, each program is executed in parallel on separate threads. -Thus, we need to think about race conditions on writes to HBM. -A reasonable approach is to write our kernels in such a way that different -programs write to disjoint places in HBM to avoid these parallel writes. -On the other hand, parallelizing the computation is how we can execute -operations like matrix multiplications really quickly. - -On TPUs, programs are executed in a combination of parallel and sequential -(depending on the architecture) so there are slightly different considerations. - +TPUs distinguish between vector and scalar memory spaces and in this case the +output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +a scalar. For more details read {ref}`tpu_and_its_memory_spaces`. To call the above kernel on TPU, run: ```{code-cell} ipython3 +# TPU version from jax.experimental.pallas import tpu as pltpu def iota(size: int): @@ -201,11 +201,22 @@ def iota(size: int): iota(8) ``` -TPUs distinguish between vector and scalar memory spaces and in this case the -output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is -a scalar. For more details read {ref}`pallas_tpu_pipelining`. +### Grid semantics + +On GPUs, each program is executed in parallel on separate threads. +Thus, we need to think about race conditions on writes to HBM. +A reasonable approach is to write our kernels in such a way that different +programs write to disjoint locations in HBM to avoid these parallel writes. +On the other hand, parallelizing the computation is how we can execute +operations like matrix multiplications really quickly. + +In contrast, TPUs operate like a very wide SIMD machine. +Some TPU models contain multiple cores, but in many cases a TPU can be +treated as a single-threaded processor. The grid on a TPU can be +specified in a combination of parallel and sequential dimensions, where sequential +dimensions are guaranteed to run serially. -You can read more details at {ref}`pallas_grid`. +You can read more details at {ref}`pallas_grid` and {ref}`pallas_tpu_noteworthy_properties`. +++ diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 9774e08dcda8..10de587105f2 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -48,12 +48,20 @@ }, { "cell_type": "markdown", + "id": "0e212a5e", "metadata": { "id": "TWKESTKAlyjT" }, "source": [ - "## TPU and its memory spaces\n", + "(tpu_and_its_memory_spaces)=\n", "\n", + "## TPU and its memory spaces" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n", "registers (which temporarily store scalar and array values) and compute units\n", "(that do computation with values in registers).\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 21865430178d..df570cf0806c 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -38,8 +38,12 @@ import numpy as np +++ {"id": "TWKESTKAlyjT"} +(tpu_and_its_memory_spaces)= + ## TPU and its memory spaces ++++ + A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units (that do computation with values in registers). From 3990e05af798686b79ee2d750b33b5e469f3fd5b Mon Sep 17 00:00:00 2001 From: Naums Mogers Date: Wed, 4 Dec 2024 15:34:23 -0800 Subject: [PATCH 611/698] [Mosaic] Add extra memref_slice verification and a memory space check helper PiperOrigin-RevId: 702883469 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 19 +++++++++++++++++++ jaxlib/mosaic/dialect/tpu/util.cc | 6 ++++++ jaxlib/mosaic/dialect/tpu/util.h | 2 ++ 3 files changed, 27 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 2107cd7fcf82..07e2e3e19197 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -93,6 +93,25 @@ LogicalResult MemRefSliceOp::verify() { auto target_type = getType(); auto target_layout = target_type.getLayout(); auto target_memory_space = target_type.getMemorySpace(); + auto indices = getBaseIdx(); + auto slice_shape = getResult().getType().getShape(); + if (!source_type.hasStaticShape()) { + return emitOpError( + "Only slicing of memrefs with static shapes is supported."); + } + auto source_shape = source_type.getShape(); + bool is_semaphore = + HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem); + if (is_semaphore && + !isa(source_type.getElementType())) { + return emitOpError( + "References to semaphore memory space must have a semaphore element " + "type."); + } + if (indices.size() != slice_shape.size() || + indices.size() != source_shape.size()) { + return emitOpError("Indices and slice shapes must match."); + } // TODO(apaszke): Check that the result has a smaller shape. // TODO(apaszke): Check that strides are equivalent. // Source and target attributes may be different before propagation is done by diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index d434837efaf5..c5f9833761b9 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -141,4 +141,10 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, return *(tiled_layout.getTileStrides().end() - 1) == 1 && *(tiled_layout.getTileStrides().end() - 2) == 1; } + +bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) { + auto memory_space = + dyn_cast_or_null(ty.getMemorySpace()); + return memory_space && memory_space.getValue() == space; +} } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 7c602e9a0bc9..9052afad499a 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -115,6 +115,8 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, const std::array &target_shape, bool allow_minormost_padding = false); +// Determines whether the given MemRefType has the given memory space. +bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space); } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ From 208194f9a557eefa036ca21d63c0e8433c56daa1 Mon Sep 17 00:00:00 2001 From: Loren Maggiore Date: Wed, 4 Dec 2024 15:57:20 -0800 Subject: [PATCH 612/698] context manager methods for AbstractMesh to appease type checker. PiperOrigin-RevId: 702890537 --- jax/_src/mesh.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 87b931683215..c7b8f692055d 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -454,6 +454,12 @@ def local_devices(self): def local_mesh(self): _raise_value_error("local_mesh") + def __enter__(self): + _raise_value_error("__enter__") + + def __exit__(self, exc_type, exc_value, traceback): + _raise_value_error("__exit__") + @staticmethod def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh): jax_config.abstract_mesh_context_manager.set_local(mesh) From 28528d44d36c3419b5b480fad936ffd37f43e042 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Dec 2024 16:23:37 -0800 Subject: [PATCH 613/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/30f22f4d6cb523e035c237f30aeac4e00ae34821. PiperOrigin-RevId: 702898833 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 232d19cfbe26..db34354f42c5 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "05f004e8368c955b872126b1c978c60e33bbc5c8" -XLA_SHA256 = "f0bedada96f5f1d09f5047c7f9db32d460d147bd0f192607cfbbee9fe5ee2d5f" +XLA_COMMIT = "30f22f4d6cb523e035c237f30aeac4e00ae34821" +XLA_SHA256 = "546dc97a5bee684b3baf1c14c00ef6c73f18c717ebb97c000a35f683bf53c244" def repo(): tf_http_archive( From f160df04442a6d7c6edae090fa19d0fb6717091a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Dec 2024 21:05:45 -0800 Subject: [PATCH 614/698] More thorough propagation of host linear layout. Currently linear layout on host can only originate from entry computation. Propagation only goes strickly down/up. More needs to be done later if such layout can original from host compute itself. Removed the temporary pattern match solution. PiperOrigin-RevId: 702966364 --- tests/memories_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index ca676a2b1993..9c9b3a4ad2bf 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1578,7 +1578,7 @@ def test_fn(x_in, y_in): test_fn, out_shardings=( Layout(custom_dll, sharding), - Layout(custom_dll, p_sharding), + Layout(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) @@ -1621,7 +1621,7 @@ def test_fn(x_in, y_in): test_fn, out_shardings=( Layout(custom_dll, sharding), - Layout(custom_dll, p_sharding), + Layout(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) From 101168740ea0ef8cd9e6bce124d7ebe82cb834f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 4 Dec 2024 21:06:23 -0800 Subject: [PATCH 615/698] [Mosaic:TPU] Lift offset restrictions on single-row (1, 128) -> (8, 128) 32-bit replicated retiling PiperOrigin-RevId: 702966495 --- .../tpu/transforms/apply_vector_layout.cc | 48 ++++++++++++------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index cac197479f4a..c0b6df203be9 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6116,35 +6116,49 @@ FailureOr>> changeTiling( } const int packing = src.packing(); const int8_t bitwidth = src.bitwidth(); - VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, - src.implicit_dim()); - if (!dst.isValid(target_shape)) { - return emitError(loc, "Not implemented: invalid offsets in tiling target"); - } - auto dst_tiles_shape = - dst.tileArrayImplicitShape(vty.getShape(), target_shape); // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating // sublanes. if (try_replicate_rows && packing == 1 && *(vregs.dimensions().end() - 2) == 1 && - src.offsets() == LayoutOffsets{0, 0} && src.tiling() == std::array{1, ctx.target_shape[1]} && dst_tiling == ctx.target_shape) { - xla::Array retiled(dst_tiles_shape); + DCHECK_EQ(src.offsets()[0].value_or(0), 0); + const LayoutOffset dst_minor_offset = + src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1]) + : std::nullopt; + const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset}, + dst_tiling, src.implicit_dim()); + xla::Array retiled( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); retiled.Each([&](absl::Span idx, Value *tile) { SmallVector src_idx(idx.begin(), idx.end()); *(src_idx.end() - 2) *= target_shape[0]; - *(src_idx.end() - 1) /= target_shape[0]; - const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0]; - CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); - *tile = - broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); + if (!src.offsets()[1].has_value()) { + // With (1, 128) tiling each vreg holds values from a single row. This + // means that if the columns are replicated, then the whole vreg is + // already replicated. + *(src_idx.end() - 1) = 0; + *tile = vregs(src_idx); + } else { + // The column (in units of sublanes) of the sublane we want: + const int64_t sublane_column = + *(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1]; + *(src_idx.end() - 1) = sublane_column / target_shape[0]; + const int64_t src_sl_idx = sublane_column % target_shape[0]; + *tile = + broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); + } }); - // We have successfully replicated sublanes. - dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, - dst.implicit_dim()); + // We have successfully replicated sublanes return std::pair(dst, std::move(retiled)); } + VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling, + src.implicit_dim()); + if (!dst.isValid(target_shape)) { + return emitError(loc, "Not implemented: invalid offsets in tiling target"); + } + auto dst_tiles_shape = + dst.tileArrayImplicitShape(vty.getShape(), target_shape); // (8,128) -> (8 * packing,128) tiling change for packed type. if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape && dst_tiling == std::array{ctx.target_shape[0] * dst.packing(), From 6172a1f1d5434914bae45660ec04bdba1bc1d7f9 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 5 Dec 2024 05:44:40 +0000 Subject: [PATCH 616/698] remove vestigial ad.reducing_transposes table these were an xmap / avals-with-names named axis thing, but that stuff is gone so we can simplify --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/interpreters/ad.py | 8 +++----- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/pjit.py | 2 +- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index fc135ac8f28c..93376c7bd170 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -652,7 +652,7 @@ def remat_transpose(out_cts, *in_primals, jaxpr, **params): for x in in_primals] assert next(in_cts_nz_, None) is next(in_zeros_, None) is None return in_cts -ad.reducing_transposes[remat_p] = remat_transpose +ad.primitive_transposes[remat_p] = remat_transpose # TODO(mattjj): move this to ad.py def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool], diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 804df185d4ed..c5e78321f331 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -277,9 +277,6 @@ def write_primal(v, val): call_jaxpr = params.pop('call_jaxpr') cts_out = get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals) - elif eqn.primitive in reducing_transposes: - cts_out = reducing_transposes[eqn.primitive]( - cts_in, *invals, **eqn.params) else: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) @@ -586,8 +583,6 @@ def to_concrete_value(self): primitive_jvps : dict[core.Primitive, Callable] = {} primitive_transposes: dict[core.Primitive, Callable] = {} -# transpose rules that internally perform reductions over the given named axes -reducing_transposes: dict[core.Primitive, Callable] = {} primitive_linearizations : dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): @@ -871,3 +866,6 @@ def __init__(self): "closed-over value into the custom_vjp function as an argument, and " "adapting the custom_vjp fwd and bwd rules.") super().__init__(msg) + +# TODO(mattjj): remove this vestigial dict +reducing_transposes: dict[core.Primitive, Callable] = {} diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 7a5e9e9dee8a..418240a4a86e 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -780,7 +780,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches): cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) ad.primitive_jvps[cond_p] = _cond_jvp -ad.reducing_transposes[cond_p] = _cond_transpose +ad.primitive_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule xla.register_initial_style_primitive(cond_p) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 76132ccdc99a..f62ce2434755 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1228,7 +1228,7 @@ def arrange_jaxpr_args_for_wrapped(args): scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp -ad.reducing_transposes[scan_p] = _scan_transpose +ad.primitive_transposes[scan_p] = _scan_transpose pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6c632c6fc51e..196dd8b014ae 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2385,7 +2385,7 @@ def prune_type(ty, xs, maybe_zeros): _set_states(attrs_tracked, final_states) return tree_unflatten(cts_out_treedef, nz_cts_out) -ad.reducing_transposes[pjit_p] = _pjit_transpose +ad.primitive_transposes[pjit_p] = _pjit_transpose @weakref_lru_cache From 8163e74e453f71922ea804b11f1acd635dfd4909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 4 Dec 2024 23:12:52 -0800 Subject: [PATCH 617/698] [Mosaic:TPU] Add relayout for adding minor implicit dim and relax some offset restrictions on similar shape cast This factors out some logic from the apply-vector-layout shape cast rule where we insert a minor dimension, relaxes some offset restrictions on it, and uses it for the relayout. PiperOrigin-RevId: 702993092 --- jaxlib/mosaic/dialect/tpu/tpu.td | 7 +- .../tpu/transforms/apply_vector_layout.cc | 177 +++++++++++++----- 2 files changed, 135 insertions(+), 49 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 5dad1309ae91..c5142f48dc1d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -342,8 +342,13 @@ def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { } def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { + let description = [{ + For each sublane `i`, broadcasts the value in lane `lane + i` along the entire + sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i` + is not defined (can be anything). + }]; let arguments = (ins - AnyVectorOfNonZeroRank:$source, // All sublanes should be equal. + TPU_Vreg:$source, // All sublanes should be equal. I32Attr:$lane // Coordinates of the first element to take. ); // Output shape should be the same, except for position dim which contains diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c0b6df203be9..50a7d57346a6 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -655,6 +655,105 @@ FailureOr> getInLayouts( return in_layouts; } +// Insert a minor dimension to the implicit shape. The original minor dimension +// becomes the new second minor dimension, laid out across sublanes. +// +// The returned vreg array uses the original tiling and the offsets specified in +// new_offsets to hold the value with the new implicit shape. +// +// Args: +// vregs: The vreg array with *implicit* array shape. +// ishape: The implicit shape of the represented value. +// layout: The layout used for the represented value. The implicit +// dimension is ignored, since this function operates directly at +// the level of the implicit shape. +// new_offsets: The offsets to use for the layout of the returned vreg array. +FailureOr> insertImplicitMinorDimension( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + const xla::Array &vregs, const ArrayRef ishape, + const VectorLayout &layout, const LayoutOffsets new_offsets) { + if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) { + return emitError(loc, "Not implemented: Unsupported bitwidth or tiling"); + } + if (layout.offsets()[1].has_value()) { + if (!new_offsets[0]) { + // TODO(tlongeri): This can only be valid if the dim size is 1. + return emitError(loc, "Not implemented: Replication mismatch"); + } + if (*new_offsets[0] != *layout.offsets()[1] % ctx.target_shape[0] && + *layout.offsets()[1] + *(ishape.end() - 1) > ctx.target_shape[1]) { + // This requires blending data from different vregs. + return emitError(loc, + "Not implemented: Misaligned offsets and shape does not " + "fit in one vreg"); + } + } + // new_layout is only to get the new vreg array shape, the implicit dim is + // irrelevant (since we already have the implicit shape): + const VectorLayout new_layout(layout.bitwidth(), new_offsets, layout.tiling(), + VectorLayout::ImplicitDim::kNone); + SmallVector new_ishape(ishape); + new_ishape.push_back(1); + xla::Array new_vregs(new_layout.tileArrayShape( + /*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(new_ishape), + ctx.target_shape)); + // Preallocate an indices vector to avoid repeated allocations: + SmallVector idxs; + new_vregs.Each([&](const absl::Span dst_idx, + Value *const dst_vreg) { + // Indices of the new vreg in the new vreg array: + const int64_t new_2nd_minor_idx = *(dst_idx.end() - 2); + const int64_t new_3rd_minor_idx = *(dst_idx.end() - 3); + idxs.assign(dst_idx.begin(), dst_idx.end()); + if (!layout.offsets()[0].has_value() && new_3rd_minor_idx != 0) { + // All vregs along that dimension are the same + *(idxs.end() - 3) = 0; + *dst_vreg = new_vregs(idxs); + } else if (!layout.offsets()[1].has_value() && new_2nd_minor_idx != 0) { + // All vregs along that dimension are the same + *(idxs.end() - 2) = 0; + *dst_vreg = new_vregs(idxs); + } else { + // dst_vreg will hold slice [row_idx, col_idx:(col_idx + target_shape[0])] + // of the after-offsets source shape + const int64_t row_idx = + layout.offsets()[0] ? new_3rd_minor_idx + *layout.offsets()[0] : 0; + const int64_t col_idx = layout.offsets()[1] + ? new_2nd_minor_idx * ctx.target_shape[0] + + *layout.offsets()[1] - *new_offsets[0] + : 0; + + idxs.pop_back(); + *(idxs.end() - 2) = row_idx / ctx.target_shape[0]; + *(idxs.end() - 1) = col_idx / ctx.target_shape[1]; + Value src_vreg = vregs(idxs); + // TODO(tlongeri): We can sometimes skip operations when dst_vreg will + // hold a single non-padding element (first or last) and we don't need + // replication in the output. + if (layout.offsets()[0].has_value()) { + // [ . . . . . . . . ] [ . . . . a b c d ] + // [ . . . . a b c d ] => [ . . . . a b c d ] + // [ . . . . . . . . ] [ . . . . a b c d ] + // [ . . . . . . . . ] [ . . . . a b c d ] + src_vreg = broadcastSublane( + builder, src_vreg, + /*sublane_idx=*/row_idx % ctx.target_shape[0], ctx.target_shape); + } + if (layout.offsets()[1].has_value()) { + // [ . . . . a b c d ] [ a a a a a a a a ] + // [ . . . . a b c d ] => [ b b b b b b b b ] + // [ . . . . a b c d ] [ c c c c c c c c ] + // [ . . . . a b c d ] [ d d d d d d d d ] + src_vreg = builder.create( + loc, src_vreg.getType(), src_vreg, + /*lane=*/col_idx % ctx.target_shape[1]); + } + *dst_vreg = src_vreg; + } + }); + return new_vregs; +} + LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -4155,54 +4254,16 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, layout_in.bitwidth() == 32 && layout_in.hasNativeTiling(ctx.target_shape) && layout_in.tiling() == layout_out.tiling() && - layout_in.offsets()[0].value_or(0) == 0 && - layout_in.offsets()[1] == 0 && layout_out.offsets()[0] == 0 - // layout_out.offsets[1] can be anything, as we produce a - // replicated result - ) { - // First, insert the new singleton lane dimension. - SmallVector s = layout_in.implicitShape(src_shape); - s.push_back(1); - xla::Array dst_vregs_local(layout_out.tileArrayShape( - /*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(s), - ctx.target_shape)); - TPU_ASSERT_EQ_OP(dst_vregs_local.dimensions().back(), - 1); // We're inserting a singleton dimension - dst_vregs_local.Each( - [&](const absl::Span dst_idx, Value *const dst_vreg) { - const int64_t col_idx = *(dst_idx.end() - 2); - const int64_t row_idx = *(dst_idx.end() - 3); - auto [sublanes_in_lane, rem] = - std::div(ctx.target_shape[1], ctx.target_shape[0]); - CHECK_EQ(rem, 0); - if (!layout_in.offsets()[0].has_value() && row_idx != 0) { - return; // All vregs along that dimension are the same. - } - SmallVector src_idx(toArrayRef(dst_idx)); - src_idx.pop_back(); - *(src_idx.end() - 2) /= ctx.target_shape[0]; - *(src_idx.end() - 1) /= sublanes_in_lane; - Value col_vreg = src_vregs(src_idx); - // BroadcastInSublanesOp requires the sublanes to be replicated. - if (layout_in.offsets()[0].has_value()) { - const int32_t sublane = row_idx % ctx.target_shape[0]; - col_vreg = broadcastSublane(builder, col_vreg, sublane, - ctx.target_shape); - } - *dst_vreg = builder.create( - col_vreg.getType(), col_vreg, - /*lane=*/(col_idx % sublanes_in_lane) * ctx.target_shape[0]); - }); - if (!layout_in.offsets()[0].has_value()) { - // Broadcast the sublane vregs. - // TODO(tlongeri): This could be done more efficiently - dst_vregs_local.Each([&](const absl::Span dst_idx, - Value *const dst_vreg) { - SmallVector first_row_idx(toArrayRef(dst_idx)); - *(first_row_idx.end() - 3) = 0; - *dst_vreg = dst_vregs_local(first_row_idx); - }); - } + (!layout_in.offsets()[1].has_value() || + *layout_in.offsets()[1] % ctx.target_shape[0] == + layout_out.offsets()[0] || + *layout_in.offsets()[1] + src_tiled_dims[1] <= + ctx.target_shape[1])) { + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array dst_vregs_local, + insertImplicitMinorDimension(ctx, builder, op.getLoc(), src_vregs, + layout_in.implicitShape(src_shape), + layout_in, layout_out.offsets())); // Now, reshape the major axes of the vreg array. dst_vregs_local.Reshape( layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape)); @@ -6370,6 +6431,26 @@ FailureOr>> changeImplicitDim( }); return std::make_pair(dst, new_vregs); } + if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone && + dst_implicit_dim == VectorLayout::ImplicitDim::kMinor && + src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) { + // TODO(tlongeri): Make insertImplicitMinorDimension more flexible about + // offsets, then we can pass dst_offset_hints directly. + const LayoutOffset dst_2nd_minor_offset = + !src.offsets()[1] || *src.offsets()[1] + *(vty.getShape().end() - 1) <= + ctx.target_shape[1] + ? dst_offset_hints[0] + : LayoutOffset(*src.offsets()[1] % ctx.target_shape[0]); + VectorLayout dst(src.bitwidth(), + {dst_2nd_minor_offset, dst_offset_hints[1]}, src.tiling(), + VectorLayout::ImplicitDim::kMinor); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array dst_vregs, + insertImplicitMinorDimension(ctx, builder, loc, vregs, + src.implicitShape(vty.getShape()), src, + dst.offsets())); + return std::make_pair(dst, std::move(dst_vregs)); + } return emitError(loc, "Not implemented: Unsupported implicit dim change: from ") << src << " to " << dst_implicit_dim; From 7214a3a82a68bba806949c790124a2023cfc5c9e Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Dec 2024 02:06:35 -0800 Subject: [PATCH 618/698] [AutoPGLE] Add multi-process test case PiperOrigin-RevId: 703031689 --- jax/_src/cache_key.py | 65 ++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6e7a421482ce..2ec645cee407 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -238,6 +238,38 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): _hash_devices(hash_obj, accelerators) _hash_platform(hash_obj, backend) +# LINT.IfChange(xla_flags) +xla_flags_to_exclude_from_cache_key = [ + "--xla_dump_compress_protos", + "--xla_dump_module_metadata", + "--xla_dump_max_hlo_modules", + "--xla_dump_include_timestamp", + "--xla_dump_hlo_pass_re", + "--xla_dump_hlo_module_re", + "--xla_dump_hlo_snapshots", + "--xla_dump_fusion_visualization", + "--xla_dump_hlo_as_url", + "--xla_dump_hlo_as_proto", + "--xla_dump_hlo_as_text", + "--xla_dump_hlo_as_long_text", + "--xla_dump_hlo_as_html", + "--xla_dump_hlo_as_dot", + "--xla_dump_to", + "--xla_force_host_platform_device_count", + "--xla_dump_disable_metadata", + "--xla_dump_hlo_pipeline_re", + "--xla_tpu_sdc_checker_streamz_metric", + "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", + "--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks", + "--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present", + "--xla_gpu_cuda_data_dir", + "--xla_gpu_experimental_autotune_cache_mode", +] + +env_override_flags_to_exclude_from_cache_key = { + x.strip("-") for x in xla_flags_to_exclude_from_cache_key +} +# LINT.ThenChange(:debug_options) def _hash_serialized_compile_options(hash_obj, compile_options_obj, strip_device_assignment=False): @@ -284,6 +316,11 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_gpu_cuda_data_dir = "" # LINT.ThenChange(:xla_flags) + compile_options_copy.env_option_overrides = [ + flag_value + for flag_value in compile_options_copy.env_option_overrides + if flag_value[0] not in env_override_flags_to_exclude_from_cache_key + ] if strip_device_assignment and compile_options_copy.device_assignment: replica_count = compile_options_copy.device_assignment.replica_count() computation_count = compile_options_copy.device_assignment.computation_count() @@ -301,34 +338,6 @@ def _hash_platform(hash_obj, backend): def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]): - # LINT.IfChange(xla_flags) - xla_flags_to_exclude_from_cache_key = [ - "--xla_dump_compress_protos", - "--xla_dump_module_metadata", - "--xla_dump_max_hlo_modules", - "--xla_dump_include_timestamp", - "--xla_dump_hlo_pass_re", - "--xla_dump_hlo_module_re", - "--xla_dump_hlo_snapshots", - "--xla_dump_fusion_visualization", - "--xla_dump_hlo_as_url", - "--xla_dump_hlo_as_proto", - "--xla_dump_hlo_as_text", - "--xla_dump_hlo_as_long_text", - "--xla_dump_hlo_as_html", - "--xla_dump_hlo_as_dot", - "--xla_dump_to", - "--xla_force_host_platform_device_count", - "--xla_dump_disable_metadata", - "--xla_dump_hlo_pipeline_re", - "--xla_tpu_sdc_checker_streamz_metric", - "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", - "--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks", - "--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present", - "--xla_gpu_cuda_data_dir", - ] - # LINT.ThenChange(:debug_options) - xla_flags = [] xla_flags_env_var = os.getenv("XLA_FLAGS") From 4e17bea91a2dc514de105fa3e47403f322a536a6 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 4 Dec 2024 08:10:47 -0800 Subject: [PATCH 619/698] [shape_poly] Fix the handling of __pow__ for symbolic dimensions The code for handling exponentiation was wrong, and there were no tests. --- jax/_src/export/shape_poly.py | 24 +++++++++++++++++------- tests/shape_poly_test.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 15f99533d59e..3356d17c36d2 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -764,13 +764,23 @@ def __rmul__(self, other): return _DimExpr._linear_combination(self, other, 0, 0, self.scope) return _ensure_poly(other, "mul", self.scope).__mul__(self) - def __pow__(self, power, modulo=None): - assert modulo is None - try: - power = int(power) - except: - raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'") - return functools.reduce(op.mul, [self] * power) + def __pow__(self, power: core.DimSize, modulo=None): + if modulo is not None: + raise NotImplementedError("__pow__ modulo not implemented") + if is_symbolic_dim(power): + return power.__rpow__(self) # type: ignore + if power != int(power): + raise ValueError(f"Symbolic dimension cannot be raised to non-integer powers: '{self}' ** '{power}'") + if power >= 0: + return functools.reduce(op.mul, [self] * power, 1) + # We don't support negative powers, because JAX does not allow negative + # powers for integers + raise ValueError(f"Symbolic dimension cannot be raised to negative powers: '{self}' ** '{power}'") + + def __rpow__(self, other, modulo=None): + if modulo is not None: + raise NotImplementedError("__rpow__ modulo not implemented") + return self.__jax_array__().__rpow__(other) def __floordiv__(self, divisor): if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor): diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 668907ffee27..0d1da6ceaeef 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -128,11 +128,11 @@ def sampled_assertion(self, ): """Checks `assertion(e, fun(*operands))` symbolically and concretely. - For the concrete check, it will same the space of dimension variable + For the concrete check, it will sample the space of dimension variable assignments for the dimension variables in `e`. - This is useful when `fun` can operate both with polynomials and with - concrete values, and we want to double-check that the behavior is sound. + This is useful when `fun` can operate both with symbolic and with + concrete values, and we want to check that the behavior is sound. """ computed_sym = fun(*operands_sym) assertion_fun = { @@ -1429,6 +1429,29 @@ def test_non_trivial_dim_expr(self, expr=lambda d: d % -2): arg_descriptors=[RandArg((3,), np.int64)], polymorphic_shapes=["b"]) + @jtu.parameterized_filterable( + # The function `f` will be called with x: f32[b] + kwargs=[ + dict(testcase_name="cube", f=lambda x: x.shape[0] ** 3), + dict(testcase_name="zero", f=lambda x: x.shape[0] ** 0), + dict(testcase_name="rpow", f=lambda x: 2 ** x.shape[0]), + dict(testcase_name="negative", + f=lambda x: x.shape[0] ** -2, + expect_error=(ValueError, "cannot be raised to negative powers")), + dict(testcase_name="non_integer", + f=lambda x: x.shape[0] ** 1.5, + expect_error=(ValueError, "cannot be raised to non-integer powers")), + dict(testcase_name="sym_pow", + f=lambda x: x.shape[0] ** x.shape[0]), + ] + ) + def test_pow(self, f, expect_error: tuple[Exception, str] | None = None): + check_shape_poly(self, + f, + arg_descriptors=[RandArg((3,), np.float32)], + polymorphic_shapes=["b"], + expect_error=expect_error) + def test_static_shape_result(self): """The result has static shape.""" From c965ffbfc97947dc4b1b619124e6b385dd396096 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 5 Dec 2024 03:32:08 -0800 Subject: [PATCH 620/698] [Mosaic GPU] Remove expect_wait from Barrier.wait It looks like LLVM already moves the wait loops to the end of the program, so the whole optimization is no longer necessary and only adds unnecessary operations. PiperOrigin-RevId: 703052393 --- jax/experimental/mosaic/gpu/utils.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 2279df4f3984..f6cab5654e64 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -673,33 +673,17 @@ def __getitem__(self, offset: ir.Value | int) -> "BarrierRef": 1, ) - def wait_parity(self, parity, expect_wait=False): - i1 = ir.IntegerType.get_signless(1) + def wait_parity(self, parity): i32 = ir.IntegerType.get_signless(32) - ticks = c(10000000, i32) - address = self.get_ptr() + ticks = arith.constant(i32, 10000000) parity = arith.extui(i32, parity) - if expect_wait: - nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks) - return - barrier_ready = llvm.inline_asm( - i1, - [address, parity], - "mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;", - "=b,l,r", - has_side_effects=True, - ) - should_wait = arith.xori(barrier_ready, c(1, i1)) - should_wait = llvm.intr_expect(should_wait, c(0, i1)) - with ir.InsertionPoint(scf.IfOp(should_wait).then_block): - nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks) - scf.yield_([]) + nvvm.mbarrier_try_wait_parity_shared(self.get_ptr(), parity, ticks) - def wait(self, expect_wait=False): + def wait(self): parities = memref.load(self.phases, []) parity, new_parities = self.update_parities(parities) memref.store(new_parities, self.phases, []) - self.wait_parity(parity, expect_wait=expect_wait) + self.wait_parity(parity) def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: i32 = ir.IntegerType.get_signless(32) From 03861d43ec4063dd82850b6cdac4438e3d5d4e94 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 5 Dec 2024 03:38:24 -0800 Subject: [PATCH 621/698] [pallas:mosaic_gpu] Removed leftover debugging code PiperOrigin-RevId: 703054113 --- jax/_src/pallas/mosaic_gpu/lowering.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 2900da133158..79f8f116fbd4 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1046,8 +1046,6 @@ def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): x_smem, transforms = _handle_reshaping(x_smem, transforms) x_smem, transforms = _handle_indexing(x_smem, transforms) - print("ctx:", ctx) - print("transforms:", transforms) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): if tiling != (64, swizzle // x_aval.dtype.itemsize): From 569c2a3c6c4f47bc61c020196ba68b7d62315b6c Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 5 Dec 2024 07:00:58 -0800 Subject: [PATCH 622/698] Reverts 73962b740890a728295fa09f515dcf96cb820822 PiperOrigin-RevId: 703100851 --- jax/_src/lax/lax.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ad08e1335a40..9c8afe4f9292 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -923,8 +923,9 @@ def accumulation_type(self) -> DTypeLike | None: case _: return np.float32 - @property - def supported_output_types(self) -> tuple[DTypeLike, ...] | None: + def supported_output_types( + self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike + ) -> tuple[DTypeLike, ...] | None: match self: case ( DotAlgorithmPreset.ANY_F8_ANY_F8_F32 @@ -941,7 +942,17 @@ def supported_output_types(self) -> tuple[DTypeLike, ...] | None: dtypes.float8_e4m3b11fnuz, ) case DotAlgorithmPreset.F16_F16_F32: - return (np.float32, np.float16) + # F16 output is only supported with F16 inputs. + if dtypes.promote_types(lhs_dtype, rhs_dtype) == np.float16: + return (np.float32, np.float16) + else: + return (np.float32,) + case DotAlgorithmPreset.BF16_BF16_F32: + # BF16 output is only supported with BF16 inputs. + if dtypes.promote_types(lhs_dtype, rhs_dtype) == dtypes.bfloat16: + return (np.float32, dtypes.bfloat16) + else: + return (np.float32,) case _: accumulation_type = self.accumulation_type return None if accumulation_type is None else (accumulation_type,) @@ -3713,13 +3724,6 @@ def get_algorithm_compute_types( algorithm.accumulation_type, ) - supported_output_types = algorithm.supported_output_types - - if algorithm == DotAlgorithmPreset.BF16_BF16_F32: - # If dtype is anything other than float32, it will be cast to bfloat16. - if np.dtype(lhs_dtype) != np.float32: - supported_output_types = (np.float32, dtypes.bfloat16) - def maybe_convert_dtype(input_dtype, target_dtypes): if target_dtypes is None: return input_dtype @@ -3727,11 +3731,12 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return input_dtype return target_dtypes[0] - return ( - maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types), - maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types), - maybe_convert_dtype(out_dtype, supported_output_types), + lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types) + rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types) + out_type = maybe_convert_dtype( + out_dtype, algorithm.supported_output_types(lhs_dtype, rhs_dtype) ) + return lhs_dtype, rhs_dtype, out_type def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, From d5ead570bbb26b3b912cec0a9bfdafceacc72e67 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 5 Dec 2024 07:05:39 -0800 Subject: [PATCH 623/698] [Mosaic TPU] Add support for modeling loads/stores and fix minor issues in model extraction PiperOrigin-RevId: 703102072 --- jax/_src/pallas/mosaic/verification.py | 68 ++++++++++++++++++++- tests/pallas/tpu_pallas_distributed_test.py | 59 ++++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index 61caa4087d99..08ff58770804 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -145,6 +145,15 @@ def block(self, begin: str, end: str): self.level -= 1 self.locals.append(self._indent(end) + "\n") + @contextlib.contextmanager + def comment_if_emitted(self, comment): + self.comment(comment) + yield + self.comment(comment) + if self.locals[-1] == self.locals[-2]: + self.locals.pop() + self.locals.pop() + def get(self, value: ir.Value, default: Any = _UNSPECIFIED): if default is _UNSPECIFIED: return self.env[value] @@ -358,6 +367,17 @@ def _print_op(ctx, op): return bin_op(ctx, "int", "%", *op.operands) case "arith.divsi": return bin_op(ctx, "int", "/", *op.operands) + case "arith.andi": + return bin_op(ctx, _model_type(op.result.type), "&", *op.operands) + case "arith.select": + cond, if_true, if_false = map(lambda o: ctx.get(o, None), op.operands) + if cond is None or if_true is None or if_false is None: + return NotImplemented + result_ty = _model_type(op.result.type) + return ctx.emit(result_ty, f"({cond} -> {if_true} : {if_false})") + case "arith.index_cast": + model = ctx.get(op.operands[0], None) + return ctx.emit("int", model) if model is not None else NotImplemented case "arith.cmpi": match op.predicate.value: case arith.CmpIPredicate.eq: @@ -386,12 +406,44 @@ def _print_op(ctx, op): read_refs.append(model) with ctx.block("d_step {", "}"): # Start reading for r in read_refs: + for loc in r.written_at(None): + ctx.emit(None, f"assert(!{loc})") for loc in r.readers_at(None): ctx.emit(None, f"{loc}++") with ctx.block("d_step {", "}"): # Stop reading for r in read_refs: for loc in r.readers_at(None): ctx.emit(None, f"{loc}--") + case "vector.load": + ref = ctx.get(op.operands[0]) + assert isinstance(ref, GlobalRefModel) + if (first_idx := ctx.get(op.operands[1], None)) is not None: + leading_load_len = ir.VectorType(op.result.type).shape[0] + ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_load_len) + with ctx.block("d_step {", "}"): # Start reading + for loc in ref.written_at(None): + ctx.emit(None, f"assert(!{loc})") + for loc in ref.readers_at(None): + ctx.emit(None, f"{loc}++") + with ctx.block("d_step {", "}"): # Stop reading + for loc in ref.readers_at(None): + ctx.emit(None, f"{loc}--") + return NotImplemented # We don't model the result of the load. + case "vector.store": + ref = ctx.get(op.operands[1]) # Stored value goes first + assert isinstance(ref, GlobalRefModel) + if (first_idx := ctx.get(op.operands[2], None)) is not None: + leading_store_len = ir.VectorType(op.operands[0].type).shape[0] + ref = GlobalRefModel(f"{ref.base} + {first_idx}", leading_store_len) + with ctx.block("d_step {", "}"): # Start writing + for loc in ref.readers_at(None): + ctx.emit(None, f"assert(!{loc})") + for loc in ref.written_at(None): + ctx.emit(None, f"assert(!{loc})") + ctx.emit(None, f"{loc} = 1") + with ctx.block("d_step {", "}"): # Stop reading + for loc in ref.written_at(None): + ctx.emit(None, f"{loc} = 0") case "scf.for": carrys = [ ctx.emit("int", ctx.get(arg)) @@ -419,6 +471,7 @@ def _print_op(ctx, op): ctx.emit(None, f"{c} = {ctx.get(new)}") ctx.emit(None, f"{induction_var} = {induction_var} + {step}") ctx.emit(None, ":: else -> break") + ctx.emit(None, "skip") # To avoid "Jump into d_step sequence errors" if len(carrys) == 1: return carrys[0] else: @@ -450,16 +503,27 @@ def bin_op(ctx, result_ty, op, lhs, rhs): return ctx.emit(result_ty, f"{lhs} {op} {rhs}") +def _model_type(ty): + if ir.IntegerType.isinstance(ty): + if ir.IntegerType(ty).width == 1: + return "bool" + else: + return "int" + else: + raise NotImplementedError(ty) + + def _print_block(ctx, block): for op in block: try: - results = _print_op(ctx, op) + with ctx.comment_if_emitted(op.OPERATION_NAME): + results = _print_op(ctx, op) except Exception as e: raise RuntimeError(f"Failed to print op: {op}") from e if results is NotImplemented: continue if not op.results: - assert results is None + assert results is None or results == () elif len(op.results) > 1: raise NotImplementedError(op) else: diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index f7e3eaaf0736..7b3bd70efafe 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -15,6 +15,8 @@ """Tests for distributed pallas TPU operations.""" import functools +import os +import tempfile from absl.testing import absltest from absl.testing import parameterized import jax @@ -513,5 +515,62 @@ def _(): atol=1e-5, rtol=1e-3) + +class VerificationTest(jtu.JaxTestCase): + + def test_verification(self): + if (num_devices := jax.local_device_count()) <= 1: + self.skipTest('Test requires multiple devices.') + if not jtu.is_device_tpu_at_least(4) or jax.devices()[0].num_cores > 1: + self.skipTest('Test requires a new single-core TPU.') + def kernel_body(in_ref, out_ref, scratch_ref, send_sem, recv_sem, capacity_sem): + my_id = lax.axis_index('x') + dst_id = jnp.where(my_id == num_devices - 1, 0, my_id + 1) + src_id = jnp.where(my_id == 0, num_devices - 1, my_id - 1) + pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id) + out_ref[...] = jnp.zeros_like(out_ref) + scratch_ref[0] = in_ref[0] + + @functools.partial(lax.fori_loop, 0, num_devices - 1, init_val=None) + def _(i, _): + slot = i % 2 + next_slot = 1 - slot + pltpu.semaphore_wait(capacity_sem, 1) + copy = pltpu.async_remote_copy( + scratch_ref.at[slot], + scratch_ref.at[next_slot], + send_sem, + recv_sem, + device_id=dst_id, + ) + out_ref[...] += scratch_ref[slot] + copy.wait() + pltpu.semaphore_signal(capacity_sem, 1, device_id=src_id) + out_ref[...] += scratch_ref[(num_devices - 1) % 2] + pltpu.semaphore_wait(capacity_sem, 1) + + kernel = pl.pallas_call( + kernel_body, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + scratch_shapes=[ + pltpu.VMEM((2, 128, 128), jnp.float32), + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, + ], + ) + devices = mesh_utils.create_device_mesh((num_devices,)) + mesh = jax.sharding.Mesh(devices, ['x']) + # This is just a smoke test to ensure that the verification does not crash. + with tempfile.TemporaryDirectory() as tmpdir: + previous_config = jax.config.read('jax_pallas_dump_promela_to') + jax.config.update('jax_pallas_dump_promela_to', tmpdir) + shard_map.shard_map( + kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_rep=False + )(jnp.ones((8, 128, 128), jnp.float32)) + jax.config.update('jax_pallas_dump_promela_to', previous_config) + self.assertNotEmpty(os.listdir(tmpdir)) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From d034680f6da44a919cfbf92749cf36a3e6f18412 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 5 Dec 2024 07:20:00 -0800 Subject: [PATCH 624/698] [Mosaic GPU] Always annotate block initialization in the profiles This helps establish a shared timeline between different warpgroups and shows how expensive it really was. PiperOrigin-RevId: 703105898 --- jax/experimental/mosaic/gpu/core.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 4b4882f65f0c..d8774e932d04 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -806,16 +806,6 @@ def _launch( ) ) - smem_ref_tree = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers - ) - # TODO(apaszke): Skip the following if no barriers were initialized. - nvvm.fence_mbarrier_init() - if math.prod(cluster) != 1: - nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) - nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - gpu.barrier() - if profiler_spec: prof_smem = memref.view( ir.MemRefType.get( @@ -832,7 +822,19 @@ def _launch( ptr_ty = ir.Type.parse("!llvm.ptr") scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree + ctx = LaunchContext(launch_op, scratch_ptr, cluster, prof) + with ctx.named_region("Init"): + smem_ref_tree = _construct_smem_reftree( + cluster, dynamic_smem, smem_buffers + ) + # TODO(apaszke): Skip the following if no barriers were initialized. + nvvm.fence_mbarrier_init() + if math.prod(cluster) != 1: + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) + gpu.barrier() + + yield ctx, smem_ref_tree if prof is not None: prof.finalize(grid=grid, block=block) gpu.terminator() From e5102957b0197d78a6ee2e181782737c77fb02d1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 5 Dec 2024 07:32:10 -0800 Subject: [PATCH 625/698] [pallas:mosaic_gpu] Do not store the grid mapping in `ModuleContext` We really only ever use the grid names. PiperOrigin-RevId: 703108864 --- jax/_src/pallas/mosaic_gpu/lowering.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 79f8f116fbd4..e04b0ff1adfe 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,13 +17,13 @@ from __future__ import annotations import collections -from collections.abc import MutableMapping, MutableSequence, Sequence +from collections.abc import Hashable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools import itertools as it import math -from typing import Any, Hashable, Protocol, cast +from typing import Any, Protocol, cast import jax from jax import lax @@ -192,7 +192,7 @@ def _reduce_sum_resource_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int @dataclasses.dataclass class ModuleContext: name: str - grid_mapping: pallas_core.GridMapping + grid_names: Sequence[Hashable] | None program_ids: Sequence[ir.Value] | None approx_math: bool runtime_smem: ir.Value # ir.MemRefType @@ -517,7 +517,7 @@ def make_program_ids(step: ir.Value): grouped_barriers[barrier].append(barrier_ref) module_ctx = ModuleContext( name_and_src_info.name, - grid_mapping, + grid_mapping.grid_names, None, approx_math, runtime_smem, @@ -1290,7 +1290,7 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): @register_lowering_rule(lax.axis_index_p) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): i32 = ir.IntegerType.get_signless(32) - grid_names = ctx.module_ctx.grid_mapping.grid_names + grid_names = ctx.module_ctx.grid_names squashed_dims = ctx.module_ctx.squashed_dims if squashed_dims: unsquashed_names = grid_names[-3:] From 5fe5206b6a7c20d0e556eec0a1888409928032fc Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 5 Dec 2024 08:01:47 -0800 Subject: [PATCH 626/698] [shape_poly] Remove some deprecated kwargs PiperOrigin-RevId: 703116755 --- CHANGELOG.md | 3 +++ jax/_src/export/shape_poly.py | 23 ----------------------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b5758d107077..258fad49b5f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. use `uses_global_constants`. * the `lowering_platforms` kwarg for {func}`jax.export.export`: use `platforms` instead. + * The kwargs `symbolic_scope` and `symbolic_constraints` from + {func}`jax.export.symbolic_args_specs` have been removed. They were + deprecated in June 2024. Use `scope` and `constraints` instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 15f99533d59e..cb9a99564093 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1198,12 +1198,6 @@ def is_symbolic_dim(p: DimSize) -> bool: """ return isinstance(p, _DimExpr) -def is_poly_dim(p: DimSize) -> bool: - # TODO: deprecated January 2024, remove June 2024. - warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim", - DeprecationWarning, stacklevel=2) - return is_symbolic_dim(p) - dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] def _einsum_contract_path(*operands, **kwargs): @@ -1413,8 +1407,6 @@ def symbolic_args_specs( shapes_specs, # prefix pytree of strings constraints: Sequence[str] = (), scope: SymbolicScope | None = None, - symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24 - symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24 ): """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. @@ -1435,25 +1427,10 @@ def symbolic_args_specs( arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). constraints: as for :func:`jax.export.symbolic_shape`. scope: as for :func:`jax.export.symbolic_shape`. - symbolic_constraints: DEPRECATED, use `constraints`. - symbolic_scope: DEPRECATED, use `scope`. Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes replaced with symbolic dimensions as specified by `shapes_specs`. """ - if symbolic_constraints: - warnings.warn("symbolic_constraints is deprecated, use constraints", - DeprecationWarning, stacklevel=2) - if constraints: - raise ValueError("Cannot use both symbolic_constraints and constraints") - constraints = symbolic_constraints - if symbolic_scope is not None: - warnings.warn("symbolic_scope is deprecated, use scope", - DeprecationWarning, stacklevel=2) - if scope is not None: - raise ValueError("Cannot use both symbolic_scope and scope") - scope = symbolic_scope - polymorphic_shapes = shapes_specs args_flat, args_tree = tree_util.tree_flatten(args) From 4a41aa0a46085f95437bf9853c1786836c3d2321 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 5 Dec 2024 08:03:37 -0800 Subject: [PATCH 627/698] [pallas:mosaic_gpu] Removed unnecessarily strict check in `emit_pipeline` PiperOrigin-RevId: 703117465 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 069b8d9e78d3..ee3f03f1849f 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -181,17 +181,6 @@ def emit_pipeline( delay_release = 0 # No need to delay anything. def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): - for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)): - if any( - spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore - for idx in range(1, len(grid) + 1) - if spec.block_shape is not None - ): - raise NotImplementedError( - f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block" - f" shape {spec.block_shape}." - ) - in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( [ From 3f5f3e1c47c230cc5d44841b08c3db9598442d13 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 5 Dec 2024 08:39:48 -0800 Subject: [PATCH 628/698] [export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility. This is because the underlying Triton IR does not guarantee compatibility. PiperOrigin-RevId: 703127711 --- CHANGELOG.md | 5 +++++ jax/_src/export/_export.py | 3 ++- tests/pallas/export_back_compat_pallas_test.py | 5 +++-- tests/pallas/export_pallas_test.py | 4 ++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 258fad49b5f4..b6d0f97f439d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. return NaN for negative integer inputs, to match the behavior of SciPy from https://github.com/scipy/scipy/pull/21827. * `jax.clear_backends` was removed after being deprecated in v0.4.26. + * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom + call that we guarantee export stability. This is because this custom call + relies on Triton IR, which is not guaranteed to be stable. If you need + to export code that uses this custom call, you can use the `disabled_checks` + parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index ad2c7fdac2dc..e3508639fe15 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1005,7 +1005,8 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", "cu_threefry2x32", "cu_threefry2x32_ffi", - "__gpu$xla.gpu.triton", # Pallas call on GPU + # Triton IR does not guarantee stability. + # "__gpu$xla.gpu.triton", # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 1b810bcb6f26..462597e567f2 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -48,9 +48,10 @@ def setUp(self): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() - @unittest.skip("TODO(necula): This test is checking backwards compatibility " + @unittest.skip("This test is checking backwards compatibility " "of Triton IR, but Triton doesn't promise backwards " - "compatibility for its IR.") + "compatibility for its IR, and we have since removed " + "the corresponding custom call from the guaranteed stable list.") def test_triton_add_one(self): def func(x): def add_one(x_ref, o_ref): diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index 70e40e1f2801..8b18f706a1d0 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -50,6 +50,10 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: exp = export.export( add_vectors, platforms=["tpu", "cuda"], + # The Pallas GPU custom call is not enabled for export by default. + disabled_checks=[ + export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton") + ] )(a, a) if (jtu.device_under_test() == "tpu" or From 29a8cce66cfe35f216b958779e975910d280bb5b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Dec 2024 09:27:19 -0800 Subject: [PATCH 629/698] jax.numpy: require boolean dtype for where argument --- jax/_src/deprecations.py | 1 + jax/_src/numpy/reductions.py | 33 +++++++++++++++++++++++++++++--- tests/lax_numpy_reducers_test.py | 28 +++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index c7a956068981..778a084e807a 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -130,5 +130,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') +register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') register('pallas-gpu-triton') diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 69d6843f5155..eea734420176 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -81,6 +81,20 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: return dtypes.int_ return dtype +def check_where(name: str, where: ArrayLike | None) -> Array | None: + if where is None: + return where + check_arraylike(name, where) + where_arr = lax_internal.asarray(where) + if where_arr.dtype != bool: + # Deprecation added 2024-12-05 + deprecations.warn( + 'jax-numpy-reduction-non-boolean-where', + f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.", + stacklevel=2) + return where_arr.astype(bool) + return where_arr + ReductionOp = Callable[[Any, Any], Any] @@ -101,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") check_arraylike(name, a) + where_ = check_where(name, where_) dtypes.check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") @@ -730,6 +745,8 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") + check_arraylike("logsumexp", a) + where = check_where("logsumexp", where) a_arr, = promote_dtypes_inexact(a) pos_dims, dims = _reduction_dims(a_arr, axis) amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf) @@ -748,6 +765,8 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") + check_arraylike("logsumexp2", a) + where = check_where("logsumexp2", where) ln2 = float(np.log(2)) if initial is not None: initial *= ln2 @@ -850,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, upcast_f16_for_computation: bool = True, where: ArrayLike | None = None) -> Array: check_arraylike("mean", a) + where = check_where("mean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") @@ -1087,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("var", a) + where = check_where("var", where) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError("The 'out' argument to jnp.var is not supported.") @@ -1224,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("std", a) + where = check_where("std", where) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") @@ -1330,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None, def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, - axis: Axis = None, keepdims: bool = False, **kwargs) -> Array: + axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None, + **kwargs) -> Array: check_arraylike(name, a) + where = check_where(name, where) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): - return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) + return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), - axis=axis, keepdims=keepdims, **kwargs) + axis=axis, keepdims=keepdims, where=where, **kwargs) if nan_if_all_nan: return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) @@ -1755,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out Array([[nan, nan, nan, nan]], dtype=float32) """ check_arraylike("nanmean", a) + where = check_where("nanmean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): @@ -1848,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: [4. ]], dtype=float32) """ check_arraylike("nanvar", a) + where = check_where("nanvar", where) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") @@ -1943,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: Array([[0.5, 0.5, 0. , 0. ]], dtype=float32) """ check_arraylike("nanstd", a) + where = check_where("nanstd", where) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index be6208e6e305..2bef35fbdcef 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -448,6 +448,34 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product(rec=JAX_REDUCER_INITIAL_RECORDS) + def testReducerWhereNonBooleanErrorInitial(self, rec): + dtype = rec.dtypes[0] + x = jnp.zeros((10,), dtype) + where = jnp.ones(10, dtype=int) + func = getattr(jnp, rec.name) + def assert_warns_or_errors(msg): + if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + func(x, where=where, initial=jnp.array(0, dtype=dtype)) + + @jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS) + def testReducerWhereNonBooleanErrorNoInitial(self, rec): + dtype = rec.dtypes[0] + x = jnp.zeros((10,), dtype) + where = jnp.ones(10, dtype=int) + func = getattr(jnp, rec.name) + def assert_warns_or_errors(msg): + if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + func(x, where=where) + @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact, From aaaee63ac5887a8f7cf06655a42053eb9c06b91b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Dec 2024 09:48:32 -0800 Subject: [PATCH 630/698] jnp.linalg.vector_norm: properly support multiple axes --- jax/_src/numpy/linalg.py | 68 ++++++++++++++++++---------------------- tests/linalg_test.py | 40 ++++++++++++++--------- 2 files changed, 57 insertions(+), 51 deletions(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e7e2e369722d..c01a5d270f0f 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -1159,35 +1159,7 @@ def norm(x: ArrayLike, ord: int | str | None = None, num_axes = len(axis) if num_axes == 1: - if ord is None or ord == 2: - return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, - keepdims=keepdims)) - elif ord == jnp.inf: - return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif ord == -jnp.inf: - return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif ord == 0: - return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, - axis=axis, keepdims=keepdims) - elif ord == 1: - # Numpy has a special case for ord == 1 as an optimization. We don't - # really need the optimization (XLA could do it for us), but the Numpy - # code has slightly different type promotion semantics, so we need a - # special case too. - return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) - elif isinstance(ord, str): - msg = f"Invalid order '{ord}' for vector norm." - if ord == "inf": - msg += "Use 'jax.numpy.inf' instead." - if ord == "-inf": - msg += "Use '-jax.numpy.inf' instead." - raise ValueError(msg) - else: - abs_x = ufuncs.abs(x) - ord_arr = lax_internal._const(abs_x, ord) - ord_inv = lax_internal._const(abs_x, 1. / ord_arr) - out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) - return ufuncs.power(out, ord_inv) + return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims) elif num_axes == 2: row_axis, col_axis = axis # pytype: disable=bad-unpacking @@ -1632,7 +1604,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array: @export -def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False, +def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: """Compute the vector norm of a vector or batch of vectors. @@ -1668,13 +1640,35 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa Array([3.7416575, 9.486833 ], dtype=float32) """ check_arraylike('jnp.linalg.vector_norm', x) - if axis is None: - result = norm(jnp.ravel(x), ord=ord) - if keepdims: - result = lax.expand_dims(result, range(jnp.ndim(x))) - return result - return norm(x, axis=axis, keepdims=keepdims, ord=ord) - + if ord is None or ord == 2: + return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, + keepdims=keepdims)) + elif ord == jnp.inf: + return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif ord == -jnp.inf: + return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif ord == 0: + return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, + axis=axis, keepdims=keepdims) + elif ord == 1: + # Numpy has a special case for ord == 1 as an optimization. We don't + # really need the optimization (XLA could do it for us), but the Numpy + # code has slightly different type promotion semantics, so we need a + # special case too. + return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) + elif isinstance(ord, str): + msg = f"Invalid order '{ord}' for vector norm." + if ord == "inf": + msg += "Use 'jax.numpy.inf' instead." + if ord == "-inf": + msg += "Use '-jax.numpy.inf' instead." + raise ValueError(msg) + else: + abs_x = ufuncs.abs(x) + ord_arr = lax_internal._const(abs_x, ord) + ord_inv = lax_internal._const(abs_x, 1. / ord_arr) + out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) + return ufuncs.power(out, ord_inv) @export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 7c135b4ffeca..0da09e232deb 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,6 +16,8 @@ from functools import partial import itertools +from typing import Iterator +from unittest import skipIf import numpy as np import scipy @@ -54,6 +56,20 @@ def _is_required_cuda_version_satisfied(cuda_version): return int(version.split()[-1]) >= cuda_version +def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]: + """ + Generate a range of valid axis arguments for a reduction over + an array with a given number of dimensions. + """ + yield from (None, ()) + if ndim > 0: + yield from (0, (-1,)) + if ndim > 1: + yield from (1, (0, 1), (-1, 0)) + if ndim > 2: + yield (-1, 0, 1) + + def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: """scipy.linalg.toeplitz with v1.17+ batching semantics.""" if scipy_version >= (1, 17, 0): @@ -707,29 +723,25 @@ def testMatrixNorm(self, shape, dtype, keepdims, ord): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) self._CompileAndCheck(jnp_fn, args_maker) + @skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0") @jtu.sample_product( - shape=[(3,), (3, 4), (2, 3, 4, 5)], + [ + dict(shape=shape, axis=axis) + for shape in [(3,), (3, 4), (2, 3, 4, 5)] + for axis in _axis_for_ndim(len(shape)) + ], dtype=float_types + complex_types, keepdims=[True, False], - axis=[0, None], ord=[1, -1, 2, -2, np.inf, -np.inf], ) def testVectorNorm(self, shape, dtype, keepdims, axis, ord): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - def np_fn(x, *, ord, keepdims, axis): - x = np.asarray(x) - if axis is None: - result = np_fn(x.ravel(), ord=ord, keepdims=False, axis=0) - return np.reshape(result, (1,) * x.ndim) if keepdims else result - return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis) - else: - np_fn = np.linalg.vector_norm - np_fn = partial(np_fn, ord=ord, keepdims=keepdims, axis=axis) + np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis) - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) - self._CompileAndCheck(jnp_fn, args_maker) + tol = 1E-3 if jtu.test_device_matches(['tpu']) else None + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) # jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here. @jtu.sample_product( From e20a483befbb80bbf782b931ec57a44c78c313b8 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 5 Dec 2024 10:51:59 -0800 Subject: [PATCH 631/698] [JAX] Add end-to-end execution support in colocated Python API This change adds a capability to run colocated Python function calls through `PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested with a prototype of a colocated Python backend. The overall behavior remains the same for McJAX (running the user code inline when colocated Python is called); the new logic will be used once we introduce a colocated Python backend for McJAX. Key highlights: * Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++ dispatch path. * `CustomCallProgram` for a colocated Python compilation nows includes specialization (input/output specs, devices). This information allows a colocated Python backend to transform input/outputs and validate PyTree/dtype/shape/sharding. * `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values. * Deserialization of devices now prefers the default backend. This improves the compatibility with an environment using both multi-platform backend as well as the standard "cpu" backend at the same time. * Several bugs have been fixed (e.g., correctly using `{}` for kwargs). PiperOrigin-RevId: 703172997 --- jax/experimental/colocated_python/func.py | 77 ++++++++++++++----- .../colocated_python/serialization.py | 20 ++++- tests/BUILD | 1 + tests/colocated_python_test.py | 45 +++++++---- 4 files changed, 109 insertions(+), 34 deletions(-) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 6639e7eefdd6..cba2a0f3801b 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -24,6 +24,7 @@ import jax from jax._src import api from jax._src import tree_util +from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc from jax._src.traceback_util import api_boundary from jax._src.util import wraps @@ -137,23 +138,54 @@ def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None: def _compile_to_executable( name: str, fun: Callable[..., Any], + in_specs_treedef: tree_util.PyTreeDef, in_specs_leaves: tuple[api.ShapeDtypeStruct, ...], + out_specs_treedef: tree_util.PyTreeDef, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...], devices: xc.DeviceList, ) -> Callable[..., Any]: """Compiles a Python function into a runtime executable.""" - pickled_function = _serialize(fun) + fun_and_specialization = ( + fun, + in_specs_treedef, + in_specs_leaves, + out_specs_treedef, + out_specs_leaves, + devices, + ) + pickled_function = _serialize(fun_and_specialization) program = ifrt_programs.make_colocated_python_program( name, pickled_function, devices, in_specs_leaves, out_specs_leaves ) - # TODO(hyeontaek): Compile the program and use the executable. - del program + ifrt_client = devices[0].client + out_sdss = tuple( + jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves + ) + out_shardings = tuple(sds.sharding for sds in out_specs_leaves) + try: + compile_options = ifrt_programs.make_colocated_python_compile_options() + loaded_executable = ifrt_client.compile_ifrt_program( + program, compile_options + ) + out_handlers = pxla.global_avals_to_results_handler( + out_sdss, out_shardings, committed=True + ).handlers + + def call(*args, **kwargs): + args_leaves = tree_util.tree_leaves((args, kwargs)) + execute_result = loaded_executable.execute_sharded( + args_leaves, with_tokens=False + ) + results = execute_result.consume_with_handlers(out_handlers) + return tree_util.tree_unflatten(out_specs_treedef, results) - del name - del in_specs_leaves - del out_specs_leaves - del devices - return fun + return call + except jax.errors.JaxRuntimeError as e: + # TODO(hyeontaek): Implement colocated Python support in McJAX and remove + # this fallback path. + if "PjRtCompiler requires an HloProgram" in str(e): + return fun + raise def _make_output_specs_and_push_result_fun( @@ -170,12 +202,12 @@ def _make_output_specs_and_push_result_fun( def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]: result = info.fun(*args, **kwargs) - out_leaves, out_treedef = tree_util.tree_flatten(result) - out_spec_leaves = tuple(_get_spec(x) for x in out_leaves) - func_backend.SINGLETON_RESULT_STORE.push(uid, out_leaves) + result_leaves, out_treedef = tree_util.tree_flatten(result) + out_spec_leaves = tuple(_get_spec(x) for x in result_leaves) + func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves) return _serialize_specs(out_treedef, out_spec_leaves, devices) - out_specs_leaves, _ = tree_util.tree_flatten( + out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( _make_specs_for_serialized_specs(specialization.devices), ) name = getattr(info.fun, "__name__", "unknown") @@ -183,7 +215,9 @@ def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]: return _compile_to_executable( name=name, fun=lowered_fun, + in_specs_treedef=specialization.in_specs_treedef, in_specs_leaves=specialization.in_specs_leaves, + out_specs_treedef=out_specs_treedef, out_specs_leaves=tuple(out_specs_leaves), devices=specialization.devices, ) @@ -200,21 +234,23 @@ def _make_pop_result_fun( out_specs_treedef = specialization.out_specs_treedef def lowered_fun() -> Any: - flat_result = func_backend.SINGLETON_RESULT_STORE.pop(uid) - return tree_util.tree_unflatten(out_specs_treedef, flat_result) + result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid) + return tree_util.tree_unflatten(out_specs_treedef, result_leaves) - in_specs, _ = tree_util.tree_flatten(( + in_specs_leaves, in_specs_treedef = tree_util.tree_flatten(( # args (), # kwargs - (), + {}, )) name = getattr(info.fun, "__name__", "unknown") name = f"{name}_pop_result" return _compile_to_executable( name=name, fun=lowered_fun, - in_specs_leaves=tuple(in_specs), + in_specs_treedef=in_specs_treedef, + in_specs_leaves=tuple(in_specs_leaves), + out_specs_treedef=specialization.out_specs_treedef, out_specs_leaves=specialization.out_specs_leaves, devices=specialization.devices, ) @@ -234,7 +270,9 @@ def _make_async_execution_fun( return _compile_to_executable( name=name, fun=info.fun, + in_specs_treedef=specialization.in_specs_treedef, in_specs_leaves=specialization.in_specs_leaves, + out_specs_treedef=specialization.out_specs_treedef, out_specs_leaves=specialization.out_specs_leaves, devices=specialization.devices, ) @@ -283,7 +321,10 @@ def specialized_func(*args, **kwargs) -> Any: return _make_pop_result_fun(info, specialization, uid)() else: # Compute out_specs using out_specs_fn and inputs. - out_specs = specialization.out_specs_fn(*args, **kwargs) + args_specs, kwargs_specs = tree_util.tree_map( + _get_spec, (args, kwargs) + ) + out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs) # Type checking is ignored to silence mypy error: Incompatible types # in assignment (expression has type "list[Any]", variable has type # "tuple[ShapeDtypeStruct, ...]") [assignment] diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index ced50b6eee3c..7e7654d4642a 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -51,8 +51,22 @@ def _get_cpu_device_map() -> dict[int, jax.Device]: # associated with colocated_python. When deserializing on the colocated_python # executor, it should be the CPU backend visible to the user function running # under colocated_python. - for backed in xb.backends().values(): - for d in backed._get_all_devices(): # pylint: disable=protected-access + + # Look for CPU devices in the default backend. + for d in xb.local_devices()[0].client._get_all_devices(): # pylint: disable=protected-access + if d.device_kind == "cpu": + if d.id in cpu_device_map: + raise ValueError( + f"Multiple CPU devices with id {d.id} found:" + f" {cpu_device_map[d.id]} and {d}" + ) + cpu_device_map[d.id] = d + if cpu_device_map: + return cpu_device_map + + # Fall back to searching CPU devices in all backends. + for backend in xb.backends().values(): + for d in backend._get_all_devices(): # pylint: disable=protected-access if d.device_kind == "cpu": if d.id in cpu_device_map: raise ValueError( @@ -87,7 +101,7 @@ def make_device_list(device_ids: Sequence[int]) -> DeviceList: devices = np.vectorize(lambda device_id: cpu_device_map[device_id])( device_ids ) - return DeviceList(devices) + return DeviceList(tuple(devices)) device_ids = [d.id for d in device_list] return make_device_list, (device_ids,) diff --git a/tests/BUILD b/tests/BUILD index 97f8a3634d99..ce887181b3f3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1562,6 +1562,7 @@ exports_files( "api_test.py", "array_test.py", "cache_key_test.py", + "colocated_python_test.py", "compilation_cache_test.py", "memories_test.py", "pmap_test.py", diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index f86a68a998f3..787d97613a15 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -34,15 +34,20 @@ def _colocated_cpu_devices( devices: Sequence[jax.Device], ) -> Sequence[jax.Device]: """Returns CPU devices colocated with the given devices.""" - # TODO(hyeontaek): Use `colocated_python.colocated_cpu_devices(devices)` once - # PjRt-IFRT prepares CPU devices by its own. - cpu_backend_devices = jax.local_devices(backend="cpu") - device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + try: + return colocated_python.colocated_cpu_devices(devices) + except (ValueError, AttributeError): + # PjRt-IFRT prepares CPU devices by its own. + # TODO(hyeontaek): Remove this fallback path once PjRt-IFRT prepares CPU + # devices by its own. + cpu_backend_devices = jax.local_devices(backend="cpu") + device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + + available_devices = devices[: min(len(cpu_backend_devices), len(devices))] + return [ + cpu_backend_devices[device_index_map[d.id]] for d in available_devices + ] - available_devices = devices[:min(len(cpu_backend_devices), len(devices))] - return [ - cpu_backend_devices[device_index_map[d.id]] for d in available_devices - ] @contextlib.contextmanager def _count_colocated_python_specialization_cache_miss() -> list[int]: @@ -79,8 +84,8 @@ class ColocatedPythonTest(jtu.JaxTestCase): def setUp(self): super().setUp() - if xla_extension_version < 298: - self.skipTest("Requires xla_extension_version >= 298") + if xla_extension_version < 300: + self.skipTest("Requires xla_extension_version >= 300") def testMakeColocatedPythonProgram(self): def add_one(x): @@ -88,11 +93,11 @@ def add_one(x): cpu_devices = _colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) - aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) + sds = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) pickled_function = serialization._serialize(add_one) program = ifrt_programs.make_colocated_python_program( - "add_one", pickled_function, [cpu_devices[0]], [aval], [aval] + "add_one", pickled_function, [cpu_devices[0]], [sds], [sds] ) del program @@ -107,10 +112,12 @@ def add_one(x): with _count_colocated_python_specialization_cache_miss() as count: out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, np.array(2)) self.assertEqual(count[0], 1) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, np.array(2)) self.assertEqual(count[0], 1) @@ -125,10 +132,12 @@ def add_one(x): with _count_colocated_python_specialization_cache_miss() as count: out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count[0], 1) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count[0], 1) @@ -154,10 +163,12 @@ def make_zero(): with _count_colocated_python_specialization_cache_miss() as count: make_zero = make_zero.specialize(devices=cpu_devices[:1]) out = make_zero() + out = jax.device_get(out) self.assertEqual(out, np.array(0)) self.assertEqual(count[0], 1) out = make_zero() + out = jax.device_get(out) self.assertEqual(out, np.array(0)) self.assertEqual(count[0], 1) @@ -172,10 +183,12 @@ def add_one(x): with _count_colocated_python_specialization_cache_miss() as count: out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, np.array(2)) self.assertEqual(count[0], 1) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, np.array(2)) self.assertEqual(count[0], 1) @@ -184,10 +197,12 @@ def add_one(x): x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count[0], 2) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count[0], 2) @@ -203,22 +218,26 @@ def add_one(x): with _count_colocated_python_specialization_cache_miss() as count: add_one = add_one.specialize(out_specs_fn=lambda x: x) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, np.array(2)) self.assertEqual(count[0], 1) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, np.array(2)) self.assertEqual(count[0], 1) # Different input tree structure and dtype/shape. - x = [np.array(1), (np.array(2), {"v": jnp.array(3)})] + x = [np.array(1), (np.array(2), {"v": np.array(3)})] x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count[0], 2) out = add_one(x) + out = jax.device_get(out) self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count[0], 2) From 2a4a0e8d6fb36b59f9c6f24e0018d42c8c8d8ee9 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 5 Dec 2024 11:32:43 -0800 Subject: [PATCH 632/698] [jax:custom_partitioning] Implement SdyShardingRule to support Shardy custom_partitioning. The parsing of the sharding rule string very closely follows how einops parses their rules in einops/parsing.py. When a SdyShardingRule object is constructed, we check the syntax of the Einsum like notation string and its consistency with the user provided factor_sizes, and report errors accordingly. This is done during f.def_partition. When SdyShardingRule.build is called, during JAX to MLIR lowering, we check the consistency between the Einsum like notation string, the factor_sizes and the MLIR operation, and report errors accordingly. PiperOrigin-RevId: 703187962 --- jax/BUILD | 1 + jax/_src/custom_partitioning_sharding_rule.py | 380 +++++++++++++++++ tests/BUILD | 10 + .../custom_partitioning_sharding_rule_test.py | 396 ++++++++++++++++++ 4 files changed, 787 insertions(+) create mode 100644 jax/_src/custom_partitioning_sharding_rule.py create mode 100644 tests/custom_partitioning_sharding_rule_test.py diff --git a/jax/BUILD b/jax/BUILD index 053b05027a2e..31020eb1d385 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -193,6 +193,7 @@ py_library_providing_imports_info( "_src/custom_batching.py", "_src/custom_derivatives.py", "_src/custom_partitioning.py", + "_src/custom_partitioning_sharding_rule.py", "_src/custom_transpose.py", "_src/debugging.py", "_src/dispatch.py", diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py new file mode 100644 index 000000000000..5193c9126bb7 --- /dev/null +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -0,0 +1,380 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implements SdyShardingRule.""" + +from collections import OrderedDict + +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import sdy + + +_CompoundFactor = tuple[str, ...] +_DimMapping = tuple[str | _CompoundFactor, ...] + +# A single character replacement for ... to simplify parsing. +_ELLIPSIS: str = "…" + +# A prefix for names of batching dimension factors, used for expanding the +# leading ... into factors. +_BATCHING_DIM_FACTOR_PREFIX = "?" + +def _get_batching_dim_factor_name(batch_dim_order : int): + """Constructs a factor name for a batching dimension. + + We expand the leading ... into factors representing the batching dimensions + to support building the MLIR representation for the sharding rule. For this + reason, we construct a factor name that won't be used by users for the + batching dimensions. + """ + return f"{_BATCHING_DIM_FACTOR_PREFIX}{batch_dim_order}" + +def _parse_values( + rule: str, +) -> tuple[_DimMapping, ...]: + """Parses the LHS or RHS of an Einsum notation like string. + + Converts each operand or result in the Einsum notation like string to a tuple + of _DimMapping. This very closely follows how einops parses their rules in + einops/parsing.py. + + Args: + rule: The Einsum notation for the operands or results of an operation. + + Returns: + The tuple of values. + + Raises: + ValueError: If the rule is not balanced or contains unknown characters. + """ + + # Remove unnecessary spaces in the rule to simplify the parsing process. + words = rule.split() + rule = " ".join(words) + + # Similar to einops rules, an empty LHS/RHS has a single scalar value. + if not rule: + return ((),) + + all_values = [] + # Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the + # value may have 0 or more leading dimensions. + value = [] + current_factor = None + # A value of None indicates the current dimension is not a compound dimension, + # while a value of [] indicates that we have just started parsing a compound + # dimension. + current_compound_dim: list[str] | None = None + + def add_factor(x): + if current_compound_dim is None: + value.append(x) + else: + current_compound_dim.append(x) + + for char in rule: + if char == _ELLIPSIS: + if (current_factor is not None or current_compound_dim is not None + or value): + raise ValueError( + "Ellipsis can only be used at the beginning of a dimension") + add_factor(_ELLIPSIS) + continue + if char in "(), ": + if current_factor is not None: + add_factor(current_factor) + current_factor = None + if char == "(": + if current_compound_dim is not None: + raise ValueError( + "Compound factors should be one level, nested brackets are not" + " allowed") + current_compound_dim = [] + elif char == ")": + if current_compound_dim is None: + raise ValueError("Brackets are not balanced") + if len(current_compound_dim) <= 1: + raise ValueError("Brackets should contain at least two factors") + value.append(tuple(current_compound_dim)) + current_compound_dim = None + elif char == ",": + all_values.append(tuple(value)) + value = [] + elif char == "_" or char.isdigit() or char.isalpha(): + if current_factor is None: + if str.isdigit(char): + raise ValueError(f"Factor names have to start with a letter, but got '{char}'") + current_factor = char + else: + current_factor += char + else: + raise ValueError(f"Unknown character '{char}'") + + if current_compound_dim is not None: + raise ValueError(f"Brackets are not balanced in rule: '{rule}'") + if current_factor is not None: + add_factor(current_factor) + all_values.append(tuple(value)) + + return tuple(all_values) + + +class SdyShardingRule: + """A representation for Shardy sharding rule. + + A SdyShardingRule includes an Enisum notation like string and an optional + list of factor sizes. A factor is a name in the Einsum notation. If a factor + is only used in compound factors, its size must be specified. + + SdyShardingRule examples: + + * Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k') + * Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k') + * A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j') + * Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2) + * An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...') + """ + + def __init__(self, rule: str, **factor_sizes): + """Constructs a SdyShardingRule object from the Einsum notation like string. + + This is done by verifying that the input Einsum notation like string and + with optional factor sizes represents a valid sharding rule and converting + it to an internal representation. + + Args: + rule: The Einsum notation like string for an operation. + **factor_sizes: The optional factor sizes. + + Raises: + ValueError: If there is any problem with the rule or factor_sizes. + """ + if not isinstance(rule, str): + raise TypeError(f"rule must be a str, but got {type(rule)}") + if not all(isinstance(size, int) for size in factor_sizes.values()): + raise TypeError( + f"factor_sizes must be a dict of str to int, but got {factor_sizes}") + + # Replace ... with a single char to simplify parsing. + if _ELLIPSIS in rule: + raise ValueError(f"Unknown character '{_ELLIPSIS}'") + if "." in rule: + rule = rule.replace("...", _ELLIPSIS) + if "." in rule: + raise ValueError("Character '.' must be used inside ellipsis '...'") + + try: + operands, results = rule.split("->") + except ValueError as e: + raise ValueError(f"There is no -> in rule: '{rule}'") from e + + self.operands = _parse_values(operands) + self.results = _parse_values(results) + + # Find all factors and mark whether their size can be inferred. + factors_inferrable = dict() + for value in self.operands + self.results: + for dim in value: + if dim == _ELLIPSIS: + continue + if isinstance(dim, str): + factors_inferrable[dim] = True + else: + for factor in dim: + if factor not in factors_inferrable.keys(): + factors_inferrable[factor] = False + + # Check that factors in factor_sizes are used in the rule. + for factor in factor_sizes: + if factor not in factors_inferrable: + raise ValueError( + f"Factor {factor} is not used in the rule, but size is provided") + + # Check that factors that are used for a whole dimension aren't in + # factor_sizes and factors that are never used for a whole dimension are + # in factor_sizes. + for factor, inferrable in factors_inferrable.items(): + if factor not in factor_sizes and not inferrable: + raise ValueError( + f"Factor {factor} is only used in compound factors; must specify" + " its size") + if factor in factor_sizes and inferrable: + raise ValueError( + f"Factor {factor} represents a whole dimension; do not specify its" + " size") + + self.factor_sizes = factor_sizes + + def __str__(self): + return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})" + + def build( + self, + operand_types: list[ir.Type], + result_types: list[ir.Type],) -> ir.Attribute: + """Builds the MLIR representation for the sharding rule. + + This is done by verifying that the rule is consistent with the types of + the operation and converting the Einsum notation like string to + OpShardingRuleAttr. + """ + if len(self.operands) != len(operand_types): + raise ValueError( + f"Sharding rule has {len(self.operands)} operands, but the operation" + f" has {len(operand_types)} operands" + ) + if len(self.results) != len(result_types): + raise ValueError( + f"Sharding rule has {len(self.results)} results, but the operation" + f" has {len(result_types)} results" + ) + + factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() + types = operand_types + result_types + UNKNOWN = -1 # Representation for unknown factor size or factor index. + + def get_message_for_value(i): + if i >= len(operand_types): + return f"{i - len(operand_types)}th result" + else: + return f"{i}th operand" + + def get_rank_for_value(i): + return ir.ShapedType(types[i]).rank + + def get_size_for_value_dim(i, j): + return ir.ShapedType(types[i]).shape[j] + + def add_factor(factor, size): + """Adds a factor to factors_to_indices_sizes. + + `size` may be a dimensions size, a user specified factor size, or UNKNOWN + if a factor is first used as in a compound factor and then used for a + whole dimension. + """ + factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) + if factor_index != UNKNOWN: + # Not the first time seeing the factor. + if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: + factor_or_batching_dim = ( + f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor + else f"Batching dimension {factor[1:]}") + raise ValueError( + f"{factor_or_batching_dim} corresponds to two sizes:" + f" {factor_size} and {size}") + if size != UNKNOWN and factor_size == UNKNOWN: + factors_to_indices_sizes[factor] = [factor_index, size] + else: + # First time seeing the factor. + factor_index = len(factors_to_indices_sizes) + factors_to_indices_sizes[factor] = [factor_index, size] + + def add_batching_dim_factor(batch_dim_order, factor_size): + ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) + add_factor(ellipsis_batch_dim_name, factor_size) + + def build_dim_mapping_for_compound_factors(i, j, factors): + accumulated_size = 1 + all_indices = [] + for factor in factors: + factor_index, factor_size = factors_to_indices_sizes[factor] + accumulated_size *= factor_size + all_indices.append(factor_index) + + dim_size = get_size_for_value_dim(i, j) + if accumulated_size != dim_size: + raise ValueError( + f"{get_message_for_value(i)} actual size {dim_size} doesn't match" + f" the size {accumulated_size} derived from the compound factors" + f" {factors}") + + return sdy.DimMappingAttr.get(factor_indices=all_indices) + + # Add factors and their sizes in the order they appear in the rule, + # including the batching dimensions represented by ellipsis. + ellipsis_rank = None + for i, value in enumerate(self.operands + self.results): + if value and value[0] == _ELLIPSIS: + has_ellipsis = True + value = value[1:] + else: + has_ellipsis = False + rule_rank = len(value) + op_rank = get_rank_for_value(i) + # The number of dimensions represented by ellipsis. + current_ellipsis_rank = 0 + if has_ellipsis and op_rank > rule_rank: + current_ellipsis_rank = op_rank - rule_rank + if has_ellipsis: + if ellipsis_rank is None: + ellipsis_rank = current_ellipsis_rank + elif ellipsis_rank != current_ellipsis_rank: + raise ValueError( + "Ellipsis represents different number of leading dimensions" + f" {ellipsis_rank} and {current_ellipsis_rank}") + rule_rank += current_ellipsis_rank + if rule_rank != op_rank: + msg = get_message_for_value(i) + raise ValueError( + f"Sharding rule {msg} has rank {rule_rank}, but the operation" + f" {msg} has rank {op_rank}") + + for j in range(current_ellipsis_rank): + add_batching_dim_factor(j, get_size_for_value_dim(i, j)) + + for j, dim in enumerate(value): + if isinstance(dim, str): + add_factor( + dim, get_size_for_value_dim(i, j + current_ellipsis_rank)) + else: + for factor in dim: + add_factor(factor, self.factor_sizes.get(factor, UNKNOWN)) + + # Build the tensor mappings for each operand and result. + tensor_mappings = [] + for i, value in enumerate(self.operands + self.results): + dim_mappings = [] + + if value and value[0] == _ELLIPSIS: + value = value[1:] + if ellipsis_rank is None: + current_ellipsis_rank = 0 + else: + current_ellipsis_rank = ellipsis_rank + else: + current_ellipsis_rank = 0 + + for j in range(current_ellipsis_rank): + dim_mappings.append( + sdy.DimMappingAttr.get(factor_indices=[ + factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) + + for j, dim in enumerate(value): + if isinstance(dim, str): + dim_mappings.append( + sdy.DimMappingAttr.get( + factor_indices=[factors_to_indices_sizes[dim][0]])) + else: + dim_mappings.append( + build_dim_mapping_for_compound_factors( + i, j + current_ellipsis_rank, dim)) + + tensor_mappings.append( + sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) + + op_sharding_rule = sdy.OpShardingRuleAttr.get( + factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], + operand_mappings=tensor_mappings[0:len(operand_types)], + result_mappings=tensor_mappings[len(operand_types):]) + return op_sharding_rule diff --git a/tests/BUILD b/tests/BUILD index ce887181b3f3..f80f17e54455 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1557,6 +1557,16 @@ jax_multiplatform_test( tags = ["multiaccelerator"], ) +jax_py_test( + name = "custom_partitioning_sharding_rule_test", + srcs = ["custom_partitioning_sharding_rule_test.py"], + deps = [ + "//jax", + "//jax:experimental", + "//jax:test_util", + ], +) + exports_files( [ "api_test.py", diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py new file mode 100644 index 000000000000..2aac4e04862f --- /dev/null +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -0,0 +1,396 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from jax._src import test_util as jtu +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import sdy +from jax._src.custom_partitioning_sharding_rule import SdyShardingRule +from jax._src.lib.mlir.dialects import hlo as stablehlo + + +class SdyShardingRuleTest(jtu.JaxTestCase): + + def test_rule_is_not_a_str(self): + with self.assertRaisesRegex(TypeError, "rule must be a str"): + SdyShardingRule(1) + + def test_factor_sizes_is_not_a_proper_dict(self): + with self.assertRaisesRegex( + TypeError, "factor_sizes must be a dict of str to int"): + SdyShardingRule("i->j", i="j") + + def test_sharding_rule_ellipsis_not_complete(self): + with self.assertRaisesRegex( + ValueError, "Character '.' must be used inside ellipsis '...'"): + SdyShardingRule(".i -> j") + + def test_sharding_rule_invalid_factor_name(self): + with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): + SdyShardingRule("2i -> j") + + def test_sharding_rule_missing_results(self): + with self.assertRaisesRegex(ValueError, "There is no -> in rule"): + SdyShardingRule("i") + + def test_sharding_rule_inbalenced_brackets(self): + with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): + SdyShardingRule("i j, k)->j") + + def test_sharding_rule_inbalenced_brackets2(self): + with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): + SdyShardingRule("i (j k->j") + + def test_sharding_rule_empty_compound_dim(self): + with self.assertRaisesRegex( + ValueError, "Brackets should contain at least two factors"): + SdyShardingRule("i ( ) j k->j") + + def test_sharding_rule_one_factorcompound_dim(self): + with self.assertRaisesRegex( + ValueError, "Brackets should contain at least two factors"): + SdyShardingRule("i (j ) k->j") + + def test_sharding_rule_nested_brackets(self): + with self.assertRaisesRegex( + ValueError, "Compound factors should be one level"): + SdyShardingRule("i (j (k))->j") + + def test_sharding_rule_unknown_char(self): + with self.assertRaisesRegex(ValueError, "Unknown character"): + SdyShardingRule("i; j->j") + + def test_sharding_rule_unknown_single_char_ellipse(self): + with self.assertRaisesRegex(ValueError, "Unknown character"): + SdyShardingRule("…j->…j") + + def test_sharding_rule_ellipsis_not_leading_dim(self): + with self.assertRaisesRegex( + ValueError, "Ellipsis can only be used at the beginning of a dimension"): + SdyShardingRule("i ... -> j") + + def test_sharding_rule_ellipsis_inside_compound_dim(self): + with self.assertRaisesRegex( + ValueError, "Ellipsis can only be used at the beginning of a dimension"): + SdyShardingRule("i, (..., j) -> j") + + def test_sharding_rule_scalar_operand_scalar_result(self): + rule = SdyShardingRule("->") + self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})") + + def test_sharding_rule_one_scalar_operand(self): + rule = SdyShardingRule("i j, , k->j") + self.assertEqual( + str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})") + + def test_sharding_rule_factor_size_not_used(self): + with self.assertRaisesRegex(ValueError, "Factor k is not used"): + SdyShardingRule("i->j", k=10) + + def test_sharding_rule_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule("i->j", i=10) + + def test_sharding_rule_compound_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule("(i j) -> i", i=10, j=20) + + def test_sharding_rule_factor_sizes_missing(self): + with self.assertRaisesRegex( + ValueError, + "Factor k is only used in compound factors; must specify its size"): + SdyShardingRule("i j -> (j k)") + + def test_sharding_rule_factor_elementwise_add(self): + rule = SdyShardingRule("... i j, ...i j -> ...i j") + self.assertEqual( + str(rule), + "SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i'," + " 'j'),), {})") + + def test_sharding_rule_factor_vector_scalar_add(self): + rule = SdyShardingRule("...i, -> ...i") + self.assertEqual( + str(rule), + "SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})") + + def test_sharding_rule_factor_reshape_combining(self): + rule = SdyShardingRule("i j -> (i j)") + self.assertEqual( + str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})") + + def test_sharding_rule_factor_reshape_reordering(self): + rule = SdyShardingRule("(j i) -> (i j)", i=10, j=20) + self.assertEqual( + str(rule), + "SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':" + " 20})") + + def test_sharding_rule_factor_compound_then_individual(self): + rule = SdyShardingRule("(i j) (j k) i -> j k") + self.assertEqual( + str(rule), + "SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})") + + def test_sharding_rule_factor_individual_then_compound(self): + rule = SdyShardingRule("i j k -> (i j) (j k)") + self.assertEqual( + str(rule), + "SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})") + + def test_sharding_rule_factor_infer_k(self): + rule = SdyShardingRule("_i (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20) + self.assertEqual( + str(rule), + "SdyShardingRule((('_i', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" + ",), {'k': 10, 'm': 10, 'bar_24': 20})") + + +class SdyShardingRuleConversionTest(jtu.JaxTestCase): + + def run(self, result=None): + with ir.Context() as ctx, ir.Location.unknown(ctx): + sdy.register_dialect(ctx) + stablehlo.register_dialect(ctx) + module = ir.Module.create() + with ir.InsertionPoint(module.body): + super().run(result) + + def get_tensor_type(self, shape): + return ir.RankedTensorType.get(shape, ir.F32Type.get()) + + def create_tensor_value(self, shape): + return ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type(shape)], + attributes=dict(call_target_name=ir.StringAttr.get("dummy_target")) + ).result + + def test_conversion_rule_op_mismatch_in_operands_num(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("i j-> i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule has 1 operands, but the operation has 2 operands"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_operands_rank(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("i j, i j k-> i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 1th operand has rank 3, but the operation 1th " + "operand has rank 2"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_results_num(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, + opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("i j, i j -> i j, i j") + with self.assertRaisesRegex( + ValueError, + "Sharding rule has 2 results, but the operation has 1 results"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_rule_op_mismatch_in_results_dim(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j, i j -> i j k") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 0th result has rank 3, but the operation 0th " + "result has rank 2"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_factor_has_two_sizes(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 64))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j, i j -> i j") + with self.assertRaisesRegex( + ValueError, + "Factor j corresponds to two sizes: 32 and 64"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_batching_dim_has_two_sizes(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 64))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("..., ... -> ...") + with self.assertRaisesRegex( + ValueError, + "Batching dimension 1 corresponds to two sizes: 32 and 64"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,],) + + def test_conversion_compound_dimension_size_mismatch(self): + opnd = self.create_tensor_value((2, 4)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((9,))], + operands=[opnd,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j -> (i j)") + with self.assertRaisesRegex( + ValueError, + "0th result actual size 9 doesn't match the size 8 derived from the" + " compound factors"): + rule.build( + [result.operands[0].type], + [result.result.type,]) + + def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16,)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("..., ... -> ...") + with self.assertRaisesRegex( + ValueError, + "Ellipsis represents different number of leading dimensions 2 and 1"): + rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + + def test_conversion_elementwise_rule_scalar_instance(self): + opnd0 = self.create_tensor_value(()) + opnd1 = self.create_tensor_value(()) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type(())], + operands=[opnd0, opnd1], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("..., ... -> ...") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([], [])->([])>") + + def test_conversion_elementwise_rule_2D_instance(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("..., ... -> ...") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>") + + def test_conversion_vector_scalar_add_2D_instance(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value(()) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = SdyShardingRule("..., -> ...") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>") + + def test_conversion_reshape_rule(self): + opnd0 = self.create_tensor_value((2, 4)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((8,))], + operands=[opnd0,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("i j -> (i j)") + mlir_rule = rule.build( + [result.operands[0].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>") + + def test_conversion_contracting_dim_matmul(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((32, 8)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 8))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = SdyShardingRule("... contracting_dim, contracting_dim k -> ... k") + mlir_rule = rule.build( + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 23d5c10ff0704f66ad7ec65a8cdcd09bd2420591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 5 Dec 2024 11:37:42 -0800 Subject: [PATCH 633/698] [Mosaic:TPU] Fix fully replicated relayout It was incorrect since batch dims are not replicated PiperOrigin-RevId: 703189919 --- .../tpu/transforms/apply_vector_layout.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 50a7d57346a6..5cbb5e620c88 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -6588,15 +6588,20 @@ FailureOr> relayout(RewriteContext &ctx, /*use_implicit_shape=*/true); } if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() && - !src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) { + !src.offsets()[1].has_value()) { // A fully replicated value is always easy to relayout - // It would be nice to be able to assert this here, but given replicated - // values our rules can introduce equivalent expressions. - // assert all(t is src_tiles_list[0] for t in src_tiles_list) xla::Array dst_tiles( - /*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape), - /*value=*/src_tiles.data()[0]); - return assemble_with_mask_check(dst_tiles); + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + SmallVector idxs; + dst_tiles.Each([&](const absl::Span src_idx, Value *vreg) { + idxs.assign(src_idx.begin(), src_idx.end()); + dst.eraseImplicit(idxs); + src.insertImplicit(idxs, 0); + *(idxs.end() - 2) = 0; + *(idxs.end() - 1) = 0; + *vreg = src_tiles(idxs); + }); + return assemble_with_mask_check(dst_tiles, /*use_implicit_shape=*/true); } // Consider (1,128),-2 -> (8,128). In this case we can change the implicit From d88ef23a63c68dc148975c6a5f44b650594500f3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Dec 2024 12:56:07 -0800 Subject: [PATCH 634/698] array API: improve test coverage --- .github/workflows/jax-array-api.yml | 4 ++-- tests/array_api_skips.txt | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 54a2bf469a38..92fef2cc29af 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -38,11 +38,11 @@ jobs: - name: Install dependencies run: | python -m pip install .[ci] - python -m pip install -r array-api-tests/requirements.txt + python -m pip install pytest-xdist -r array-api-tests/requirements.txt - name: Run the test suite env: ARRAY_API_TESTS_MODULE: jax.numpy JAX_ENABLE_X64: 'true' run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt + pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index e1d4c35eae68..2f8d4d1c666f 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -13,3 +13,9 @@ array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted + +# clip out dtype has ambiguous semantics (https://github.com/numpy/numpy/issues/24976) +array_api_tests/test_operators_and_elementwise_functions.py::test_clip + +# JAX raises a ValueError rather than the expected IndexError for out-of-bound axis +array_api_tests/test_manipulation_functions.py::test_expand_dims From 7e6620a57775084dfa8d438ae4fd27f3ef365018 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 4 Dec 2024 13:20:44 -0500 Subject: [PATCH 635/698] JAX release 0.4.36. --- jax/version.py | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/version.py b/jax/version.py index 3e8a8291ec8d..941b34f1226f 100644 --- a/jax/version.py +++ b/jax/version.py @@ -137,7 +137,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.35" +_minimum_jaxlib_version = "0.4.36" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index a3b54f7aa94f..ea42d625eadc 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.35' +_current_jaxlib_version = '0.4.36' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.35' -_libtpu_version = '0.0.2' +_libtpu_version = '0.0.5' _libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup' def load_version_module(pkg_path): From 259194a69f52a06847a9ff11eb268072e91fd65f Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 5 Dec 2024 11:44:15 -0800 Subject: [PATCH 636/698] [Pallas] Fix shard_axis in dma_start interpret mode rule. PiperOrigin-RevId: 703192497 --- jax/_src/pallas/mosaic/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 9ea2b59f66f1..0a7bd371a639 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals, if device_id_len > 1 or len(nonempty_axes) > 1: raise NotImplementedError("Meshes with more than 1 named dimension not " "implemented in dma_start_p") - shard_axis = nonempty_axes[0].name + shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) else: raise ValueError(f"Unknown device_id_type: {device_id_type}") From fd42b561d62194f0e04757561369f19e686e5858 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 5 Dec 2024 14:17:31 -0800 Subject: [PATCH 637/698] [Pallas] Fix shard_axis in dma_start interpret mode rule. PiperOrigin-RevId: 703249904 --- jax/_src/pallas/mosaic/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 9ea2b59f66f1..0a7bd371a639 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals, if device_id_len > 1 or len(nonempty_axes) > 1: raise NotImplementedError("Meshes with more than 1 named dimension not " "implemented in dma_start_p") - shard_axis = nonempty_axes[0].name + shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) else: raise ValueError(f"Unknown device_id_type: {device_id_type}") From 651ab1887453685ee21533bbb16c25fe63e8a166 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Thu, 5 Dec 2024 15:00:21 -0800 Subject: [PATCH 638/698] [Mosaic:TPU] Fix elementwise inference with i1s PiperOrigin-RevId: 703263310 --- .../tpu/transforms/infer_vector_layout.cc | 59 ++++++++++++------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index bc733742df19..dd63ba66cb9c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -187,7 +187,7 @@ class VectorLayoutInferer { false_ty.getElementTypeBitWidth() == kNativeBitwidth, "Only 32-bit select supported"); } - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + if (inferElementwise(&any_op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -198,7 +198,7 @@ class VectorLayoutInferer { auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth() : op.getIn().getType().getIntOrFloatBitWidth(); if (in_bitwidth == 1) { - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + if (inferElementwise(&any_op).failed()) { return failure(); } } else { @@ -214,7 +214,7 @@ class VectorLayoutInferer { TPU_CHECK_OP(static_cast(lhs_ty) == static_cast(rhs_ty), "Only one side of cmp is a vector?"); // TODO(tlongeri): Check that TPU generation supports comparison. - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + if (inferElementwise(&any_op).failed()) { return failure(); } } else if (auto op = dyn_cast(any_op)) { @@ -1726,7 +1726,7 @@ class VectorLayoutInferer { return success(); } - LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) { + LogicalResult inferElementwise(Operation *op) { TPU_CHECK_OP(op->getNumResults() == 1, "only one result supported"); TPU_CHECK_OP(op->getNumOperands() > 0, "elementwise ops with no operands unsupported"); @@ -1735,26 +1735,45 @@ class VectorLayoutInferer { std::optional out_layout_candidate; std::optional out_layout; SmallVector, 4> in_layouts; - int64_t bit_width = -1; + int64_t bitwidth = -1; + // Find the bitwidth of the operands/results. They must all be the same + // except for the case of i1s, which use a "fake" bitwidth for layouts. + // They can be relayouted (in principle) to any other fake bitwidth, so we + // don't commit to their bitwidth. See comments in VectorLayout class. + for (Value val : llvm::concat(op->getOperands(), op->getResults())) { + if (const VectorType vty = dyn_cast(val.getType())) { + const int64_t val_bitwidth = vty.getElementTypeBitWidth(); + if (val_bitwidth != 1) { + if (bitwidth == -1) { + bitwidth = val_bitwidth; + } else if (bitwidth != val_bitwidth) { + return op->emitOpError( + "Mismatched bitwidth in elementwise for non-i1 " + "operands/results"); + } + } + } + } for (int64_t i = 0; i < op->getNumOperands(); ++i) { if (auto vty = dyn_cast(op->getOperand(i).getType())) { - if (bit_width == -1) { - bit_width = vty.getElementTypeBitWidth(); - } - TPU_CHECK_OP( - !check_bitwidth || bit_width == vty.getElementTypeBitWidth(), - "Generic elementwise rule only supports operands of same width"); auto some_layout = getLayout(op->getOperand(i)); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); auto &layout = *some_layout; - // If the input is fully replicated, don't use it to commit to any - // layout. Replicated values are easy to relayout. - if (is_fully_replicated(some_layout)) { + if (bitwidth == -1) { + // All operands/results are i1s, just commit to the first bitwidth + DCHECK(!out_layout.has_value()); + bitwidth = layout.bitwidth(); + out_layout = layout; + in_layouts.push_back(layout); + } else if (bitwidth != layout.bitwidth()) { + DCHECK_EQ(vty.getElementTypeBitWidth(), 1); + in_layouts.push_back(std::nullopt); + } else if (is_fully_replicated(some_layout)) { + // If the input is fully replicated, don't use it to commit to any + // layout. Replicated values are easy to relayout. in_layouts.push_back(std::nullopt); out_layout_candidate = layout; - continue; - } - if (!out_layout) { + } else if (!out_layout) { // TODO(apaszke): There are probably smarter ways to choose layout. out_layout = layout; in_layouts.push_back(some_layout); @@ -1768,8 +1787,9 @@ class VectorLayoutInferer { // any replication bits that might have been present in out_layout, // since there is no guarantee that the conflicting inputs could // even become replicated. + DCHECK_EQ(out_layout->bitwidth(), bitwidth); out_layout = - VectorLayout(out_layout->bitwidth(), + VectorLayout(bitwidth, {out_layout->offsets()[0].value_or(0), out_layout->offsets()[1].value_or(0)}, out_layout->tiling(), out_layout->implicit_dim()); @@ -1784,9 +1804,6 @@ class VectorLayoutInferer { } Layout final_out_layout = std::nullopt; if (auto out_vty = dyn_cast(op->getResult(0).getType())) { - TPU_CHECK_OP( - !check_bitwidth || bit_width == out_vty.getElementTypeBitWidth(), - "Generic elementwise rule can't change element type width"); if (out_layout) { final_out_layout = *out_layout; } else if (out_layout_candidate) { From 84f3f992175964aa8298c4a87f778df9620ae098 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Dec 2024 15:29:42 -0800 Subject: [PATCH 639/698] [pallas] fix jumble test flakiness * Enable interpret mode in tests * Ensure that the kernel is run multiple times where weve seen data corruption * Use masked comparison - prior comparison was reading garbage data as we were basically relying on past behavior of how uninitialized memory was behaving. * This was being hidden by a cache, where the interpret test, which always has 0.0 for uninitialized memory was being hit first, where TPU does not have the same behavior. PiperOrigin-RevId: 703272002 --- tests/pallas/pallas_jumble_test.py | 97 ++++++++++++++++-------------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py index 8452d1ee7264..509ef08a987f 100644 --- a/tests/pallas/pallas_jumble_test.py +++ b/tests/pallas/pallas_jumble_test.py @@ -41,6 +41,21 @@ floatx = dtypes.canonicalize_dtype(jnp.float64) +def _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res, ref +): + total_columns = col_grid_size * 128 + mask = jnp.zeros((len(ragged_shape), row_count, total_columns), dtype=bool) + + for i, r in enumerate(ragged_shape): + mask = mask.at[i, :, : r * 128].set(True) + + res_valid = jnp.where(mask, res, -1) + ref_valid = jnp.where(mask, ref, -1) + + np.testing.assert_allclose(res_valid, ref_valid) + + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -104,24 +119,16 @@ def invoke_kernel(x): axis_size=3, )(x) - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - - def correct(v): - return np.count_nonzero(v == jnp.sin(1.0)) - - for b, batch in enumerate(res): - ragged_val = ragged_shape[b] - for r, row in enumerate(batch): - row_total = ragged_val * 128 - self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}") + ref = jax.vmap( + jnp.sin, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) - self.assertEqual(correct(res), ragged_total) + _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res.data, ref.data + ) def test_vmap_jumble_over_add_kernel(self): if not jtu.test_device_matches(["tpu"]): @@ -156,36 +163,34 @@ def invoke_kernel(x, y): (8, col_grid_size * 128), dtype=jnp.float32 ), grid=(1, col_grid_size), - interpret=False, + interpret=self.INTERPRET, )(x, y) - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) + # We've had this test fail with data corruption due to multiple + # invocations, so we run it k times to make sure it's not setting up + # memory incorrectly for subsequent invocations. + for _ in range(4): + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - - def correct(v): - return np.count_nonzero(v == 2.0) - - for r, row in enumerate(res): - ragged_val = ragged_shape[r] - row_total = ragged_val * 128 * row_count - self.assertEqual(correct(row), row_total) - for col in row: - col_total = ragged_val * 128 - self.assertEqual(correct(col), col_total) - - self.assertEqual(np.count_nonzero(res == 2.0), ragged_total) + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + + ref = jax.vmap( + lambda x, y: x + y, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x, y) + _assert_ragged_equal_with_elementwise_mask( + row_count, col_grid_size, ragged_shape, res, ref.data + ) def test_vmap_jumble_over_sin_kernel_grid_remapping(self): if not jtu.test_device_matches(["tpu"]): @@ -212,7 +217,7 @@ def invoke_kernel(x): out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), grid=(1, 5), - interpret=False, + interpret=self.INTERPRET, )(x) with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): @@ -280,7 +285,7 @@ def matmul( ), grid=grid, input_output_aliases={2: 0}, - interpret=False, + interpret=self.INTERPRET, )(x, y, x_sentinel) # TODO(mvoz): parameterize this shape? From 1ca8903a3c3528da8c85f53dc1840ae9232ae1b8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Dec 2024 15:45:56 -0800 Subject: [PATCH 640/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/fa3369103478bb0b98a900c21658f2aca2e73319. PiperOrigin-RevId: 703276684 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index db34354f42c5..bb4e14ee447b 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "30f22f4d6cb523e035c237f30aeac4e00ae34821" -XLA_SHA256 = "546dc97a5bee684b3baf1c14c00ef6c73f18c717ebb97c000a35f683bf53c244" +XLA_COMMIT = "fa3369103478bb0b98a900c21658f2aca2e73319" +XLA_SHA256 = "a9db6376115ae898c3eff4a2ca8f0e71e6eff79240d8b9c5929aaf923f7d86d0" def repo(): tf_http_archive( From ba626fa6501c228b51928e44dc5e8ec31da9f706 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 6 Dec 2024 05:57:30 -0800 Subject: [PATCH 641/698] Bump JAX version after release. PiperOrigin-RevId: 703472753 --- CHANGELOG.md | 4 +++- jax/_src/tree_util.py | 6 +++--- jax/version.py | 2 +- setup.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d0f97f439d..2e411e27faee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,9 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.36 +## jax 0.4.37 + +## jax 0.4.36 (Dec 5, 2024) * Breaking Changes * This release lands "stackless", an internal change to JAX's tracing diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 73cff5aa8042..bb9924f8bb72 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -291,13 +291,13 @@ def register_pytree_node( """ if xla_extension_version >= 299: default_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] ) none_leaf_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] ) dispatch_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] ) else: default_registry.register_node(nodetype, flatten_func, unflatten_func) diff --git a/jax/version.py b/jax/version.py index 941b34f1226f..9da3d63f8708 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.36" +_version = "0.4.37" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index ea42d625eadc..dfe64c4d83ac 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ _current_jaxlib_version = '0.4.36' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.35' +_latest_jaxlib_version_on_pypi = '0.4.36' _libtpu_version = '0.0.5' _libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup' From 9081e85d6864db03b84e226b262f29ca44950999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Fri, 6 Dec 2024 06:49:15 -0800 Subject: [PATCH 642/698] Activate Schur Decomposition to XLA's FFI PiperOrigin-RevId: 703484916 --- jax/_src/export/_export.py | 1 + .../cpu_schur_lapack_gees.py | 215 ++++++++++++++++++ jax/_src/lax/linalg.py | 50 ++-- jaxlib/lapack.py | 141 +++++++----- tests/export_back_compat_test.py | 18 +- 5 files changed, 354 insertions(+), 71 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index e3508639fe15..c521becb76d2 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -998,6 +998,7 @@ def _check_lowering(lowering) -> None: "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", "lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi", + "lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py index d7e6e5a1bc48..309aa73f20ba 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py @@ -241,3 +241,218 @@ mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0bOO\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\x06\x06\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\x0b\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgees\x00", xla_call_module_version=6, ) # End paste + +data_2024_11_29 = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], + [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], + [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], + [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]]),), + expected_outputs=(array([[ 3.2464249196572972e+01+0.j, -1.3416407864998739e+01+0.j, + -1.2558842947806125e-14+0.j, -7.3490869705474997e-15+0.j], + [ 0.0000000000000000e+00+0.j, -2.4642491965729798e+00+0.j, + -2.5534994473279107e-15+0.j, -1.3671521621839345e-16+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + -1.8779126463272594e-15+0.j, 7.2486619604759691e-16+0.j], + [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, + 0.0000000000000000e+00+0.j, 4.8523679991768567e-16+0.j]]), array([[ 0.11417645138733863+0.j, -0.8288327563197511 +0.j, + 0.5401354211381763 +0.j, -0.09085002384085737+0.j], + [ 0.33000459866554743+0.j, -0.43714638836388686+0.j, + -0.6524649518290251 +0.j, 0.5237265380279561 +0.j], + [ 0.545832745943757 +0.j, -0.04546002040802424-0.j, + -0.31547635975648136+0.j, -0.774903004533341 +0.j], + [ 0.7616608932219662 +0.j, 0.346226347547838 +0.j, + 0.42780589044732925+0.j, 0.3420264903462419 +0.j]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xcomplex> loc("input")) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:5 = stablehlo.custom_call @lapack_zgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor<4xcomplex>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + return %6, %10 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0bO\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02>\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x0b\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_zgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], + [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], + [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], + [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]], dtype=complex64),), + expected_outputs=(array([[ 3.2464264e+01+0.j, -1.3416414e+01+0.j, -2.1337737e-06+0.j, + 1.8261760e-06+0.j], + [ 0.0000000e+00+0.j, -2.4642489e+00+0.j, -6.0543999e-07+0.j, + 4.8744488e-07+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, -6.5878328e-07+0.j, + 3.9895070e-07+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 0.0000000e+00+0.j, + 3.0199919e-07+0.j]], dtype=complex64), array([[ 0.11417647 +0.j, -0.8288329 +0.j, 0.5404726 +0.j, + -0.08882082 +0.j], + [ 0.3300045 +0.j, -0.4371462 +0.j, -0.6544272 +0.j, + 0.52127254 +0.j], + [ 0.54583293 +0.j, -0.045460045-0.j, -0.312564 +0.j, + -0.77608234 +0.j], + [ 0.76166105 +0.j, 0.34622625 +0.j, 0.42651838 +0.j, + 0.34363067 +0.j]], dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xcomplex> loc("input")) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:5 = stablehlo.custom_call @lapack_cgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor<4xcomplex>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc3) + return %6, %10 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa5e-\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x0b\x0b\x01\x05\x0b\x0f\x03)\x17\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x02\x1e\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19C\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f'\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f%\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03-\r\x01#\x19\x03\x0537\r\x03%5\x1d'\r\x03%9\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05EGIK\x1d/\x13\x11V\x1d1\x13\x11N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03[\x15\x03\x01\x01\x01\x03\x0b##_''\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x03\x1b\x13\x01)\x01\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\t\x1b)\x03\x11\t)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x046\x02\x05\x01Q\x03\x07\x01\x07\x04\x0e\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf3\x03%;\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\x0b\x05\x05\x1f\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03)\x05\x0f\x11\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x15\x07\x06\x01\x03\x05\x07\x19\x07\x17\x03F\x01\x0b\x03\x15\x03\x13\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x17\x03\x1d\x07\x06\x01\x03\x05\x07!\t\x1f\x0f\x04\x03\x05\x1b#\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_cgees_ffi\x00\x08=\x11\x05#\x01\x0b+/1;=\x03?\x03A\x11MOQSUWY]\x03!\x05ac\x03)", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]], dtype=float32),), + expected_outputs=(array([[ 3.2464233e+01, -1.3416398e+01, -1.6680369e-05, 4.0411728e-06], + [ 0.0000000e+00, -2.4642496e+00, -1.8640144e-06, 6.7429795e-07], + [ 0.0000000e+00, 0.0000000e+00, -7.2618576e-07, 3.9895073e-07], + [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 3.0443638e-07]], + dtype=float32), array([[-0.11417632 , 0.8288333 , -0.5413438 , 0.08334288 ], + [-0.33000442 , 0.43714583 , 0.65967286 , -0.5146185 ], + [-0.54583275 , 0.045459934, 0.30468878 , 0.7792079 ], + [-0.7616609 , -0.34622616 , -0.4230168 , -0.34793234 ]], + dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf32> loc("input")) -> (tensor<4x4xf32> {jax.result_info = "[0]"}, tensor<4x4xf32> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:6 = stablehlo.custom_call @lapack_sgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc3) + return %6, %10 : tensor<4x4xf32>, tensor<4x4xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\n\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\t\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_sgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_11_29["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgees_ffi'], + serialized_date=datetime.date(2024, 12, 2), + inputs=(array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]),), + expected_outputs=(array([[ 3.2464249196572979e+01, -1.3416407864998748e+01, + 4.7128510442204522e-15, -8.6687960588453852e-15], + [ 0.0000000000000000e+00, -2.4642491965729767e+00, + 1.8990547895861982e-15, -2.4680570671743780e-16], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -1.8780225147134376e-15, -7.2486619604759710e-16], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + 0.0000000000000000e+00, 4.8523923435746521e-16]]), array([[-0.1141764513873386 , 0.8288327563197505 , 0.5401360966805397 , + 0.09084600741204968], + [-0.3300045986655475 , 0.43714638836388714, -0.6524688462214561 , + -0.5237216863090944 ], + [-0.5458327459437569 , 0.04546002040802441, -0.31547059759870844, + 0.774905350382041 ], + [-0.7616608932219663 , -0.34622634754783793, 0.4278033471396243 , + -0.3420296714849957 ]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x4xf64> loc("input")) -> (tensor<4x4xf64> {jax.result_info = "[0]"}, tensor<4x4xf64> {jax.result_info = "[1]"}) { + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %0:6 = stablehlo.custom_call @lapack_dgees_ffi(%arg0) {mhlo.backend_config = {mode = 86 : ui8, sort = 78 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor, tensor) loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.compare EQ, %0#5, %1, SIGNED : (tensor, tensor) -> tensor loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#0, %4 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#1, %8 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc3) + return %6, %10 : tensor<4x4xf64>, tensor<4x4xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":631:13) +#loc3 = loc("jit(func)/jit(main)/schur"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.7.1\x00\x01!\x05\x01\x05\x11\x01\x03\x0b\x03\x0f\x0f\x13\x17\x1b\x1f#'\x03\xa3e+\x01!\x0f\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x17\x0b\x03E\x0fO\x0b/\x0fO\x0f\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b/\x1f\x1b\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x17#\x0b\x0b\x01\x05\x0b\x0f\x03'\x17\x0f\x07\x07\x07\x0f\x13\x07\x07\x17\x17\x1b\x07\x13\x13\x13\x13\x0f\x13\x02\x1a\x04\x1d\x1b\x1d\x1f\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x15\x11\x01\x00\x05\x17\x05\x19\x05\x1b\x1d\x15\x03\x05\x1d\x03\x03\x19E\x05\x1f\x05!\x17\x1f\xde\t\x1b\x05#\x1f%\x01\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d%\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03/\r\x01#\x1b\x03\x0559\r\x03%7\x1d'\r\x03%;\x1d)\x1d+\x1d-\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x07\t\x00\x00\x00\x00\r\x05GIKM\x1d/\x13\x13V\x1d1\x13\x13N\x0b\x03\x1d3\x1d5\x03\x01\x05\x01\x03\x03#\x03\x03]\x15\x03\x01\x01\x01\x03\r##''))\t\x07\x07\x01\x01\t\x01\x02\x02)\x05\x11\x11\t)\x01\x1d\x0b\x13\x01)\x01\t)\x03\x11\t!\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x05\x05\x05\x05\x1b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x0b)\x03\x01\x15)\x01\r)\x03\t\x15\x04:\x02\x05\x01Q\x03\x07\x01\x07\x04\x12\x02\x03\x01\x05\tP\x03\x03\x07\x04\xf5\x03';\x03\x0b\x13\x00\x05B\x03\x05\x03\x0f\x05B\x03\x07\x03\x07\x0bG\x01\x17\t\r\x05\x05\x11\x11\x07\x07\x03\x01\x03F\x01\x0b\x03\x07\x03\x05\rF\x01\r\x03'\x05\x11\x13\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x17\x07\x06\x01\x03\x05\x07\x1b\x07\x19\x03F\x01\x0b\x03\x17\x03\x15\x03F\x01\x0b\x03\x05\x03\x03\x03F\x01\x0f\x03\x19\x03\x1f\x07\x06\x01\x03\x05\x07#\t!\x0f\x04\x03\x05\x1d%\x06\x03\x01\x05\x01\x00\xe6\x057#\x03\x0b\x0b\x0f\x0b\t\t!i5)\r\x13%)9\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00input\x00mhlo.backend_config\x00jit(func)/jit(main)/schur\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00[0]\x00[1]\x00main\x00public\x00mode\x00sort\x00\x00lapack_dgees_ffi\x00\x08=\x11\x05#\x01\x0b-13=?\x03A\x03C\x11OQSUWY[_\x03!\x05ac\x03+", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 62cb72c69fd7..780759e69ee5 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -48,6 +48,7 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2616,37 +2617,54 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals, batch_dims = operand_aval.shape[:-2] a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - gees_result = lapack.gees_hlo(operand_aval.dtype, operand, + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else () + gees_result = lapack.gees_hlo(*ctx_args, operand_aval.dtype, operand, jobvs=compute_schur_vectors, sort=sort_eig_vals, select=select_callable, a_shape_vals=a_shape_vals) - - # Number of return values depends on value of sort_eig_vals. - T, vs, *_, info = gees_result + if jaxlib_version >= (0, 4, 37) and not ctx.is_forward_compat(): + schur_form, schur_vectors, _eig_vals, _selected_eig_vals, info = gees_result + else: + # Number of return values depends on value of sort_eig_vals. + schur_form, schur_vectors, *_, info = gees_result ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") - select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - T = _broadcasting_select_hlo( + select_schur_form_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) + schur_form = _broadcasting_select_hlo( ctx, - mlir.broadcast_in_dim(ctx, ok, select_T_aval, - broadcast_dimensions=range(len(batch_dims))), - select_T_aval, - T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]) - output = [T] + mlir.broadcast_in_dim( + ctx, + ok, + select_schur_form_aval, + broadcast_dimensions=range(len(batch_dims)), + ), + select_schur_form_aval, + schur_form, + ctx.avals_out[0], + _nan_like_hlo(ctx, ctx.avals_out[0]), + ctx.avals_out[0], + ) + output = [schur_form] if compute_schur_vectors: select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - vs = _broadcasting_select_hlo( + schur_vectors = _broadcasting_select_hlo( ctx, - mlir.broadcast_in_dim(ctx, ok, select_vs_aval, - broadcast_dimensions=range(len(batch_dims))), + mlir.broadcast_in_dim( + ctx, ok, select_vs_aval, broadcast_dimensions=range(len(batch_dims)) + ), select_vs_aval, - vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]) + schur_vectors, + ctx.avals_out[1], + _nan_like_hlo(ctx, ctx.avals_out[1]), + ctx.avals_out[1], + ) - output.append(vs) + output.append(schur_vectors) return output diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 9eef615ccc07..fa7ef99af7f3 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -27,6 +27,7 @@ from jaxlib import xla_client from .cpu import _lapack +from .cpu._lapack import schur from .cpu._lapack import eig from .hlo_helpers import ( custom_call, hlo_u8, hlo_s32, @@ -353,9 +354,9 @@ def geev_hlo(ctx, dtype, input, *, # # gees : Schur factorization -def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None, +def gees_hlo(ctx, dtype, a, *, jobvs=True, sort=False, select=None, a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() + fn_base = prepare_lapack_call(fn_base="gees", dtype=dtype) a_type = ir.RankedTensorType(a.type) etype = a_type.element_type assert len(a_shape_vals) >= 2 @@ -368,70 +369,108 @@ def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None, raise NotImplementedError( "The sort feature of LAPACK's gees routine is not implemented.") - jobvs = ord('V' if jobvs else 'N') - sort = ord('S' if sort else 'N') + mode = ( + schur.ComputationMode.kComputeSchurVectors + if jobvs + else schur.ComputationMode.kNoComputeSchurVectors + ) + sort = schur.Sort.kSortEigenvalues if sort else schur.Sort.kNoSortEigenvalues + if ctx.is_forward_compat(): + fn = fn_base + workspaces: list[ShapeTypePair] + eigvals: list[ShapeTypePair] + if not np.issubdtype(dtype, np.complexfloating): + workspaces = [(a_shape_vals, etype)] + workspace_layouts = [layout] + eigvals = [(batch_dims_vals + (n,), etype)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + else: + workspaces = [(a_shape_vals, etype), + ([n], ir.ComplexType(etype).element_type), + ] + workspace_layouts = [layout, [0]] + eigvals = [(batch_dims_vals + (n,), etype)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] - if dtype == np.float32: - fn = "lapack_sgees" - elif dtype == np.float64: - fn = "lapack_dgees" - elif dtype == np.complex64: - fn = "lapack_cgees" - elif dtype == np.complex128: - fn = "lapack_zgees" - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") + i32_type = ir.IntegerType.get_signless(32) - workspaces: list[ShapeTypePair] + scalar_layout = [] + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + shape_type_pairs = workspaces + eigvals + [ + (a_shape_vals, etype), + (batch_dims_vals, i32_type), + (batch_dims_vals, i32_type)] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + out = custom_call( + fn, + result_types=result_types, + operands=[ + batch_size_val, + ensure_hlo_s32(n), + hlo_u8(mode.value), + hlo_u8(sort.value), + # TODO: figure out how to put the callable select function here + a + ], + operand_layouts=[scalar_layout] * 4 + [layout], + result_layouts=workspace_layouts + eigvals_layouts + [ + layout, + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + ], + operand_output_aliases={4: 0}, + result_shapes=result_shapes, + ).results + if sort == schur.Sort.kSortEigenvalues: + return (out[0], out[3], out[4], out[5]) + else: + return (out[0], out[3], out[5]) + fn = fn_base + "_ffi" eigvals: list[ShapeTypePair] - if not np.issubdtype(dtype, np.complexfloating): - workspaces = [(a_shape_vals, etype)] - workspace_layouts = [layout] - eigvals = [(batch_dims_vals + (n,), etype)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - else: - workspaces = [(a_shape_vals, etype), - ([n], ir.ComplexType(etype).element_type), - ] - workspace_layouts = [layout, [0]] - eigvals = [(batch_dims_vals + (n,), etype)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] + is_complex = np.issubdtype(dtype, np.complexfloating) + eigvals = [(batch_dims_vals + (n,), etype)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + if not is_complex: + eigvals = eigvals * 2 + eigvals_layouts = eigvals_layouts * 2 i32_type = ir.IntegerType.get_signless(32) - - scalar_layout = [] - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - shape_type_pairs = workspaces + eigvals + [ + shape_type_pairs = [ + (a_shape_vals, etype), (a_shape_vals, etype), + *eigvals, (batch_dims_vals, i32_type), - (batch_dims_vals, i32_type)] + (batch_dims_vals, i32_type), + ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( fn, result_types=result_types, - operands=[ - batch_size_val, - ensure_hlo_s32(n), - hlo_u8(jobvs), - hlo_u8(sort), - # TODO: figure out how to put the callable select function here - a - ], - operand_layouts=[scalar_layout] * 4 + [layout], - result_layouts=workspace_layouts + eigvals_layouts + [ - layout, - tuple(range(num_bd - 1, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), + operands=[a], + # TODO(paruzelp): Use FFI execution context to put `select` + operand_layouts=[layout], + result_layouts=[ + layout, + layout, + *eigvals_layouts, + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), ], - operand_output_aliases={4: 0}, + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={ + "mode": _enum_to_char_attr(mode), + "sort": _enum_to_char_attr(sort), + }, + api_version=4, ).results - if sort == ord('S'): - return (out[0], out[3], out[4], out[5]) + # out: Schur Form, Schur Vectors, Eigenvalues, Selected Eigenvalues, Info + if is_complex: + return out[0], out[1], out[2], out[3], out[4] else: - return (out[0], out[3], out[5]) + return out[0], out[1], (out[2], out[3]), out[4], out[5] # gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form. diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index c20fc95350c2..b16cdc787345 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -68,6 +68,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions +from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -119,6 +120,7 @@ def test_custom_call_coverage(self): cpu_eig_lapack_geev.data_2024_08_19, cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, + cpu_schur_lapack_gees.data_2024_11_29, cpu_svd_lapack_gesdd.data_2024_08_13, cpu_hessenberg_lapack_gehrd.data_2024_08_31, ] @@ -611,10 +613,10 @@ def compute_max_backward_error(operand, reconstructed_operand): self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False), np.asarray(out), atol=1e-4, rtol=1e-4)) - @jtu.parameterized_filterable( - kwargs=[ - dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) - for dtype_name in ("f32", "f64", "c64", "c128")]) + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", + dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) @jax.default_matmul_precision("float32") def test_cpu_schur_lapack_gees(self, dtype_name="f32"): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: @@ -640,6 +642,14 @@ def check_schur_results(res_run, res_expected, *, rtol, atol): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_schur_results) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 37) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_schur_results) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) From fac1b1a78084c93cb0b0c8314dbc264b3f0a422a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 6 Dec 2024 07:16:58 -0800 Subject: [PATCH 643/698] Set -Werror=mismatched-tags on clang. This means we see a helpful compiler error rather than a cryptic linker error if struct/class tags are mismatched. PiperOrigin-RevId: 703491110 --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index 6ef7d4493937..8b53bd475e5b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -104,6 +104,8 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 build:clang --copt=-Wno-gnu-offsetof-extensions # Disable clang extention that rejects unknown arguments. build:clang --copt=-Qunused-arguments +# Error on struct/class mismatches, since this causes link failures on Windows. +build:clang --copt=-Werror=mismatched-tags # Configs for CUDA build:cuda --repo_env TF_NEED_CUDA=1 From bae660002a6bef54fe023bc1b7fb11f09fac583a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 6 Dec 2024 07:31:39 -0800 Subject: [PATCH 644/698] [pallas:mosaic_gpu] `FragmentedArray.reduce_sum` now returns a `FragmentedArray` This aligns it with the `reduce` method and also makes it clear that the reduction always produces a scalar. PiperOrigin-RevId: 703494443 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 +++----- jax/experimental/mosaic/gpu/fragmented_array.py | 4 ++-- tests/mosaic/gpu_test.py | 6 +----- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e04b0ff1adfe..ff5667d75e3b 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1256,13 +1256,11 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: case mgpu.WGStridedFragLayout(): - if axes != (0,): - raise NotImplementedError("No support for axes other than 0 yet") + if set(axes) != set(range(x_aval.ndim)): + raise NotImplementedError("No support for axes yet") scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype) with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]: - return mgpu.FragmentedArray.splat( - x.reduce_sum(scratch), (), is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) + return x.reduce_sum(scratch) case mgpu.WGMMA_LAYOUT: if axes != (x_aval.ndim - 1,): raise NotImplementedError diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 2da3de70658b..9e1585be27eb 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1129,7 +1129,7 @@ def upcast_to_bf16(reg, high): ) # NOTE: scratch can be reused immediately once this function returns. - def reduce_sum(self, scratch) -> ir.Value: + def reduce_sum(self, scratch): if ir.FloatType.isinstance(self.mlir_dtype): op = addf elif ir.IntegerType.isinstance(self.mlir_dtype): @@ -1168,7 +1168,7 @@ def reduce_sum(self, scratch) -> ir.Value: utils.warpgroup_barrier() result = memref.load(scratch, [zero_index]) utils.warpgroup_barrier() # Make sure everyone is done using scratch. - return result + return FragmentedArray.splat(result, (), is_signed=self.is_signed) def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): if isinstance(op, str): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 80c4048720a3..26d6bfafd84d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1481,11 +1481,7 @@ def kernel(ctx, src, dst, scratch): src = mgpu.FragmentedArray.load_strided( src, is_signed=utils.is_signed(dtype) ) - acc = mgpu.FragmentedArray.splat( - src.reduce_sum(scratch), - (m,), - is_signed=src.is_signed - ) + acc = src.reduce_sum(scratch).broadcast((m,)) acc.store_untiled(dst) in_shape = jax.ShapeDtypeStruct((m, n), dtype) From 08d31d0fe14e97893ac07c46b5d3543096fcea6d Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 6 Dec 2024 08:12:13 -0800 Subject: [PATCH 645/698] [mosaic_gpu] Emit a slightly more informative error message in `FragmentedArray._pointwise` PiperOrigin-RevId: 703504247 --- jax/experimental/mosaic/gpu/fragmented_array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 9e1585be27eb..7e40d86f2ae4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -662,7 +662,8 @@ def _pointwise(self, op, *other, output_is_signed: bool | None = None): if isinstance(o.layout, WGSplatFragLayout): if not o.layout.can_broadcast_to(self.shape): - raise ValueError("Can't broadcast shape.") + raise ValueError( + f"Cannot broadcast shape {self.shape} to layout {o.layout}") o = FragmentedArray.splat( o.registers.flat[0], shape=self.shape, From af5013568a90aa1d5daca8ea48f5bc8a3eee7b5b Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 2 Dec 2024 09:17:51 -0800 Subject: [PATCH 646/698] Fix error when swapping a ref with a trivial indexing transform. Without this fix, the added test case fails with: ``` ... jax/_src/state/discharge.py:416: in _swap_discharge_rule z, x_new = _swap_discharge(x, val, idx, tree) jax/_src/state/discharge.py:421: in _swap_discharge return transform_swap_array(x, transforms, val) jax/_src/state/discharge.py:396: in transform_swap_array result_val = lax_slicing.dynamic_update_slice( jax/_src/lax/slicing.py:215: in dynamic_update_slice start_indices = _dynamic_slice_indices(operand, start_indices) ... AttributeError: 'NoneType' object has no attribute 'ndim' ``` from encountering a None when computing the `result_val`. --- jax/_src/state/discharge.py | 2 +- tests/state_test.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 2c38878c7112..7bf0835fb156 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -364,7 +364,7 @@ def transform_swap_array(x, transforms, val): case indexing.NDIndexer(): indexer = transform if _is_trivial_indexer(indexer): - _results.append(None) + _results.append(_results[-1]) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. diff --git a/tests/state_test.py b/tests/state_test.py index 44caded0ca64..a930fe293709 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -639,6 +639,26 @@ def f(a_ref): refval, = core.eval_jaxpr(discharged_jaxpr, discharged_consts, inval) self.assertTrue((refval == inval.at[jnp.array([0, 1])].set(1.)).all()) + def test_discharge_swap(self): + def f(a_ref): + a = ref_swap( + a_ref.at[0:4, 0:3, 0:2].at[1:3, :, 0], + (slice(None), slice(1, 3)), + jnp.zeros((2, 2), jnp.float32)) + return [a + 1] + in_avals = [shaped_array_ref((4, 3, 2), jnp.float32)] + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(f), in_avals) + + discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) + self.assertLen(discharged_jaxpr.invars, 1) + self.assertLen(discharged_jaxpr.outvars, 2) + + inval = jnp.arange(24., dtype=jnp.float32).reshape((4, 3, 2)) + outval, refval = core.eval_jaxpr(discharged_jaxpr, (), inval) + self.assertArraysEqual(outval, inval[1:3, 1:3, 0] + 1) + self.assertArraysEqual(refval, inval.at[1:3, 1:3, 0].set(0)) + def test_discharge_addupdate(self): def f(a_ref, b): ref_addupdate(a_ref, (), b + 1) From 8b656206af1372fae670c10052e640f91c1e364e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 6 Dec 2024 09:15:48 -0800 Subject: [PATCH 647/698] [Pallas MGPU] Use multiple k/v_consumed_barriers in the attention kernel There's nothing technically preventing the compute threads from running ahead and signalling the consumption of k/v twice in case the memory thread ends up being temporarily starved. I don't think this was ever a problem in practice since the GPU hardware scheduler is surprisingly fair, but it's good not to have races :) PiperOrigin-RevId: 703520322 --- .../pallas/ops/gpu/attention_mgpu.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 294ef153ff93..e8c818b884b5 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -74,7 +74,7 @@ def kernel(q_ref, k_ref, v_ref, out_ref, scoped): wg_idx = lax.axis_index("wg") qo_smem2, k_smem, v_smem = smem_buffers k_barriers, v_barriers, q_barriers = buffer_barriers - k_consumed_barrier, v_consumed_barrier = consumed_barriers + k_consumed_barriers, v_consumed_barriers = consumed_barriers def perform_schedule_barrier(): plgpu.barrier_arrive(schedule_barrier) plgpu.barrier_wait(schedule_barrier) @@ -116,7 +116,7 @@ def compute_qk(acc_ref): perform_schedule_barrier() return acc_ref[...] qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32)) - plgpu.barrier_arrive(k_consumed_barrier) + plgpu.barrier_arrive(k_consumed_barriers.at[slot]) # Softmax # We keep m scaled by log2e to use FMA instructions when computing p. @@ -153,7 +153,7 @@ def compute_pv(acc_ref): def _wait(): plgpu.barrier_wait(k_barriers.at[wait_slot]) acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc)) - plgpu.barrier_arrive(v_consumed_barrier) + plgpu.barrier_arrive(v_consumed_barriers.at[slot]) return acc, m_i, l_i if kv_seq_len % block_kv: raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") @@ -184,17 +184,12 @@ def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) - plgpu.barrier_wait(k_consumed_barrier) + plgpu.barrier_wait(k_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) - plgpu.barrier_wait(v_consumed_barrier) + plgpu.barrier_wait(v_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) - def kv_epilogue(i, _): - plgpu.barrier_wait(k_consumed_barrier) - plgpu.barrier_wait(v_consumed_barrier) - lax.fori_loop(0, max_concurrent_steps, kv_epilogue, None) - def run(refs): q_ref, k_ref, v_ref, out_ref = refs @@ -210,7 +205,6 @@ def run(refs): @pl.core_map(mesh) def _kernel_entry(): compute_wgs = 2 - barrier_2wg = plgpu.Barrier(num_arrivals=compute_wgs) tiling = plgpu.TilingTransform((64, 64)) swizzle = plgpu.SwizzleTransform(128) qo_scratch = plgpu.SMEM( @@ -233,8 +227,8 @@ def _kernel_entry(): plgpu.Barrier(1, num_barriers=max_concurrent_steps), plgpu.Barrier(1, num_barriers=compute_wgs), ), - (barrier_2wg, barrier_2wg), - barrier_2wg, + (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, + plgpu.Barrier(num_arrivals=compute_wgs), ) _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) From eda7506d6bcd4607c0fd38ba69991e3226677d0a Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 6 Dec 2024 09:18:28 -0800 Subject: [PATCH 648/698] [Pallas MGPU] Disable XLA:GPU autotuning in attention tests We don't care about performance of the reference impl, we only use it for correctness testing. More importantly, it works around a deadlock at compile time that sometimes happens when testing large batch sizes. PiperOrigin-RevId: 703521029 --- tests/pallas/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 50c1054ba9fd..fd1166d66df6 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -494,6 +494,7 @@ jax_multiplatform_test( srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"], enable_backends = [], enable_configs = ["gpu_h100_x32"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, tags = [ "manual", "notap", @@ -509,6 +510,7 @@ jax_multiplatform_test( srcs = ["mgpu_attention_test.py"], enable_backends = [], enable_configs = ["gpu_h100_x32"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, deps = [ "//jax:pallas", "//jax:pallas_experimental_gpu_ops", From 0c6b967e86c2a87a2cdfe20d6c5e1fc5950db5da Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 4 Dec 2024 12:30:45 -0600 Subject: [PATCH 649/698] Don't look for CUDA files when building the ROCm wheel --- jaxlib/tools/build_gpu_kernels_wheel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 9a47c6ad5409..36c1b4d2cbfc 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -145,7 +145,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", - f"__main__/jaxlib/cuda/_hybrid.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", From 048dc296b4d817ead7676c47ce1ed3b62d9bb91a Mon Sep 17 00:00:00 2001 From: charleshofer Date: Fri, 6 Dec 2024 11:33:12 -0600 Subject: [PATCH 650/698] Don't look for CUDA files when building the ROCm wheel (#173) --- jaxlib/tools/build_gpu_kernels_wheel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 9a47c6ad5409..36c1b4d2cbfc 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -145,7 +145,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", - f"__main__/jaxlib/cuda/_hybrid.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", From 641a1d53ce89b06376c8646d6896155cc78644d2 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 6 Dec 2024 10:35:18 -0800 Subject: [PATCH 651/698] [Pallas] Add support for run_state to cost estimator. PiperOrigin-RevId: 703543961 --- jax/_src/pallas/cost_estimate.py | 19 +++++++++++++++---- tests/pallas/pallas_cost_estimate_test.py | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index b83c36159555..5b322eedc837 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -13,14 +13,17 @@ # limitations under the License. """Helper tool for automatic cost estimation.""" import dataclasses +import functools import math from typing import Any, Sequence import jax +from jax._src import api_util from jax._src import core as jax_core from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src import pjit +from jax._src.state import discharge from jax._src.pallas import core as pallas_core from jax._src.interpreters import partial_eval as pe from jax._src.util import safe_map @@ -87,10 +90,9 @@ def estimate_cost(fun, *args, **kwargs) -> pallas_core.CostEstimate: A pallas_core.CostEstimate object containing the cost estimate. """ flattened_args, treedef = jax.tree.flatten(args) - def _partial_fun(*flat_args): - return fun(*jax.tree.unflatten(treedef, flat_args), **kwargs) - wrapped_fun = lu.wrap_init( - lambda *args, **kwargs: (_partial_fun(*args, **kwargs),)) + partial_fun = functools.partial(fun, **kwargs) + wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(partial_fun), + treedef) avals = [jax_core.ShapedArray(a.shape, a.dtype) for a in flattened_args] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) estimate = cost_estimate_jaxpr(jax_core.ClosedJaxpr(jaxpr, consts)) @@ -243,3 +245,12 @@ def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): bytes_accessed=inner_cost.bytes_accessed, ) register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) + +def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2): + inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr)) + return CostEstimate( + flops=inner_cost.flops, + transcendentals=inner_cost.transcendentals, + bytes_accessed=inner_cost.bytes_accessed, + ) +register_cost_rule(discharge.run_state_p, _run_state_rule) diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py index fcdeac4cab82..d9eb18e6f540 100644 --- a/tests/pallas/pallas_cost_estimate_test.py +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -19,6 +19,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.pallas import cost_estimate +from jax._src.state import discharge config.parse_flags_with_absl() @@ -91,5 +92,23 @@ def test_integer_pow(self, power, expected_flops_per_element): self.assertEqual(cost.transcendentals, 0) self.assertEqual(cost.bytes_accessed, 80) + def test_run_state(self): + def add_refs(refs): + x_ref, y_ref, z_ref = refs + x = x_ref[:] + y = y_ref[:] + z = x + y + z_ref[:] = z + input_shape = jax.ShapeDtypeStruct((100,), jnp.float32) + cost = cost_estimate.estimate_cost( + discharge.run_state(add_refs), + (input_shape, input_shape, input_shape)) + self.assertEqual(cost.flops, 100) + self.assertEqual(cost.transcendentals, 0) + # TODO(justinfu): This is off by a factor of 2 because run_state + # has all inputs/outputs as both arguments and return values. + self.assertEqual(cost.bytes_accessed / 2, 3 * 4 * 100) + + if __name__ == "__main__": absltest.main() From a13b618c98109aa61b435f432d162bdbd6dc73fd Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Thu, 21 Nov 2024 13:25:07 -0500 Subject: [PATCH 652/698] Document cudaMallocAsync as an experimental feature. --- docs/gpu_memory_allocation.rst | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index dac52c194603..6667589e7b72 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -70,3 +70,31 @@ Common causes of OOM failures memory. Note however, that the algorithm is basic and you can often get better trade-off between compute and memory by disabling the automatic remat pass and doing it manually with `the jax.remat API `_ + + +Experimental features +--------------------- + +Features here are experimental and must be tried with caution. + +``TF_GPU_ALLOCATOR=cuda_malloc_async`` + This replace XLA's own BFC memory allocator with `cudaMallocAsync + `_. + This will remove the big fixed pre-allocation and use a memory pool that grows. + The expected benefit is no need to set `XLA_PYTHON_CLIENT_MEM_FRACTION`. + + The risk are: + + - that memory fragmentation is different, so if you are close to the + limit, the exact OOM case due to fragmentation will be different. + - The allocation time won't be all paid at the start, but be incurred + when the memory pool need to be increased. So you could + experience less speed stability at the start and for benchmarks + it will be even more important to ignore the first few iterations. + + The risks can be mitigated by pre-allocating a signigicant chunk and + still get the benefit of having a growing memory pool. This can be + done with `TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N`. If N is `-1` + it will preallocate the same as what was allocatedy by + default. Otherwise, it is the size in bytes that you want to + preallocate. From 2b2d7cda985f358d7b7e89ca9983ca0c0339bdc7 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 4 Dec 2024 16:38:42 -0800 Subject: [PATCH 653/698] [Pallas] Update TPU documentation --- docs/_static/pallas/vector_layout_example.svg | 1 + docs/pallas/tpu/details.rst | 66 +++++++++++++------ docs/pallas/tpu/sparse.ipynb | 2 + docs/pallas/tpu/sparse.md | 2 + 4 files changed, 52 insertions(+), 19 deletions(-) create mode 100644 docs/_static/pallas/vector_layout_example.svg diff --git a/docs/_static/pallas/vector_layout_example.svg b/docs/_static/pallas/vector_layout_example.svg new file mode 100644 index 000000000000..f1c9403573d8 --- /dev/null +++ b/docs/_static/pallas/vector_layout_example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index b7ce10d564f6..0575806e6037 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -119,24 +119,44 @@ The output reference can be then used as an accumulator for partial results. spilled vector registers) exceeds the size of VMEM. In this case, you will likely see a low-level compiler error message complaining about an out-of-memory error. -Dimension ordering is meaningful -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Array Layouts +^^^^^^^^^^^^^ +Dimension ordering of arrays is meaningful in Pallas. In JAX programs, the ordering of intermediate arrays inside ``jax.jit`` usually has no impact on performance, as the compiler is free to rearrange them. However, as Pallas is meant to expose lower-level capabilities, the dimension order can have great impact on the quality of generated code. -Recall that the TPUs perform bulk of the computation on 2D vector registers. -Pallas TPU will only ever consider mapping the last two dimensions of -intermediate arrays to those vector register dimensions (sublanes and lanes -respectively). An array of shape ``(n, 1, 1)`` is guaranteed to require at least -``n`` vector registers to represent. If ``n`` becomes too large, this can lead -to spills, and potential VMEM OOM errors due to an overly large memory footprint. -But it also might not --- the low-level compiler is free to rearrange the -instructions to lower the register pressure, and is in fact very good at it. -Still, it is a good rule of thumb to keep the last two dimensions large -(especially the last dimension), while keeping the leading dimensions small. +TPUs perform bulk of the computation on 2D vector registers, which are typically of +size 8x128 for 32-bit values (as of TPU v6). +When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``), +the last two dimensions of the array will be tiled into the registers. +Pallas will only ever consider mapping the last two dimensions of +intermediate arrays to the 8x128 vector register dimensions (sublanes and lanes +respectively). + +Here is a graphical example of how a 12x320 array can be tiled using 6 8x128 +tiles: + +.. image:: ../../_static/pallas/vector_layout_example.svg + +Tiled layouts have several import ramifications for kernel writers: + +* The last two axes of an array are treated differently than other + axes. For example, reductions, reshapes, and transposes are generally + more expensive when involving the last two axes. Some reshapes + involving the last two dimensions are not supported and will result in a compiler + error, but are "free" and performed at compile time for other dimensions. +* While sometimes unavoidable, it is generally wasteful to have singleton + dimensions in the last two axes, since they will occupy 1 element out of + the entire tile dimension. Consuming too many registers can + also potentially cause register spills into VMEM which degrades kernel + performance. +* Related to the above point, all vector computation is padded up to the tile + size. Adding a two 1x1 arrays costs as much as adding two 8x128 arrays, and + adding two 8x128x1x1 arrays will be 1024 times as expensive as adding two + 8x128 arrays, since the 8x128x1x1 array will be padded to 8x128x8x128. Multicore TPU configurations ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -196,18 +216,19 @@ for those arguments. But, the ``BlockSpec``\s for all subsequent arguments will receive not only the grid indices, but also the SMEM references to the leading operands. -.. note:: - We are working on implementing examples for this feature. Stay tuned! +See :ref:`pallas_scalar_prefetch_guide` for examples on using this +feature. Supported data types ^^^^^^^^^^^^^^^^^^^^ -At the moment Pallas TPU only supports the following data types: +At the moment Pallas TPU supports the following data types: * ``jnp.float32`` * ``jnp.bfloat16`` * ``jnp.int*`` (all precisions, except for ``jnp.int4``) * ``jnp.uint*`` (all precisions) +* ``jnp.bool_`` Computation placement ^^^^^^^^^^^^^^^^^^^^^ @@ -306,14 +327,13 @@ Array constructors ^^^^^^^^^^^^^^^^^^ All constant array constructors are supported (``jnp.ones``, ``jnp.zeros``, -``jnp.full``). Notably, the ``jax.random`` module is **not** compatible with -Pallas as of today. +``jnp.full``). Reductions ^^^^^^^^^^ -Sum, maximum and minimum reductions are supported, but only on a single array -axis at a time. +``sum``, ``max``, ``min`` (for floating point values) reductions are supported, as well +as ``any`` and ``all`` for boolean values. Integer reductions are not supported. Reductions over the last array dimension are generally the slowest. Reductions over the second last dimension are faster, but still slower than @@ -338,6 +358,14 @@ of an array is when (1) some leading dimensions are flattened onto the second to last dimension, or (2) it adds a dimension that was just removed by a reduction. +Random Number Generation +^^^^^^^^^^^^^^^^^^^^^^^^ + +Pallas supports the most commonly used functions from the ``jax.random`` module, +such as ``uniform``, ``normal``, and ``bernoulli``. The key should be a ``threefry2x32`` key, +which is the default setting in JAX. Keys can be directly passed into a kernel, +or generated inside of a kernel. + Control flow ^^^^^^^^^^^^ diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index a80ba4ebedbb..5b37e7b0574b 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -6,6 +6,8 @@ "id": "ZHuzXqQ-9JUQ" }, "source": [ + "(pallas_scalar_prefetch_guide)=\n", + "\n", "# Scalar Prefetch and Block-Sparse Computation\n", "\n", "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory." diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 2ac25edb5064..36a6e07e9192 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -14,6 +14,8 @@ kernelspec: +++ {"id": "ZHuzXqQ-9JUQ"} +(pallas_scalar_prefetch_guide)= + # Scalar Prefetch and Block-Sparse Computation In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory. From 83c64b23799c281add6d24a1651e96856dfa7f21 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 6 Dec 2024 12:31:34 -0800 Subject: [PATCH 654/698] Add a flag to enable detailed timestamped logging of subprocess commands. This adds a new command-line flag, `--detailed_timestamped_log`, that enables detailed logging of Bazel build commands. When disabled (the default), logging mirrors the output you'd see when running the command directly in your terminal. When this flag is enabled: - Bazel's output is captured line by line. - Each line is timestamped for improved traceability. - The complete log is stored for potential use as an artifact. The flag is disabled by default and only enabled in the CI builds. If you're running locally and enable `detailed_timestamped_log`, you might notice that Bazel's output is not colored. To force a color output, include `--bazel_options=--color=yes` in your command. PiperOrigin-RevId: 703581368 --- build/build.py | 13 +++++++++++-- build/tools/command.py | 13 +++++++------ ci/build_artifacts.sh | 2 +- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/build/build.py b/build/build.py index 25a873d89e24..a6c1a7922b0e 100755 --- a/build/build.py +++ b/build/build.py @@ -123,6 +123,15 @@ def add_global_arguments(parser: argparse.ArgumentParser): help="Produce verbose output for debugging.", ) + parser.add_argument( + "--detailed_timestamped_log", + action="store_true", + help=""" + Enable detailed logging of the Bazel command with timestamps. The logs + will be stored and can be accessed as artifacts. + """, + ) + def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): """Adds all the arguments that applies to the artifact subcommands.""" @@ -399,7 +408,7 @@ async def main(): else: requirements_command.append("//build:requirements.update") - result = await executor.run(requirements_command.get_command_as_string(), args.dry_run) + result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") else: @@ -597,7 +606,7 @@ async def main(): wheel_build_command.append(f"--jaxlib_git_hash={git_hash}") - result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run) + result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) # Exit with error if any wheel build fails. if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") diff --git a/build/tools/command.py b/build/tools/command.py index 48a9bfc1c0d6..cc95d7eea4af 100644 --- a/build/tools/command.py +++ b/build/tools/command.py @@ -75,7 +75,7 @@ def __init__(self, environment: Dict[str, str] = None): """ self.environment = environment or dict(os.environ) - async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: + async def run(self, cmd: str, dry_run: bool = False, detailed_timestamped_log: bool = False) -> CommandResult: """ Executes a subprocess command. @@ -96,14 +96,15 @@ async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: process = await asyncio.create_subprocess_shell( cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE if detailed_timestamped_log else None, + stderr=asyncio.subprocess.PIPE if detailed_timestamped_log else None, env=self.environment, ) - await asyncio.gather( - _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) - ) + if detailed_timestamped_log: + await asyncio.gather( + _process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result) + ) result.return_code = await process.wait() result.end_time = datetime.datetime.now() diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 9f8d54401691..698de38418b7 100644 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -69,7 +69,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then fi # Build the artifact. - python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose + python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. From baedb62b71d9cf32d1922254d8faa3b03903ad77 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 6 Dec 2024 15:50:31 -0800 Subject: [PATCH 655/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/1ff335df59d91650c4c6c9b7d215ed03ad6bf7e9. PiperOrigin-RevId: 703640058 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index bb4e14ee447b..943c60de4479 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fa3369103478bb0b98a900c21658f2aca2e73319" -XLA_SHA256 = "a9db6376115ae898c3eff4a2ca8f0e71e6eff79240d8b9c5929aaf923f7d86d0" +XLA_COMMIT = "1ff335df59d91650c4c6c9b7d215ed03ad6bf7e9" +XLA_SHA256 = "5b039608ca8d903adcdba28f0d88c47dc97d76107b59223ab8ff854cdae43683" def repo(): tf_http_archive( From 861115ad4bf0f57e53f61d4d083cd2bda6877ab5 Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Fri, 6 Dec 2024 17:44:52 -0800 Subject: [PATCH 656/698] Support transfer guard in broadcast_one_to_all(). Fixes https://github.com/jax-ml/jax/issues/25325 PiperOrigin-RevId: 703666450 --- jax/experimental/multihost_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index b38edcaba10a..79989583fc28 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -75,7 +75,7 @@ def pre_jit(x): return host_local_array_to_global_array(inp, global_mesh, pspec) def post_jit(x): - return np.asarray(x.addressable_data(0)) + return jax.device_get(x.addressable_data(0)) in_tree = jax.tree.map(pre_jit, in_tree) out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding( From 1f4d184ac8897270fb9248250d9b8473e8b64f2e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Sat, 7 Dec 2024 11:13:37 -0800 Subject: [PATCH 657/698] Temporarily allow bfloat16 dot algorithms on CPU. Since XLA:CPU doesn't (yet!) support explicit algorithms for controlling the precision of dot products we have a check in JAX that fails when a non-trivial algorithm is specified on CPU. In order to support downstream use cases, this change allows some bfloat16 algorithms to pass through. XLA:CPU "emulates" these algorithms using `F32_F32_F32` with the appropriate casting, so that means that CPU numerics will be different than on other platforms with explicit algorithm support, but it is useful to be able to use these algorithms with the correct input and output casting without requiring platform dependent logic in user code. PiperOrigin-RevId: 703834889 --- jax/_src/lax/lax.py | 2 ++ tests/lax_test.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9c8afe4f9292..972beb04cab4 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3773,6 +3773,8 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): if platform == "cpu" and precision not in { DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16, DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64, + DotAlgorithmPreset.BF16_BF16_F32, DotAlgorithmPreset.BF16_BF16_F32_X3, + DotAlgorithmPreset.BF16_BF16_F32_X6, }: raise ValueError( f"The precision '{precision}' is not supported by dot_general on CPU") diff --git a/tests/lax_test.py b/tests/lax_test.py index 9ef13efc2ed3..360525efab0b 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1082,6 +1082,9 @@ def testDotAlgorithm(self, algorithm, dtype): lax.DotAlgorithmPreset.F16_F16_F16, lax.DotAlgorithmPreset.F32_F32_F32, lax.DotAlgorithmPreset.F64_F64_F64, + lax.DotAlgorithmPreset.BF16_BF16_F32, + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, }: raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") From ad00ee1e06eb8063b8ed081ef410dbf75fd246a3 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 7 Dec 2024 14:59:59 -0800 Subject: [PATCH 658/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ace1e90f1bb5335e1bcb6898d860a1a98d15b358. PiperOrigin-RevId: 703866489 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 943c60de4479..770339520efe 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1ff335df59d91650c4c6c9b7d215ed03ad6bf7e9" -XLA_SHA256 = "5b039608ca8d903adcdba28f0d88c47dc97d76107b59223ab8ff854cdae43683" +XLA_COMMIT = "ace1e90f1bb5335e1bcb6898d860a1a98d15b358" +XLA_SHA256 = "285607be2f3c3915aa4230ab9d72d8fe58be11e9f7b6b287310f706d1b8d2040" def repo(): tf_http_archive( From cc73c50c41dea532ff768e2ca8523c4935da0b96 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 8 Dec 2024 17:39:24 +0100 Subject: [PATCH 659/698] [export] Improved the documentation. In particular added the docstring for `Exported.call` method, and fixed the formatting for `Exported.in_shardings_jax`. --- docs/developer.md | 11 +++--- docs/jax.export.rst | 7 ++-- jax/_src/export/_export.py | 71 +++++++++++++++++++++++--------------- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/docs/developer.md b/docs/developer.md index e6bdf53f1112..e8069b3b5fe6 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -689,22 +689,21 @@ minimization phase. ### Doctests JAX uses pytest in doctest mode to test the code examples within the documentation. -You can run this using +You can find the up-to-date command to run doctests in +[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml). +E.g., you can run: ``` -pytest docs +JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst ``` Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in function docstrings will run correctly. You can run this locally using, for example: ``` -pytest --doctest-modules jax/_src/numpy/lax_numpy.py +JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py ``` -Keep in mind that there are several files that are marked to be skipped when the -doctest command is run on the full package; you can see the details in -[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml) ## Type checking diff --git a/docs/jax.export.rst b/docs/jax.export.rst index 2095758dd3b3..c8feb1d169bd 100644 --- a/docs/jax.export.rst +++ b/docs/jax.export.rst @@ -14,8 +14,11 @@ Classes .. autosummary:: :toctree: _autosummary - Exported - DisabledSafetyCheck +.. autoclass:: Exported + :members: + +.. autoclass:: DisabledSafetyCheck + :members: Functions --------- diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index c521becb76d2..ade75e7446e6 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -203,6 +203,7 @@ class Exported: _get_vjp: Callable[[Exported], Exported] | None def mlir_module(self) -> str: + """A string representation of the `mlir_module_serialized`.""" return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized) def __str__(self): @@ -211,8 +212,8 @@ def __str__(self): return f"Exported(fun_name={self.fun_name}, ...)" def in_shardings_jax( - self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + self, + mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: """Creates Shardings corresponding to self.in_shardings_hlo. The Exported object stores `in_shardings_hlo` as HloShardings, which are @@ -221,30 +222,31 @@ def in_shardings_jax( `jax.device_put`. Example usage: - >>> from jax import export - >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) - >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), - ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) - ... )(np.arange(jax.device_count())) - >>> exp.in_shardings_hlo - ({devices=[8]<=[8]},) - - # Create a mesh for running the exported object - >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) - >>> - # Put the args and kwargs on the appropriate devices - >>> run_arg = jax.device_put(np.arange(jax.device_count()), - ... exp.in_shardings_jax(run_mesh)[0]) - >>> res = exp.call(run_arg) - >>> res.addressable_shards - [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), - Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), - Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), - Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), - Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), - Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), - Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), - Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])] + + >>> from jax import export + >>> # Prepare the exported object: + >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) + >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), + ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) + ... )(np.arange(jax.device_count())) + >>> exp.in_shardings_hlo + ({devices=[8]<=[8]},) + >>> # Create a mesh for running the exported object + >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) + >>> # Put the args and kwargs on the appropriate devices + >>> run_arg = jax.device_put(np.arange(jax.device_count()), + ... exp.in_shardings_jax(run_mesh)[0]) + >>> res = exp.call(run_arg) + >>> res.addressable_shards + [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), + Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), + Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), + Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), + Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), + Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), + Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), + Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])] + """ return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh) for s in self.in_shardings_hlo) @@ -252,7 +254,7 @@ def in_shardings_jax( def out_shardings_jax( self, mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: - """Creates Shardings corresponding to self.out_shardings_hlo. + """Creates Shardings corresponding to `self.out_shardings_hlo`. See documentation for in_shardings_jax. """ @@ -289,6 +291,21 @@ def serialize(self, return serialize(self, vjp_order=vjp_order) def call(self, *args, **kwargs): + """Call an exported function from a JAX program. + + Args: + args: the positional arguments to pass to the exported function. This + should be a pytree of arrays with the same pytree structure as the + arguments for which the function was exported. + kwargs: the keyword arguments to pass to the exported function. + + Returns: a pytree of result array, with the same structure as the + results of the exported function. + + The invocation supports reverse-mode AD, and all the features supported + by exporting: shape polymorphism, multi-platform, device polymorphism. + See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html). + """ return call_exported(self)(*args, **kwargs) From 7062325521f7df18561b0909fe84537134ca7ca7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 8 Dec 2024 14:08:43 -0800 Subject: [PATCH 660/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/fb46636c1f8f88b3bd69bf0523b76c904191d1ad. PiperOrigin-RevId: 704067165 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 770339520efe..e07d72edc453 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "ace1e90f1bb5335e1bcb6898d860a1a98d15b358" -XLA_SHA256 = "285607be2f3c3915aa4230ab9d72d8fe58be11e9f7b6b287310f706d1b8d2040" +XLA_COMMIT = "fb46636c1f8f88b3bd69bf0523b76c904191d1ad" +XLA_SHA256 = "20c5009feca949739a89b4f0077caac9345b623fafbc91f154d59085c4193e23" def repo(): tf_http_archive( From efa35ea9f90629693105ffc382f86c0cec3ae886 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Sun, 8 Dec 2024 20:11:06 -0500 Subject: [PATCH 661/698] Fix type annotation for numpy.linalg.matrix_norm argument 'ord'. --- jax/_src/numpy/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index c01a5d270f0f..8e35560a52ed 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -1517,7 +1517,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: @export -def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array: +def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str | int = 'fro') -> Array: """Compute the norm of a matrix or stack of matrices. JAX implementation of :func:`numpy.linalg.matrix_norm` From 3ec55c772370bb30df6c3e6ea86a9f4313009216 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 9 Dec 2024 02:52:14 -0800 Subject: [PATCH 662/698] [pallas:triton] Add support for `DotAlgorithmPreset` `precision` arguments to `dot`. PiperOrigin-RevId: 704208558 --- jax/_src/pallas/triton/lowering.py | 152 +++++++++++------------------ tests/pallas/pallas_test.py | 40 ++++++++ 2 files changed, 97 insertions(+), 95 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index fe641ba29494..e2376a457cdf 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1983,81 +1983,6 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): return tt_dialect.trans(x, permutation) -def _check_dot_operands( - x_type: ir.RankedTensorType, y_type: ir.RankedTensorType, options: Any -): - # TODO(slebedev): Ensure that the dtypes are supported by CUDA. - return - - -def _dot( - x: ir.Value, - y: ir.Value, - acc: ir.Value | None = None, - *, - allow_tf32: bool = True, - max_num_imprecise_acc: int | None = None, - out_type: ir.Type | None = None, -) -> ir.Value: - if out_type is None: - out_type = ir.F32Type.get() - elif isinstance(out_type, ir.BF16Type): - raise NotImplementedError(f"unsupported output type: {out_type}") - - x_type = ir.RankedTensorType(x.type) - y_type = ir.RankedTensorType(y.type) - if min(*x_type.shape, *y_type.shape) < 16: - raise ValueError("all dimensions of x and y must be >= 16 ") - if x_type.element_type != y_type.element_type: - raise ValueError( - "x and y must have the same element type, but got:" - f" {x_type.element_type} and {y_type.element_type}" - ) - - _check_dot_operands(x_type, y_type, object()) - - element_type = x_type.element_type - if isinstance(element_type, ir.IntegerType): - if element_type.width != 8: - raise TypeError(f"unsupported element type: {element_type}") - element_type = ir.IntegerType.get_signless(32) - elif isinstance(element_type, (ir.F32Type, ir.BF16Type)): - element_type = ir.F32Type.get() - else: - element_type = out_type - - if element_type != out_type: - raise TypeError( - f"output type {out_type} does not match element type {element_type}" - ) - - m, _ = x_type.shape - _, n = y_type.shape - - if acc is None: - acc = _full(ir.RankedTensorType.get([m, n], element_type), 0) - - if max_num_imprecise_acc is None: - if isinstance(element_type, ir.FloatType) and element_type.width == 8: - # TODO(slebedev): Fill in from options. - raise NotImplementedError - else: - max_num_imprecise_acc = 0 - - # Ideally, replace all allow_tf32 usages with InputPrecision directly. - input_precision = tt_dialect.InputPrecision.IEEE - if allow_tf32: - input_precision = tt_dialect.InputPrecision.TF32 - - return tt_dialect.dot( - x, - y, - acc, - max_num_imprecise_acc=max_num_imprecise_acc, - input_precision=input_precision - ) - - _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) @@ -2081,27 +2006,63 @@ def _dot_general_lowering( if b_contract_dim == 1: b = tt_dialect.trans(b, (1, 0)) - if precision is None: - allow_tf32 = True + a_aval, b_aval = ctx.avals_in + [out_aval] = ctx.avals_out + + if precision is None or (precision == lax.DotAlgorithmPreset.DEFAULT): + precision = (lax.Precision.DEFAULT, lax.Precision.DEFAULT) + + if isinstance(precision, lax.DotAlgorithmPreset): + match precision: + case lax.DotAlgorithmPreset.TF32_TF32_F32: + input_precision = tt_dialect.InputPrecision.TF32 + case lax.DotAlgorithmPreset.TF32_TF32_F32_X3: + input_precision = tt_dialect.InputPrecision.TF32x3 + case lax.DotAlgorithmPreset.F32_F32_F32: + input_precision = tt_dialect.InputPrecision.IEEE + case ( + lax.DotAlgorithmPreset.F16_F16_F16 + | lax.DotAlgorithmPreset.F16_F16_F32 + | lax.DotAlgorithmPreset.BF16_BF16_BF16 + | lax.DotAlgorithmPreset.BF16_BF16_F32 + ): + input_precision = None + case _: + raise NotImplementedError(f"Unsupported dot algorithm: {precision}.") + + a = _cast(a, a_aval.dtype, precision.supported_lhs_types[0]) + b = _cast(b, b_aval.dtype, precision.supported_rhs_types[0]) + acc_dtype = precision.accumulation_type + elif isinstance(precision, tuple): + a_precision, b_precision = precision + if a_precision in _TF32_PRECISIONS or b_precision in _TF32_PRECISIONS: + input_precision = tt_dialect.InputPrecision.TF32 + elif a_aval.dtype == jnp.float32: + input_precision = tt_dialect.InputPrecision.IEEE + else: + input_precision = None + + acc_dtype = out_aval.dtype + if acc_dtype != jnp.int32 and acc_dtype != jnp.float16: + acc_dtype = jnp.float32 else: - prec_a, prec_b = precision - allow_tf32 = prec_a in _TF32_PRECISIONS or prec_b in _TF32_PRECISIONS + raise NotImplementedError(f"Unsupported dot precision: {precision}.") - [out_aval] = ctx.avals_out - out_dtype = acc_dtype = out_aval.dtype - if acc_dtype != jnp.int32 and acc_dtype != jnp.float16: - acc_dtype = jnp.dtype(jnp.float32) - - return _cast( - _dot( - a, - b, - allow_tf32=allow_tf32, - out_type=_dtype_to_ir_type(acc_dtype), - ), - acc_dtype, - out_dtype, - ) + a_type = ir.RankedTensorType(a.type) + b_type = ir.RankedTensorType(b.type) + if min(*a_type.shape, *b_type.shape) < 16: + raise ValueError("all dimensions of a and b must be >= 16 ") + if a_type.element_type != b_type.element_type: + raise ValueError( + "a and b must have the same element type, but got:" + f" {a_type.element_type} and {b_type.element_type}" + ) + + m, _ = a_type.shape + _, n = b_type.shape + acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) + return _cast(acc, acc_dtype, out_aval.dtype) def _reduction_lowering(body, ctx: LoweringRuleContext, a, axes): @@ -2623,7 +2584,8 @@ def _i64_constant(v: int) -> ir.Value: return arith_dialect.constant(ir.IntegerType.get_signless(64), v) -def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type: +def _dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type: + dtype = jnp.dtype(dtype) if jnp.issubdtype(dtype, np.integer): # All integer types in Triton are signless. return ir.IntegerType.get_signless(dtype.itemsize * 8) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index bc2c237ffa94..8c5a0a99c279 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -687,6 +687,46 @@ def f(x): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) + @parameterized.parameters( + ("float32", None), + ("float32", jax.lax.Precision.DEFAULT), + ("float32", jax.lax.Precision.HIGH), + ("float32", jax.lax.Precision.HIGHEST), + ("float32", jax.lax.DotAlgorithmPreset.DEFAULT), + ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), + ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), + ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), + ("bfloat16", None), + ("bfloat16", jax.lax.Precision.DEFAULT), + ("bfloat16", jax.lax.Precision.HIGHEST), + ("bfloat16", jax.lax.DotAlgorithmPreset.DEFAULT), + ("bfloat16", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ) + def test_dot_precision(self, dtype, precision): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("`DotAlgorithmPreset` only supported on GPU.") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((32, 64), jnp.float32), + grid=1, + ) + def dot_kernel(x_ref, y_ref, o_ref): + o_ref[()] = pl.dot(x_ref[()], y_ref[()], precision=precision) + + key0, key1 = random.split(random.key(0)) + x = random.normal(key0, (32, 16), dtype=dtype) + y = random.normal(key1, (16, 64), dtype=dtype) + expected = jnp.dot( + x, + y, + precision=jax.lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True From a94474d016658460a443e4c712200b1d329084a4 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 9 Dec 2024 03:25:39 -0800 Subject: [PATCH 663/698] [pallas] Add `DotAlgorithmPreset` note to CHANGELOG. PiperOrigin-RevId: 704216341 --- docs/pallas/CHANGELOG.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2687cbc909fb..94dbeb3aa70d 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,7 +11,16 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c Remember to align the itemized text with the first line of an item within a list. --> -## Released with jax 0.4.35 +## Released with jax 0.4.37 + +* New functionality + + * Added support for `DotAlgorithmPreset` precision arguments for `dot` + lowering on Triton backend. + +## Released with jax 0.4.36 (December 6, 2024) + +## Released with jax 0.4.35 (October 22, 2024) * Removals From adb2bf629c6d3adc815d32ef3040c502928c5de4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 9 Dec 2024 04:32:07 -0800 Subject: [PATCH 664/698] [Mosaic TPU] Allow downgrading the IR during serialization for forward compat This is to uphold the monthly stability promise made by jax.export. PiperOrigin-RevId: 704233290 --- jax/_src/tpu_custom_call.py | 20 ++++- jaxlib/mosaic/dialect/tpu/tpu.td | 5 +- jaxlib/mosaic/dialect/tpu/transforms/serde.cc | 84 ++++++++++++++++--- 3 files changed, 96 insertions(+), 13 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 10a979dffc6a..ccd77af5bef0 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -62,6 +62,11 @@ help="Allow hlo dialects in Mosaic", ) + +# This tracks the latest Mosaic IR version with a monthly delay. +FWD_COMPAT_IR_VERSION = 3 + + tpu_custom_call_p = core.Primitive("tpu_custom_call") tpu_custom_call_p.def_impl( functools.partial(xla.apply_primitive, tpu_custom_call_p)) @@ -407,6 +412,7 @@ def _lower_mosaic_module_to_asm( backend: str, device_type: str | None, kernel_name: str | None, + ir_version: int | None = None, ) -> tuple[ir.Module, tuple[bool, bool, bool, bool]]: has_communication, has_custom_barrier = tpu.private_has_communication( module.operation @@ -438,8 +444,17 @@ def _lower_mosaic_module_to_asm( module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True + # TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37. + if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37): + target_version = "" + else: + target_version = ( + f"target-version={ir_version}" if ir_version is not None else "" + ) try: - pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})") + pipeline = PassManager.parse( + "builtin.module(mosaic-serde{serialize=true " + target_version + "})" + ) pipeline.run(module_op) finally: ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects @@ -506,6 +521,7 @@ def _lower_to_custom_call_config( serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, kernel_name: str | None = None, + ir_version: int | None = None, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -517,6 +533,7 @@ def _lower_to_custom_call_config( backend=backend, device_type=device_type, kernel_name=kernel_name, + ir_version=ir_version, ) return _lowered_to_custom_call_config( lowered_module_asm, @@ -617,6 +634,7 @@ def lower_module_to_custom_call( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, + ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, ) return _tpu_custom_call_lowering( ctx, diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index c5142f48dc1d..0019581921c4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -780,7 +780,10 @@ def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::Fun } def MosaicSerdePass : Pass<"mosaic-serde", "::mlir::ModuleOp"> { - let options = [Option<"serialize", "serialize", "bool", "", "">]; + let options = [ + Option<"serialize", "serialize", "bool", "", "">, + Option<"target_version", "target-version", "int", "", ""> // Only used when serialize=true. + ]; } def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index fd68c9e6c95e..6717e3a3e8ec 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -43,6 +43,8 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; +// When this is bumped, we should file a TODO to update the forward-compatible +// version in tpu_custom_call.py in a month! constexpr int kVersion = 3; StringRef mangle(StringRef name, std::string* storage) { @@ -63,7 +65,7 @@ std::optional demangle(StringRef name) { using rule_type = std::function; -LogicalResult enqueue_dma_rule(Operation* op, int version) { +LogicalResult enqueue_dma_upgrade(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { if (op->getNumOperands() == 3) { // Local DMA. @@ -84,7 +86,14 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) { return success(); } -LogicalResult semaphore_signal_rule(Operation* op, int version) { +LogicalResult enqueue_dma_downgrade(Operation* op, int version) { + if (version < 2) { + return op->emitError("Downgrade to version ") << version << " unsupported"; + } + return success(); +} + +LogicalResult semaphore_signal_upgrade(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { if (op->getNumOperands() == 2) { // Local signal. @@ -92,9 +101,6 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) { EnqueueDMAOp>::getOperandSegmentSizeAttr(), mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); } else if (op->getNumOperands() == 3) { // Remote signal. - // Hardcoding that one optional value is device_id, not core_id. This - // could misinterpret sem_signals where core_id is specified, but - // device_id isn't. op->setAttr(OpTrait::AttrSizedOperandSegments< EnqueueDMAOp>::getOperandSegmentSizeAttr(), mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); @@ -105,7 +111,25 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) { return success(); } -LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) { +LogicalResult semaphore_signal_downgrade(Operation* op, int version) { + if (version < 2) { + auto operands = op->getAttrOfType( + OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr()); + if (!operands || operands.size() != 4) { + return op->emitError("Missing or invalid AttrSizedOperandSegments"); + } + if (operands[3]) { + return op->emitError("Downgrade to version ") + << version << " impossible: core_id is set"; + } + op->removeAttr(OpTrait::AttrSizedOperandSegments< + EnqueueDMAOp>::getOperandSegmentSizeAttr()); + } + return success(); +} + +LogicalResult vector_multi_dim_reduce_upgrade(Operation* op, int version) { // Changed reductions_dims from ArrayAttr of IntegerAttrs to DenseI64ArrayAttr // in version 3. if (version < 3) { @@ -133,21 +157,49 @@ LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) { return success(); } +LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) { + if (version < 3) { + return op->emitError("Downgrade to version ") << version << " unsupported"; + } + return success(); +} + const llvm::StringMap& upgrade_rules() { static auto rules = new llvm::StringMap{ - {EnqueueDMAOp::getOperationName(), enqueue_dma_rule}, - {SemaphoreSignalOp::getOperationName(), semaphore_signal_rule}, + {EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade}, + {SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade}, {vector::MultiDimReductionOp::getOperationName(), - vector_multi_dim_reduce_rule} + vector_multi_dim_reduce_upgrade} }; return *rules; } +const llvm::StringMap& downgrade_rules() { + static auto rules = new llvm::StringMap{ + {EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade}, + {SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_dim_reduce_downgrade}}; + return *rules; +} + struct MosaicSerdePass : public impl::MosaicSerdePassBase { using Base::Base; void runOnOperation() override { ModuleOp module = getOperation(); + if (!serialize.hasValue()) { + module.emitError("serialize option must be specified"); + return signalPassFailure(); + } + int serialize_version = + target_version.hasValue() ? target_version : kVersion; + if (serialize && serialize_version > kVersion) { + module.emitError("The highest supported version is ") + << kVersion << " but requested serialization at version " + << serialize_version; + return signalPassFailure(); + } if (serialize && !module->getContext()->allowsUnregisteredDialects()) { module.emitError() << "Cannot serialize within a context that does not " "allow unregistered dialects."; @@ -159,7 +211,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { module->setAttr( kVersionAttrName, IntegerAttr::get(IntegerType::get(module->getContext(), 64), - kVersion)); + serialize_version)); } else { IntegerAttr version_attr = module->getAttrOfType(kVersionAttrName); @@ -178,7 +230,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { module->removeAttr(kVersionAttrName); } std::string name_storage; - auto result = module.walk([this, &name_storage, version](Operation* op) { + auto result = module.walk([&](Operation* op) { if (isa(op)) { // Don't mangle the ModuleOp itself. return WalkResult::advance(); } @@ -210,6 +262,16 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { auto new_op = Operation::create( op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); + // Downgrade the op to the target version, if needed. + if (serialize && kVersion != serialize_version) { + if (const auto rule = + downgrade_rules().find(op->getName().getStringRef()); + rule != downgrade_rules().end()) { + if (rule->second(new_op, serialize_version).failed()) { + return WalkResult::interrupt(); + } + } + } op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op); op->replaceAllUsesWith(new_op->getResults()); op->erase(); From d474feda9e080a340ffbadcbfd3c2ede84281630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Mon, 9 Dec 2024 04:36:21 -0800 Subject: [PATCH 665/698] Activate Tridiagonal Reduction to XLA's FFI Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction. PiperOrigin-RevId: 704234350 --- jax/_src/export/_export.py | 3 + .../cpu_tridiagonal_lapack_sytrd_hetrd.py | 844 ++++++++++++++++++ jax/_src/lax/linalg.py | 29 +- jaxlib/lapack.py | 110 ++- tests/export_back_compat_test.py | 37 + 5 files changed, 975 insertions(+), 48 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index c521becb76d2..6102269fceac 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -997,6 +997,7 @@ def _check_lowering(lowering) -> None: "lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi", "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", + "lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi", "lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi", "lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi", ] @@ -1022,6 +1023,8 @@ def _check_lowering(lowering) -> None: "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", # schur on CPU "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", + # tridiagonal on CPU + "lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd", # hessenberg on CPU "lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd", # lu on GPU diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py new file mode 100644 index 000000000000..9e245052e03a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py @@ -0,0 +1,844 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32, complex64 + +data_2024_09_03 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zhetrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[-1.6782909868280393 +0.j , + -0.44670237330570184+4.847000766107959j , + 2.05945450900321 -2.2848432268240106j , + -1.852046418980849 +1.672382006137275j ], + [ 8.516713699516982 +0.j , + -2.7881860505313174 +0.j , + 0.9238284715039695 -2.3790501284019947j , + 0.5005102262291599 -1.30066052934836j ], + [-0.12132810525381293-0.2963030371159077j , + -3.6374350042782893 +0.j , + 0.5605752523031344 +0.j , + -2.9865099107523174 +0.5492956557924651j ], + [-0.40379248092949666-0.7813328344426929j , + -0.07101654492399719-0.27208840961051617j, + -7.4654253782049285 +0.j , + -8.172380353916964 +0.j ]], + + [[-3.996403598623405 +0.j , + 0.59408630943699 +2.531609474375295j , + -1.789098034543644 -2.538389274566601j , + -1.291106590337488 +3.1576544511573843j ], + [10.8950662522622 +0.j , + -2.8151642043836693 +0.j , + 6.18998567202382 +1.1866537964613415j , + 3.1900218245393352 +2.7291222716752372j ], + [-0.3142889671188478 -0.37781876498252764j, + 3.049208563595754 +0.j , + -2.4383044880335487 +0.j , + 4.075435464493341 -0.6653616942280807j ], + [ 0.32757687545025194+0.565870910342534j , + 0.8177026465997795 -0.15906305615104555j, + 3.3415143060767125 +0.j , + 4.094619408678314 +0.j ]]]), array([[-1.6782909868280393, -2.7881860505313174, 0.5605752523031344, + -8.172380353916964 ], + [-3.996403598623405 , -2.8151642043836693, -2.4383044880335487, + 4.094619408678314 ]]), array([[ 8.516713699516982 , -3.6374350042782893, -7.4654253782049285], + [10.8950662522622 , 3.049208563595754 , 3.3415143060767125]]), array([[1.0626274644222748+0.06050271598884928j, + 1.834630852474663 +0.18575551495730305j, + 1.981584368497257 +0.19102912741736966j], + [1.0365789616521406-0.40942548304121656j, + 1.0872592163018966-0.3187050677167622j , + 1.0458498304770472-0.9989483435319496j ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(-1.6782909868280393,-0.44303325034407437), (-0.44670237330570184,4.8470007661079588), (2.0594545090032099,-2.2848432268240106), (-1.852046418980849,1.6723820061372749)], [(-0.53338018421119981,-0.5152843101202178), (-8.6208093221459947,-1.4723511111926109), (0.92382847150396952,-2.3790501284019947), (0.50051022622915986,-1.30066052934836)], [(0.94535043721506584,2.744088772946665), (-5.9178492824175759,-4.3744650461123786), (1.8341291553102983,-4.8378584827626838), (-2.9865099107523174,0.54929565579246509)], [(3.2517513113853891,7.2792034361133062), (-0.09841002311276037,0.88008791818205689), (-0.035759860211603468,2.4677764344580244), (-3.6133109853094476,-2.2833696560058976)]], [[(-3.996403598623405,2.42308766118121), (0.59408630943699003,2.531609474375295), (-1.789098034543644,-2.538389274566601), (-1.2911065903374881,3.1576544511573843)], [(-0.39853021063902833,4.4607177630985086), (1.0742061295773189,-2.6002112528615386), (6.1899856720238198,1.1866537964613415), (3.1900218245393352,2.7291222716752372)], [(5.2347956435718022,2.8649782894514577), (2.3527586611916762,2.4688953673448575), (-2.317572140163894,4.3609023810820053), (4.0754354644933413,-0.66536169422808067)], [(-6.2237114632988675,-4.9294897244018943), (4.2994486027667103,-1.3300494261380422), (-0.51942958410141249,0.60038999428238982), (0.084516726847668963,-7.2944134049318752)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_zhetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo/O/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x10\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\x12\x10\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x08d\x91Y\xa6G\xda\xfa\xbf$-Q"\xa8Z\xdc\xbfL0\x19\x8d\xc5\x96\xdc\xbf\x86{8+Tc\x13@\xf0%\x1eI\xc3y\x00@\xe4\x91\xbd\xe2[G\x02\xc0\x85%\x03m\xfb\xa1\xfd\xbf\x9atl\xa2\x13\xc2\xfa?\x9c\xb0\xf0Qs\x11\xe1\xbf\xd8v\x83\x855}\xe0\xbf\x84V/\xb8\xda=!\xc0\n\xd3\xec\t\xc0\x8e\xf7\xbf\x98$\x07\xba\x00\x90\xed?\xd5?\x08oK\x08\x03\xc0>\xf8\x9e\x05.\x04\xe0?\xf2\xfcKj\x81\xcf\xf4\xbf\xe4"c\x8fO@\xee?y\x03\x89\xd0\xe4\xf3\x05@\xee\x8f\xaa\xae\xe0\xab\x17\xc0\xf20\xda\xc3s\x7f\x11\xc0V*+\xd0\x97X\xfd?P\x91\xf8\x92\xf7Y\x13\xc0\x7f\xe3\xdeN_\xe4\x07\xc0\x14\xd5\xae{\xd4\x93\xe1?\xbc\x00\t1\x96\x03\n@`&l\x81\xe7\x1d\x1d@X/\xde6f1\xb9\xbf\x06KF#\xae)\xec?\xcd\x9a<\xcc\x1dO\xa2\xbf\x91\xb1>\x92\x01\xbe\x03@\xf2s\x01\x97\x0f\xe8\x0c\xc0\xf5\xcaiOWD\x02\xc0F\xa2-s\xa2\xf8\x0f\xc0X\xea\xa0\xc8{b\x03@\x0b\x10\xc1J\xc1\x02\xe3?2|\xd5w\xbc@\x04@\xca>\xbbB%\xa0\xfc\xbf\xe8>6\t\x9fN\x04\xc0\xafdRb_\xa8\xf4\xbf\x80Q>V\xe0B\t@UhJ\xdb\x84\x81\xd9\xbf\t\xc7\xb4e\xc6\xd7\x11@<(;\xc4\xf2/\xf1?\x1a\xda\xad\x8e;\xcd\x04\xc0\x1c4\xa0\x9a\x8b\xc2\x18@z\x9c\xf7\xb0\x88\xfc\xf2?\xaea\x8f)*\x85\t@\x00\x0b\xbd\x0e>\xd5\x05@b\x89\xe9Dn\xf0\x14@a\x8d\xc7\xbcy\xeb\x06@\x8a\x97\t"s\xd2\x02@\xc2\xef\xdf6L\xc0\x03@J\xff Cc\x8a\x02\xc0\xd7.\xcfd\x90q\x11@s\xd4S\xf4>M\x10@t\x10\x97\x9b\xa4J\xe5\xbf\x8eo*\x9e\x14\xe5\x18\xc0\xc5\x18\x81\'\xcc\xb7\x13\xc0\x19\xdd\x8e\xa7\xa22\x11@-95\xe8\xe1G\xf5\xbfZK\x89\xca*\x9f\xe0\xbfR;\xc9\x13e6\xe3?\x7f\x94\xc6a\xe3\xa2\xb5?\xe2\xbe&\xb5z-\x1d\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\x0b\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_zhetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_chetrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[ 3.3228416 +0.j , -1.9756439 +4.593356j , + 7.367708 +0.88518727j , -8.659938 +1.6132793j ], + [-6.9206004 +0.j , -3.6362798 +0.j , + 3.3011198 -4.644362j , -4.8589935 -0.61439794j ], + [ 0.64957 +0.060723424j, 6.620491 +0.j , + 0.2882607 +0.j , -1.0288142 +1.8544064j ], + [-0.05458622 +0.10473086j , -0.15611424 +0.06925995j , + -4.431866 +0.j , 2.364208 +0.j ]], + + [[-4.1803885 +0.j , 0.5670845 +0.6913016j , + 2.675204 -0.23881845j , -0.41825035 -1.4060576j ], + [ 8.33625 +0.j , 2.6144838 +0.j , + -2.4941807 -1.9316154j , 0.6687787 -2.209776j ], + [ 0.019031923+0.17462212j , 2.7034955 +0.j , + -0.70924187 +0.j , 2.7962255 +1.5316825j ], + [-0.057821754+0.023692288j, -0.62805307 -0.0882424j , + 6.6364865 +0.j , -1.698973 +0.j ]]], + dtype=complex64), array([[ 3.3228416 , -3.6362798 , 0.2882607 , 2.364208 ], + [-4.1803885 , 2.6144838 , -0.70924187, -1.698973 ]], + dtype=float32), array([[-6.9206004, 6.620491 , -4.431866 ], + [ 8.33625 , 2.7034955, 6.6364865]], dtype=float32), array([[1.360567 +0.1977107j , 1.7586378-0.56989706j, + 1.5772758-0.8165493j ], + [1.9152443-0.1834492j , 1.1593437+0.55631363j, + 1.6889225-0.724835j ]], dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(3.32284164,1.14621949), (-1.97564387,4.59335613), (7.36770821,0.885187268), (-8.65993785,1.61327934)], [(2.495340e+00,1.36827672), (-3.96969199,-0.636681795), (3.3011198,-4.64436197), (-4.85899353,-0.614397943)], [(6.03322554,1.46055949), (-3.89591122,-4.1833396), (-1.46423841,-0.106284566), (-1.0288142,1.85440636)], [(-0.657281339,0.911450386), (3.18693113,-2.02812219), (-2.64483237,0.351429433), (4.45011663,-1.79112875)]], [[(-4.18038845,-3.65238023), (0.567084491,0.691301584), (2.67520404,-0.238818452), (-0.418250352,-1.4060576)], [(-7.62970591,1.5292784), (0.269325763,2.48722434), (-2.49418068,-1.93161535), (0.668778717,-2.20977592)], [(-0.570908666,-2.75890398), (-0.235837936,3.45861554), (-0.946199476,0.23120968), (2.79622555,1.53168249)], [(0.886947453,-0.466695577), (-3.194850e+00,-0.0176551137), (-4.37602425,-3.7703948), (0.883143305,-4.70016575)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_chetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo//\x1f\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\xe2\x0b\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\xc0\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x04p\xa9T@R\xb7\x92?\xe6\xe1\xfc\xbf\xc6\xfc\x92@D\xc4\xeb@\xa2\x9bb?\x1b\x8f\n\xc1\xf0\x7f\xce?\xa7\xb3\x1f@\xb1#\xaf?o\x0f~\xc0\x94\xfd"\xbf\x8cES@\x9d\x9e\x94\xc0\xe0|\x9b\xc0/I\x1d\xbf/\x10\xc1@\x9d\xf3\xba?\x9cVy\xc0\xeb\xdd\x85\xc0*l\xbb\xbf\xb9\xab\xd9\xbd/\xb0\x83\xbf0]\xed?\x97C(\xbf\xd0Ti?\xae\xf6K@\xc1\xcc\x01\xc0\xefD)\xc0\x8f\xee\xb3>[g\x8e@\xb5C\xe5\xbf\xbe\xc5\x85\xc0\x99\xc0i\xc0s,\x11?$\xf90?\x8b6+@\xd3\x8ct\xbe\xe9$\xd6\xbe\xb2\xf9\xb3\xbf\x8d&\xf4\xc0e\xbf\xc3?\x11\xe5\x89>\xaf.\x1f@\xa8\xa0\x1f\xc0,?\xf7\xbf\x155+?\xf8l\r\xc0\x12\'\x12\xbf\xe2\x910\xc0\x80\x7fq\xbe\xf5Y]@!:r\xbf;\xc2l>\\\xf52@,\x0e\xc4?\xfd\x0ec?\xb9\xf2\xee\xbelxL\xc0u\xa1\x90\xbcd\x08\x8c\xc0&Nq\xc0\xae\x15b?\xc2g\x96\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\t\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_chetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_ssytrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[-0.8395241 , 0.156272 , -1.6810869 , 0.23832119], + [-2.985257 , -5.571 , -0.22652794, -0.83806676], + [ 0.27237308, -1.6295947 , 2.0042834 , -1.148861 ], + [-0.17183593, 0.57464546, 0.5536146 , -4.206357 ]], + + [[ 1.7666914 , 2.569005 , -0.86576384, -0.1617768 ], + [-5.143918 , 5.0426254 , -3.7237067 , 4.383015 ], + [ 0.33311516, -1.5299042 , -8.854181 , -2.896776 ], + [ 0.3419102 , 0.2669245 , -2.8250606 , 5.752488 ]]], + dtype=float32), array([[-0.8395241, -5.571 , 2.0042834, -4.206357 ], + [ 1.7666914, 5.0426254, -8.854181 , 5.752488 ]], dtype=float32), array([[-2.985257 , -1.6295947, 0.5536146], + [-5.143918 , -1.5299042, -2.8250606]], dtype=float32), array([[1.8120625, 1.5035137, 0. ], + [1.6288393, 1.8669801, 0. ]], dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[-0.83952409, 1.562720e-01, -1.6810869, 0.238321185], [2.42421508, -5.17118931, -0.226527944, -0.838066756], [1.47339451, -1.32866347, -3.3505435, -1.14886105], [-0.929541587, -0.955984473, 2.71886253, 0.748659431]], [[1.76669145, 2.56900501, -0.865763843, -0.161776796], [3.23469758, -0.362713158, -3.72370672, 4.38301516], [2.79104376, 7.36582708, -3.04437494, -2.89677596], [2.86473417, 0.981746375, -2.13533139, 5.34802151]]]> : tensor<2x4x4xf32> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_ssytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf32>, tensor) -> tensor<2x4x4xf32> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc7) + return %2 : tensor<2x4x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b/\x1f\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x04\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02b\t\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x02\r\xebV\xbf\xc4\x05 >\xdb-\xd7\xbfx\nt>W&\x1b@bz\xa5\xc0\xf1\xf6g\xbe\x8b\x8bV\xbf1\x98\xbc?\xa5\x11\xaa\xbfNoV\xc0\xe1\r\x93\xbfp\xf6m\xbff\xbbt\xbf\xd8\x01.@%\xa8??\xf2"\xe2?\x94j$@\xb3\xa2]\xbf\xd1\xa8%\xbeI\x05O@\x8a\xb5\xb9\xbe6Qn\xc0\xa9A\x8c@v\xa02@\xdb\xb4\xeb@\n\xd7B\xc0\xc7d9\xc0\xceW7@\xbbS{?E\xa9\x08\xc0\xfe"\xab@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\t)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_ssytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_03["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dsytrd'], + serialized_date=datetime.date(2024, 9, 3), + inputs=(), + expected_outputs=(array([[[ 0.8251247184208595 , -2.6963562039892532 , + 0.8082445002373937 , -1.551980329390836 ], + [-2.629505060186711 , 4.427374205796291 , + -2.2111093161901074 , 7.552489598405787 ], + [ 0.2269453213819231 , 0.3650586474106988 , + -3.5933639667756205 , 4.828829679372501 ], + [-0.6415372293575187 , -0.2519326897319508 , + -1.7607827845801751 , -3.381311711243865 ]], + + [[-4.000421911405985 , 3.6303350337601055 , + 2.8066821235532355 , 1.099224389184342 ], + [-4.141622408467332 , -5.276404169116551 , + -0.8496056221591237 , -2.275319346221659 ], + [ 0.5828958067901202 , 0.9351254869793256 , + 2.7765603683442177 , -4.339686212557215 ], + [-0.6391146585297987 , 0.3129920702652711 , + -0.25441692469349864, -1.4155240723557498 ]]]), array([[ 0.8251247184208595, 4.427374205796291 , -3.5933639667756205, + -3.381311711243865 ], + [-4.000421911405985 , -5.276404169116551 , 2.7765603683442177, + -1.4155240723557498]]), array([[-2.629505060186711 , 0.3650586474106988 , -1.7607827845801751 ], + [-4.141622408467332 , 0.9351254869793256 , -0.25441692469349864]]), array([[1.3669846724688552, 1.8806358893589366, 0. ], + [1.1440109149169537, 1.8215532880266878, 0. ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[0.82512471842085955, -2.6963562039892532, 0.80824450023739369, -1.5519803293908361], [0.96498805326781766, -4.1313349231964409, -2.2111093161901074, 7.5524895984057867], [0.81575339483804743, 1.0647235400727899, -1.0064296232364345, 4.8288296793725012], [-2.3060011529502993, -2.9182106402942192, -1.7781896154088577, 2.5904630742096817]], [[-4.0004219114059847, 3.6303350337601055, 2.8066821235532355, 1.0992243891843421], [0.59643883228393779, -1.5243235004961249, -0.84960562215912372, -2.275319346221659], [2.7617960295487092, -0.57538970930521982, 0.12559406141906576, -4.3396862125572149], [-3.0281643919760217, 0.38177997229319849, 3.860398204232184, -2.5166384340510231]]]> : tensor<2x4x4xf64> loc(#loc) + %c = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) + %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) + %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) + %0:6 = stablehlo.custom_call @lapack_dsytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<128xf64>) loc(#loc2) + %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_5 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf64>, tensor) -> tensor<2x4x4xf64> loc(#loc5) + %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc7) + return %2 : tensor<2x4x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b//\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02r\x0b\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x04A\xa4\x17\xf4kg\xea?\x1f\x01\x943#\x92\x05\xc0\x86 \xf6\x91#\xdd\xe9?\x9dMlS\xe9\xd4\xf8\xbf\x88\x1c:\xa0.\xe1\xee?8\xce\x7f\xa9|\x86\x10\xc0\xe8V\xc7\x14Z\xb0\x01\xc0\xd2!R\xd5\xbf5\x1e@\xbf\xc5\r\xdd\xa6\x1a\xea?\xbcM\xfe\x8c\x1b\t\xf1?\xdbj\xd8\xf2U\x1a\xf0\xbf\xado;\xba\xb8P\x13@\xbb\xad\x83\xbb\xb0r\x02\xc0\x1f9\xf7\xd1~X\x07\xc0)ID\xf4vs\xfc\xbfD\xcfI\xb4D\xb9\x04@\x16\xc3\xfe\x99n\x00\x10\xc0\x82.\x1c\x18\xed\n\r@\x8cn\xd7\xc1\x15t\x06@|2(Pl\x96\xf1?\x88*\xd7\xe3\x06\x16\xe3?F{\xf2\t\xa1c\xf8\xbf8z5!\xf8/\xeb\xbf4\xd3\x1f\xa1\xda3\x02\xc0)\x13I\x84(\x18\x06@\xbcw\xfd\xad\x97i\xe2\xbf\x1e\xf0.Yw\x13\xc0?dW\xd7\xb3\xd6[\x11\xc0\x04\x97\xb3@\xae9\x08\xc0\xbc\x17\xd1C\x15o\xd8?\x02\xb7%t\x18\xe2\x0e@\xac\xd8\xd0T\x13"\x04\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\x0b)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_dsytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_12_01 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zhetrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[-1.6782909868280393 +0.j , + -0.44670237330570184+4.847000766107959j , + 2.05945450900321 -2.2848432268240106j , + -1.852046418980849 +1.672382006137275j ], + [ 8.516713699516982 +0.j , + -2.7881860505313174 +0.j , + 0.9238284715039695 -2.3790501284019947j , + 0.5005102262291599 -1.30066052934836j ], + [-0.12132810525381293-0.2963030371159077j , + -3.6374350042782893 +0.j , + 0.5605752523031344 +0.j , + -2.9865099107523174 +0.5492956557924651j ], + [-0.40379248092949666-0.7813328344426929j , + -0.07101654492399719-0.27208840961051617j, + -7.4654253782049285 +0.j , + -8.172380353916964 +0.j ]], + + [[-3.996403598623405 +0.j , + 0.59408630943699 +2.531609474375295j , + -1.789098034543644 -2.538389274566601j , + -1.291106590337488 +3.1576544511573843j ], + [10.8950662522622 +0.j , + -2.8151642043836693 +0.j , + 6.18998567202382 +1.1866537964613415j , + 3.1900218245393352 +2.7291222716752372j ], + [-0.3142889671188478 -0.37781876498252764j, + 3.049208563595754 +0.j , + -2.4383044880335487 +0.j , + 4.075435464493341 -0.6653616942280807j ], + [ 0.32757687545025194+0.565870910342534j , + 0.8177026465997795 -0.15906305615104555j, + 3.3415143060767125 +0.j , + 4.094619408678314 +0.j ]]]), array([[-1.6782909868280393, -2.7881860505313174, 0.5605752523031344, + -8.172380353916964 ], + [-3.996403598623405 , -2.8151642043836693, -2.4383044880335487, + 4.094619408678314 ]]), array([[ 8.516713699516982 , -3.6374350042782893, -7.4654253782049285], + [10.8950662522622 , 3.049208563595754 , 3.3415143060767125]]), array([[1.0626274644222748+0.06050271598884928j, + 1.834630852474663 +0.18575551495730305j, + 1.981584368497257 +0.19102912741736966j], + [1.0365789616521406-0.40942548304121656j, + 1.0872592163018966-0.3187050677167622j , + 1.0458498304770472-0.9989483435319496j ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(-1.6782909868280393,-0.44303325034407437), (-0.44670237330570184,4.8470007661079588), (2.0594545090032099,-2.2848432268240106), (-1.852046418980849,1.6723820061372749)], [(-0.53338018421119981,-0.5152843101202178), (-8.6208093221459947,-1.4723511111926109), (0.92382847150396952,-2.3790501284019947), (0.50051022622915986,-1.30066052934836)], [(0.94535043721506584,2.744088772946665), (-5.9178492824175759,-4.3744650461123786), (1.8341291553102983,-4.8378584827626838), (-2.9865099107523174,0.54929565579246509)], [(3.2517513113853891,7.2792034361133062), (-0.09841002311276037,0.88008791818205689), (-0.035759860211603468,2.4677764344580244), (-3.6133109853094476,-2.2833696560058976)]], [[(-3.996403598623405,2.42308766118121), (0.59408630943699003,2.531609474375295), (-1.789098034543644,-2.538389274566601), (-1.2911065903374881,3.1576544511573843)], [(-0.39853021063902833,4.4607177630985086), (1.0742061295773189,-2.6002112528615386), (6.1899856720238198,1.1866537964613415), (3.1900218245393352,2.7291222716752372)], [(5.2347956435718022,2.8649782894514577), (2.3527586611916762,2.4688953673448575), (-2.317572140163894,4.3609023810820053), (4.0754354644933413,-0.66536169422808067)], [(-6.2237114632988675,-4.9294897244018943), (4.2994486027667103,-1.3300494261380422), (-0.51942958410141249,0.60038999428238982), (0.084516726847668963,-7.2944134049318752)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_zhetrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf5\x99G\x011\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03i\x0f\x0b\x0b\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0boO/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x10\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03C\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0f\x07\x07\x13\x0b\x1b\x07\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02\x9a\x0f\x1d\x1d\t\x1f\x1d!\t\x1d-\t\x17\x1f\xde\n\x1b\x1d#\t\x1d/\t\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'\x81\x05\'\x1d+\t\x05)\x05+\x05-\x1f5\x01\x1d/\x1d1\x1d3\x03\x07;;;\r\x0335\x03\x03;\x1d5\x1f\x17\t\x00\x00\x00\x00\t\x07\x07\x01\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fA!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1d;\x1d=\x1f?1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f#)\x03\taeim\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\r\x057o35\x1dE\x1dG\x1dI#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x02\x08d\x91Y\xa6G\xda\xfa\xbf$-Q"\xa8Z\xdc\xbfL0\x19\x8d\xc5\x96\xdc\xbf\x86{8+Tc\x13@\xf0%\x1eI\xc3y\x00@\xe4\x91\xbd\xe2[G\x02\xc0\x85%\x03m\xfb\xa1\xfd\xbf\x9atl\xa2\x13\xc2\xfa?\x9c\xb0\xf0Qs\x11\xe1\xbf\xd8v\x83\x855}\xe0\xbf\x84V/\xb8\xda=!\xc0\n\xd3\xec\t\xc0\x8e\xf7\xbf\x98$\x07\xba\x00\x90\xed?\xd5?\x08oK\x08\x03\xc0>\xf8\x9e\x05.\x04\xe0?\xf2\xfcKj\x81\xcf\xf4\xbf\xe4"c\x8fO@\xee?y\x03\x89\xd0\xe4\xf3\x05@\xee\x8f\xaa\xae\xe0\xab\x17\xc0\xf20\xda\xc3s\x7f\x11\xc0V*+\xd0\x97X\xfd?P\x91\xf8\x92\xf7Y\x13\xc0\x7f\xe3\xdeN_\xe4\x07\xc0\x14\xd5\xae{\xd4\x93\xe1?\xbc\x00\t1\x96\x03\n@`&l\x81\xe7\x1d\x1d@X/\xde6f1\xb9\xbf\x06KF#\xae)\xec?\xcd\x9a<\xcc\x1dO\xa2\xbf\x91\xb1>\x92\x01\xbe\x03@\xf2s\x01\x97\x0f\xe8\x0c\xc0\xf5\xcaiOWD\x02\xc0F\xa2-s\xa2\xf8\x0f\xc0X\xea\xa0\xc8{b\x03@\x0b\x10\xc1J\xc1\x02\xe3?2|\xd5w\xbc@\x04@\xca>\xbbB%\xa0\xfc\xbf\xe8>6\t\x9fN\x04\xc0\xafdRb_\xa8\xf4\xbf\x80Q>V\xe0B\t@UhJ\xdb\x84\x81\xd9\xbf\t\xc7\xb4e\xc6\xd7\x11@<(;\xc4\xf2/\xf1?\x1a\xda\xad\x8e;\xcd\x04\xc0\x1c4\xa0\x9a\x8b\xc2\x18@z\x9c\xf7\xb0\x88\xfc\xf2?\xaea\x8f)*\x85\t@\x00\x0b\xbd\x0e>\xd5\x05@b\x89\xe9Dn\xf0\x14@a\x8d\xc7\xbcy\xeb\x06@\x8a\x97\t"s\xd2\x02@\xc2\xef\xdf6L\xc0\x03@J\xff Cc\x8a\x02\xc0\xd7.\xcfd\x90q\x11@s\xd4S\xf4>M\x10@t\x10\x97\x9b\xa4J\xe5\xbf\x8eo*\x9e\x14\xe5\x18\xc0\xc5\x18\x81\'\xcc\xb7\x13\xc0\x19\xdd\x8e\xa7\xa22\x11@-95\xe8\xe1G\xf5\xbfZK\x89\xca*\x9f\xe0\xbfR;\xc9\x13e6\xe3?\x7f\x94\xc6a\xe3\xa2\xb5?\xe2\xbe&\xb5z-\x1d\xc0\r\x03\x83\x85\x1dK\x13=L\x0b\x03\x1dM\x1dO\x05\x01\x03\x03W\x03\x03\x93\x15\x03\x01\x01\x01\x03\x0bWKKK\x97\x1fC\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x1f)\x05\t\x05\x13)\x05\t\x11\x19)\x05\t\r\x19)\x05\t\r\x1f)\x01\x1f)\x01\x19\x01)\x03\t\')\x01\'\x0b\x1d)\x03\t\x13\x03\x19)\x07\t\x05\x05\x13\x13)\x05\t\r\x13\x1b\x11\x01\t\x05\t\x0b\r\x11\x07!\x05\x0f\x03\x05\x11\x07\x07\t\x11\x03\t\x11\x07\x07\x0b\x11\x03\x0b\x11\x07\x07\r\x0f\x03\r)\x03\t\x1b)\x03\x01\x1b)\x05\t\x11\x13)\x07\t\x11\x11\x13)\x03\r\x1b!)\x03\r#)\x03\t#)\x03\x05#)\x03\x05\x1b\x04\xbe\x06\x05\x01Q\x03\x11\x01\x07\x04\x96\x06\x03\x01\x15\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\x05\x11G)%\x07\x0b\x05\t\x0b\r\x15\x03\x01\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\r\rF\x07\r\x03\x1d\x05\x0b\x0f\x03F\r\x0f\x03!\x03\x11\x05B\x03\x11\x03\x0f\x0fF\x01\x13\x03\x05\x07\x13\x03\x15\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\x19\rF\x07\r\x03\x1d\x05\x0b\x1b\x03F\r\x0f\x03\x07\x03\x1d\x05B\x03\x15\x03\x11\x0fF\x01\x17\x03\t\x07\x1f\x05!\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03%\rF\x07\r\x03\x1d\x05\x0b\'\x03F\r\x0f\x03\x07\x03)\x05B\x03\x15\x03\x11\x0fF\x01\x19\x03\x0b\x07+\x07-\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x031\rF\x07\r\x03\x1d\x05\x0b3\x03F\r\x0f\x03\x07\x035\x05B\x03\x11\x03\x0f\x0fF\x01\x1b\x03\r\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x1d\x07\x04S\x03\r\x13\x07C\x01\x0b\x01\x1f\x01\x00\x03F\x05\x1f\x039\x03\x01\x03F\x05\x0b\x03\x05\x03\x05\x0b\x06\x0b\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x0f\x01\x13\x01#\x01\x00\x03F\x05#\x037\x03\x01\x03F\x05\x0b\x03\t\x03\x05\x0b\x06\x0b\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x0f\x01\x17\x01#\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\x0b\x03\x05\x0b\x06\x0b\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\'\x07\x04S\x03\r\x13\x07\x0f\x01\x1b\x01\x1f\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\r\x03\x05\x0b\x06\x0b\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x12\nQ%\x03\x0b\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/A)Sci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_zhetrd_ffi\x00\x08\x8d)\x057\x01\x0bM]_qs\x03\x7f\x11\x87\x89\x8bM\x8d\x8f\x91\x95\x03A\x031\x05CE\x03G\x03Y\x03O\x03[\x03Q\x03S\x03U\x0b9u=O?\x03}\x0b9w=Q?\x03I\x0b9y=S?\x0b9{=U?', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_chetrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[ 3.3228416 +0.j , -1.9756439 +4.593356j , + 7.367708 +0.88518727j , -8.659938 +1.6132793j ], + [-6.9206004 +0.j , -3.6362798 +0.j , + 3.3011198 -4.644362j , -4.8589935 -0.61439794j ], + [ 0.64957 +0.060723424j, 6.620491 +0.j , + 0.2882607 +0.j , -1.0288142 +1.8544064j ], + [-0.05458622 +0.10473086j , -0.15611424 +0.06925995j , + -4.431866 +0.j , 2.364208 +0.j ]], + + [[-4.1803885 +0.j , 0.5670845 +0.6913016j , + 2.675204 -0.23881845j , -0.41825035 -1.4060576j ], + [ 8.33625 +0.j , 2.6144838 +0.j , + -2.4941807 -1.9316154j , 0.6687787 -2.209776j ], + [ 0.019031923+0.17462212j , 2.7034955 +0.j , + -0.70924187 +0.j , 2.7962255 +1.5316825j ], + [-0.057821754+0.023692288j, -0.62805307 -0.0882424j , + 6.6364865 +0.j , -1.698973 +0.j ]]], + dtype=complex64), array([[ 3.3228416 , -3.6362798 , 0.2882607 , 2.364208 ], + [-4.1803885 , 2.6144838 , -0.70924187, -1.698973 ]], + dtype=float32), array([[-6.9206004, 6.620491 , -4.431866 ], + [ 8.33625 , 2.7034955, 6.6364865]], dtype=float32), array([[1.360567 +0.1977107j , 1.7586378-0.56989706j, + 1.5772758-0.8165493j ], + [1.9152443-0.1834492j , 1.1593437+0.55631363j, + 1.6889225-0.724835j ]], dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[(3.32284164,1.14621949), (-1.97564387,4.59335613), (7.36770821,0.885187268), (-8.65993785,1.61327934)], [(2.495340e+00,1.36827672), (-3.96969199,-0.636681795), (3.3011198,-4.64436197), (-4.85899353,-0.614397943)], [(6.03322554,1.46055949), (-3.89591122,-4.1833396), (-1.46423841,-0.106284566), (-1.0288142,1.85440636)], [(-0.657281339,0.911450386), (3.18693113,-2.02812219), (-2.64483237,0.351429433), (4.45011663,-1.79112875)]], [[(-4.18038845,-3.65238023), (0.567084491,0.691301584), (2.67520404,-0.238818452), (-0.418250352,-1.4060576)], [(-7.62970591,1.5292784), (0.269325763,2.48722434), (-2.49418068,-1.93161535), (0.668778717,-2.20977592)], [(-0.570908666,-2.75890398), (-0.235837936,3.45861554), (-0.946199476,0.23120968), (2.79622555,1.53168249)], [(0.886947453,-0.466695577), (-3.194850e+00,-0.0176551137), (-4.37602425,-3.7703948), (0.883143305,-4.70016575)]]]> : tensor<2x4x4xcomplex> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_chetrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) + %16 = call @_where_2(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) + return %2 : tensor<2x4x4xcomplex> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) + return %2 : tensor<2x3xcomplex> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf5\x99G\x011\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03i\x0f\x0b\x0b\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0bo/\x1f\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03C\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0f\x07\x07\x13\x0b\x1b\x07\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02j\x0b\x1d\x1d\t\x1f\x1d!\t\x1d-\t\x17\x1f\xde\n\x1b\x1d#\t\x1d/\t\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'\x81\x05\'\x1d+\t\x05)\x05+\x05-\x1f5\x01\x1d/\x1d1\x1d3\x03\x07;;;\r\x0335\x03\x03;\x1d5\x1f\x17\t\x00\x00\x00\x00\t\x07\x07\x01\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fA!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1d;\x1d=\x1f?1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x11\t\x00\x00\xc0\x7f#)\x03\taeim\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\r\x057o35\x1dE\x1dG\x1dI#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x02\x04p\xa9T@R\xb7\x92?\xe6\xe1\xfc\xbf\xc6\xfc\x92@D\xc4\xeb@\xa2\x9bb?\x1b\x8f\n\xc1\xf0\x7f\xce?\xa7\xb3\x1f@\xb1#\xaf?o\x0f~\xc0\x94\xfd"\xbf\x8cES@\x9d\x9e\x94\xc0\xe0|\x9b\xc0/I\x1d\xbf/\x10\xc1@\x9d\xf3\xba?\x9cVy\xc0\xeb\xdd\x85\xc0*l\xbb\xbf\xb9\xab\xd9\xbd/\xb0\x83\xbf0]\xed?\x97C(\xbf\xd0Ti?\xae\xf6K@\xc1\xcc\x01\xc0\xefD)\xc0\x8f\xee\xb3>[g\x8e@\xb5C\xe5\xbf\xbe\xc5\x85\xc0\x99\xc0i\xc0s,\x11?$\xf90?\x8b6+@\xd3\x8ct\xbe\xe9$\xd6\xbe\xb2\xf9\xb3\xbf\x8d&\xf4\xc0e\xbf\xc3?\x11\xe5\x89>\xaf.\x1f@\xa8\xa0\x1f\xc0,?\xf7\xbf\x155+?\xf8l\r\xc0\x12\'\x12\xbf\xe2\x910\xc0\x80\x7fq\xbe\xf5Y]@!:r\xbf;\xc2l>\\\xf52@,\x0e\xc4?\xfd\x0ec?\xb9\xf2\xee\xbelxL\xc0u\xa1\x90\xbcd\x08\x8c\xc0&Nq\xc0\xae\x15b?\xc2g\x96\xc0\r\x03\x83\x85\x1dK\x13=L\x0b\x03\x1dM\x1dO\x05\x01\x03\x03W\x03\x03\x93\x15\x03\x01\x01\x01\x03\x0bWKKK\x97\x1fC\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x1f)\x05\t\x05\x13)\x05\t\x11\x19)\x05\t\r\x19)\x05\t\r\x1f)\x01\x1f)\x01\x19\x01)\x03\t\')\x01\'\t\x1d)\x03\t\x13\x03\x19)\x07\t\x05\x05\x13\x13)\x05\t\r\x13\x1b\x11\x01\t\x05\t\x0b\r\x11\x07!\x05\x0f\x03\x05\x11\x07\x07\t\x11\x03\t\x11\x07\x07\x0b\x11\x03\x0b\x11\x07\x07\r\x0f\x03\r)\x03\t\x1b)\x03\x01\x1b)\x05\t\x11\x13)\x07\t\x11\x11\x13)\x03\r\x1b!)\x03\r#)\x03\t#)\x03\x05#)\x03\x05\x1b\x04\xbe\x06\x05\x01Q\x03\x11\x01\x07\x04\x96\x06\x03\x01\x15\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\x05\x11G)%\x07\x0b\x05\t\x0b\r\x15\x03\x01\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\r\rF\x07\r\x03\x1d\x05\x0b\x0f\x03F\r\x0f\x03!\x03\x11\x05B\x03\x11\x03\x0f\x0fF\x01\x13\x03\x05\x07\x13\x03\x15\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03\x19\rF\x07\r\x03\x1d\x05\x0b\x1b\x03F\r\x0f\x03\x07\x03\x1d\x05B\x03\x15\x03\x11\x0fF\x01\x17\x03\t\x07\x1f\x05!\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x03%\rF\x07\r\x03\x1d\x05\x0b\'\x03F\r\x0f\x03\x07\x03)\x05B\x03\x15\x03\x11\x0fF\x01\x19\x03\x0b\x07+\x07-\x05B\x03\t\x03\x17\x03F\x07\x0b\x03\x15\x031\rF\x07\r\x03\x1d\x05\x0b3\x03F\r\x0f\x03\x07\x035\x05B\x03\x11\x03\x0f\x0fF\x01\x1b\x03\r\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x1d\x07\x04S\x03\r\x13\x07C\x01\x0b\x01\x1f\x01\x00\x03F\x05\x1f\x039\x03\x01\x03F\x05\x0b\x03\x05\x03\x05\x0b\x06\x0b\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x0f\x01\x13\x01#\x01\x00\x03F\x05#\x037\x03\x01\x03F\x05\x0b\x03\t\x03\x05\x0b\x06\x0b\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x0f\x01\x17\x01#\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\x0b\x03\x05\x0b\x06\x0b\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\'\x07\x04S\x03\r\x13\x07\x0f\x01\x1b\x01\x1f\x01\x00\x03F\x05#\x03%\x03\x01\x03F\x05\x0b\x03\r\x03\x05\x0b\x06\x0b\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x12\nQ%\x03\x0b\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/A)Sci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_chetrd_ffi\x00\x08\x8d)\x057\x01\x0bM]_qs\x03\x7f\x11\x87\x89\x8bM\x8d\x8f\x91\x95\x03A\x031\x05CE\x03G\x03Y\x03O\x03[\x03Q\x03S\x03U\x0b9u=O?\x03}\x0b9w=Q?\x03I\x0b9y=S?\x0b9{=U?', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_ssytrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[-0.8395241 , 0.156272 , -1.6810869 , 0.23832119], + [-2.985257 , -5.571 , -0.22652794, -0.83806676], + [ 0.27237308, -1.6295947 , 2.0042834 , -1.148861 ], + [-0.17183593, 0.57464546, 0.5536146 , -4.206357 ]], + + [[ 1.7666914 , 2.569005 , -0.86576384, -0.1617768 ], + [-5.143918 , 5.0426254 , -3.7237067 , 4.383015 ], + [ 0.33311516, -1.5299042 , -8.854181 , -2.896776 ], + [ 0.3419102 , 0.2669245 , -2.8250606 , 5.752488 ]]], + dtype=float32), array([[-0.8395241, -5.571 , 2.0042834, -4.206357 ], + [ 1.7666914, 5.0426254, -8.854181 , 5.752488 ]], dtype=float32), array([[-2.985257 , -1.6295947, 0.5536146], + [-5.143918 , -1.5299042, -2.8250606]], dtype=float32), array([[1.8120625, 1.5035137, 0. ], + [1.6288393, 1.8669801, 0. ]], dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[-0.83952409, 1.562720e-01, -1.6810869, 0.238321185], [2.42421508, -5.17118931, -0.226527944, -0.838066756], [1.47339451, -1.32866347, -3.3505435, -1.14886105], [-0.929541587, -0.955984473, 2.71886253, 0.748659431]], [[1.76669145, 2.56900501, -0.865763843, -0.161776796], [3.23469758, -0.362713158, -3.72370672, 4.38301516], [2.79104376, 7.36582708, -3.04437494, -2.89677596], [2.86473417, 0.981746375, -2.13533139, 5.34802151]]]> : tensor<2x4x4xf32> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_ssytrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xf32>, tensor) -> tensor<2x4x4xf32> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc7) + return %2 : tensor<2x4x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) + return %2 : tensor<2x4xf32> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) + return %2 : tensor<2x3xf32> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe7\x93?\x011\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03c\x0f\x0b\x0b\x0b\x13\x1f\x0b\x0b/\x1f\x17\x0f\x0b\x0bO\x0b\x0b\x0bOo\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x04\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03;\x17\x0f\x1b\x17\x17\x07\x13\x0f\x07\x07\x13\x1b\x07\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02\xea\x08\x1d\x1d\x07\x1f\x1d-\x07\x17\x1f\xde\n\x1b\x1d!\x07\x1d/\x07\x1d#\x07\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'{\x05\'\x1d+\x07\x05)\x05+\x05-\x1f-\x01\x1d/\x1d1\x1d3\r\x0335\x1f\x13\t\x00\x00\x00\x00\t\x07\x07\x01\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\xc0\x7f\x03\x07999\x03\x039\x1d5\x1d7\x1f9!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d9\x1d;\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f71\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t]aei\r\x057_35\x1d=\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\x1dE\x1dG###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\x02\x02\r\xebV\xbf\xc4\x05 >\xdb-\xd7\xbfx\nt>W&\x1b@bz\xa5\xc0\xf1\xf6g\xbe\x8b\x8bV\xbf1\x98\xbc?\xa5\x11\xaa\xbfNoV\xc0\xe1\r\x93\xbfp\xf6m\xbff\xbbt\xbf\xd8\x01.@%\xa8??\xf2"\xe2?\x94j$@\xb3\xa2]\xbf\xd1\xa8%\xbeI\x05O@\x8a\xb5\xb9\xbe6Qn\xc0\xa9A\x8c@v\xa02@\xdb\xb4\xeb@\n\xd7B\xc0\xc7d9\xc0\xceW7@\xbbS{?E\xa9\x08\xc0\xfe"\xab@\r\x03}\x7f\x1dI\x135L\x0b\x03\x1dK\x1dM\x05\x01\x03\x03W\x03\x03\x8d\x15\x03\x01\x01\x01\x03\x0bWMMM\x91\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\t\r\x15)\x01\x15)\x07\t\x11\x11\x15)\x05\t\x11\x15)\x05\t\x05\x0f\x01)\x03\t\x1f)\x01\x1f\t\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f\x13\x1b\x11\x01\t\t\x0b\x05\x05\x11\x07\x1b\t\x07\x03\t\x11\x07\r\x0b\x07\x03\x0b\x11\x07\r\x05\x07\x03\x05)\x05\t\r\x0f)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x0f)\x07\t\x11\x11\x0f)\x03\r\x17!)\x03\r\x1d)\x03\t\x1d)\x03\x05\x1d)\x03\x05\x17\x04\xfe\x05\x05\x01Q\x03\x11\x01\x07\x04\xd6\x05\x03\x01\x11\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\t\x11G)%\x07\x0b\t\x0b\x05\x05\x11\x03\x01\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\r\x0bF\x05\r\x03\x19\x05\x0b\x0f\x03F\x0b\x0f\x03\x1b\x03\x11\x05B\x03\x11\x03\x07\rF\x01\x13\x03\t\x07\x13\x03\x15\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\x19\x0bF\x05\r\x03\x19\x05\x0b\x1b\x03F\x0b\x0f\x03\r\x03\x1d\x05B\x03\x11\x03\x07\rF\x01\x15\x03\x0b\x07\x1f\x05!\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03%\x0bF\x05\r\x03\x19\x05\x0b\'\x03F\x0b\x0f\x03\r\x03)\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x07+\x07-\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x031\x0bF\x05\r\x03\x19\x05\x0b3\x03F\x0b\x0f\x03\r\x035\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x19\x07\x04S\x03\r\x13\x077\x01\x13\x01\x0f\x01\x00\x03F\t\x1b\x031\x03\x01\x03F\t\x0b\x03\t\x03\x05\x0f\x06\r\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\x1d\x07\x04S\x03\r\x13\x07\x1b\x01\x17\x01\x0f\x01\x00\x03F\t\x1f\x03/\x03\x01\x03F\t\x0b\x03\x0b\x03\x05\x0f\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x1b\x01\x0b\x01\x0f\x01\x00\x03F\t\x1f\x03)\x03\x01\x03F\t\x0b\x03\x05\x03\x05\x0f\x06\r\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\xea\tO%\x03\x0b\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/A)Sci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_ssytrd_ffi\x00\x08y#\x057\x01\x0bOY[mo\x03y\x11\x81\x83\x85O\x87\x89\x8b\x8f\x03;\x031\x05=?\x03A\x03C\x03Q\x03S\x03K\x0bEqGQI\x03w\x0bEsGSI\x03U\x0bEuGKI', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_12_01["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dsytrd_ffi'], + serialized_date=datetime.date(2024, 12, 1), + inputs=(), + expected_outputs=(array([[[ 0.8251247184208595 , -2.6963562039892532 , + 0.8082445002373937 , -1.551980329390836 ], + [-2.629505060186711 , 4.427374205796291 , + -2.2111093161901074 , 7.552489598405787 ], + [ 0.2269453213819231 , 0.3650586474106988 , + -3.5933639667756205 , 4.828829679372501 ], + [-0.6415372293575187 , -0.2519326897319508 , + -1.7607827845801751 , -3.381311711243865 ]], + + [[-4.000421911405985 , 3.6303350337601055 , + 2.8066821235532355 , 1.099224389184342 ], + [-4.141622408467332 , -5.276404169116551 , + -0.8496056221591237 , -2.275319346221659 ], + [ 0.5828958067901202 , 0.9351254869793256 , + 2.7765603683442177 , -4.339686212557215 ], + [-0.6391146585297987 , 0.3129920702652711 , + -0.25441692469349864, -1.4155240723557498 ]]]), array([[ 0.8251247184208595, 4.427374205796291 , -3.5933639667756205, + -3.381311711243865 ], + [-4.000421911405985 , -5.276404169116551 , 2.7765603683442177, + -1.4155240723557498]]), array([[-2.629505060186711 , 0.3650586474106988 , -1.7607827845801751 ], + [-4.141622408467332 , 0.9351254869793256 , -0.25441692469349864]]), array([[1.3669846724688552, 1.8806358893589366, 0. ], + [1.1440109149169537, 1.8215532880266878, 0. ]])), + mlir_module_text=r""" +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) +#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { + %cst = stablehlo.constant dense<[[[0.82512471842085955, -2.6963562039892532, 0.80824450023739369, -1.5519803293908361], [0.96498805326781766, -4.1313349231964409, -2.2111093161901074, 7.5524895984057867], [0.81575339483804743, 1.0647235400727899, -1.0064296232364345, 4.8288296793725012], [-2.3060011529502993, -2.9182106402942192, -1.7781896154088577, 2.5904630742096817]], [[-4.0004219114059847, 3.6303350337601055, 2.8066821235532355, 1.0992243891843421], [0.59643883228393779, -1.5243235004961249, -0.84960562215912372, -2.275319346221659], [2.7617960295487092, -0.57538970930521982, 0.12559406141906576, -4.3396862125572149], [-3.0281643919760217, 0.38177997229319849, 3.860398204232184, -2.5166384340510231]]]> : tensor<2x4x4xf64> loc(#loc) + %0:5 = stablehlo.custom_call @lapack_dsytrd_ffi(%cst) {mhlo.backend_config = {uplo = 76 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>) loc(#loc2) + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) + %cst_0 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %4 = call @_where(%3, %0#0, %cst_0) : (tensor<2x1x1xi1>, tensor<2x4x4xf64>, tensor) -> tensor<2x4x4xf64> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc) + %5 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_2 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %8 = call @_where_0(%7, %0#1, %cst_2) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) + %c_3 = stablehlo.constant dense<0> : tensor loc(#loc) + %9 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_4 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %12 = call @_where_1(%11, %0#2, %cst_4) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + %c_5 = stablehlo.constant dense<0> : tensor loc(#loc) + %13 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) + %cst_6 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) + %16 = call @_where_1(%15, %0#3, %cst_6) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) + return %4, %8, %12, %16 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc) + } loc(#loc) + func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc7) + return %2 : tensor<2x4x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) + return %2 : tensor<2x4xf64> loc(#loc5) + } loc(#loc5) + func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) + %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) + %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) + return %2 : tensor<2x3xf64> loc(#loc5) + } loc(#loc5) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) +#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) +#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe7\x93?\x011\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x03c\x0f\x0b\x0b\x0b\x13\x1f\x0b\x0b//\x17\x0f\x0b\x0bO\x0b\x0b\x0bOo\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1f/\x01\x05\x0b\x0f\x03;\x17\x0f\x1b\x17\x17\x07\x13\x0f\x07\x07\x13\x1b\x07\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x07\x13\x13\x13\x13\x02\xfa\n\x1d\x1d\x07\x1f\x1d-\x07\x17\x1f\xde\n\x1b\x1d!\x07\x1d/\x07\x1d#\x07\x11\x03\x05\x03\x07\x13\x15\x17\x0f\x19\x0f\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x03\x03\'{\x05\'\x1d+\x07\x05)\x05+\x05-\x1f-\x01\x1d/\x1d1\x1d3\r\x0335\x1f\x13\t\x00\x00\x00\x00\t\x07\x07\x01\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x03\x07999\x03\x039\x1d5\x1d7\x1f9!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d9\x1d;\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f71\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t]aei\r\x057_35\x1d=\r\x057c35\x1d?\r\x057g35\x1dA\r\x057k35\x1dC\x1dE\x1dG###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\x02\x04A\xa4\x17\xf4kg\xea?\x1f\x01\x943#\x92\x05\xc0\x86 \xf6\x91#\xdd\xe9?\x9dMlS\xe9\xd4\xf8\xbf\x88\x1c:\xa0.\xe1\xee?8\xce\x7f\xa9|\x86\x10\xc0\xe8V\xc7\x14Z\xb0\x01\xc0\xd2!R\xd5\xbf5\x1e@\xbf\xc5\r\xdd\xa6\x1a\xea?\xbcM\xfe\x8c\x1b\t\xf1?\xdbj\xd8\xf2U\x1a\xf0\xbf\xado;\xba\xb8P\x13@\xbb\xad\x83\xbb\xb0r\x02\xc0\x1f9\xf7\xd1~X\x07\xc0)ID\xf4vs\xfc\xbfD\xcfI\xb4D\xb9\x04@\x16\xc3\xfe\x99n\x00\x10\xc0\x82.\x1c\x18\xed\n\r@\x8cn\xd7\xc1\x15t\x06@|2(Pl\x96\xf1?\x88*\xd7\xe3\x06\x16\xe3?F{\xf2\t\xa1c\xf8\xbf8z5!\xf8/\xeb\xbf4\xd3\x1f\xa1\xda3\x02\xc0)\x13I\x84(\x18\x06@\xbcw\xfd\xad\x97i\xe2\xbf\x1e\xf0.Yw\x13\xc0?dW\xd7\xb3\xd6[\x11\xc0\x04\x97\xb3@\xae9\x08\xc0\xbc\x17\xd1C\x15o\xd8?\x02\xb7%t\x18\xe2\x0e@\xac\xd8\xd0T\x13"\x04\xc0\r\x03}\x7f\x1dI\x135L\x0b\x03\x1dK\x1dM\x05\x01\x03\x03W\x03\x03\x8d\x15\x03\x01\x01\x01\x03\x0bWMMM\x91\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\t\r\x15)\x01\x15)\x07\t\x11\x11\x15)\x05\t\x11\x15)\x05\t\x05\x0f\x01)\x03\t\x1f)\x01\x1f\x0b\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f\x13\x1b\x11\x01\t\t\x0b\x05\x05\x11\x07\x1b\t\x07\x03\t\x11\x07\r\x0b\x07\x03\x0b\x11\x07\r\x05\x07\x03\x05)\x05\t\r\x0f)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x0f)\x07\t\x11\x11\x0f)\x03\r\x17!)\x03\r\x1d)\x03\t\x1d)\x03\x05\x1d)\x03\x05\x17\x04\xfe\x05\x05\x01Q\x03\x11\x01\x07\x04\xd6\x05\x03\x01\x11\x07P\x03\x03\x07\x04j\x03\x03=m\x05B\x03\x05\x03\t\x11G)%\x07\x0b\t\x0b\x05\x05\x11\x03\x01\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\r\x0bF\x05\r\x03\x19\x05\x0b\x0f\x03F\x0b\x0f\x03\x1b\x03\x11\x05B\x03\x11\x03\x07\rF\x01\x13\x03\t\x07\x13\x03\x15\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03\x19\x0bF\x05\r\x03\x19\x05\x0b\x1b\x03F\x0b\x0f\x03\r\x03\x1d\x05B\x03\x11\x03\x07\rF\x01\x15\x03\x0b\x07\x1f\x05!\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x03%\x0bF\x05\r\x03\x19\x05\x0b\'\x03F\x0b\x0f\x03\r\x03)\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x07+\x07-\x05B\x03\t\x03\x13\x03F\x05\x0b\x03\x11\x031\x0bF\x05\r\x03\x19\x05\x0b3\x03F\x0b\x0f\x03\r\x035\x05B\x03\x11\x03\x07\rF\x01\x17\x03\x05\x077\t9\t\x04\x03\t\x17#/;\x07P\x01\x19\x07\x04S\x03\r\x13\x077\x01\x13\x01\x0f\x01\x00\x03F\t\x1b\x031\x03\x01\x03F\t\x0b\x03\t\x03\x05\x0f\x06\r\x03\t\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01\x1d\x07\x04S\x03\r\x13\x07\x1b\x01\x17\x01\x0f\x01\x00\x03F\t\x1f\x03/\x03\x01\x03F\t\x0b\x03\x0b\x03\x05\x0f\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01!\x07\x04S\x03\r\x13\x07\x1b\x01\x0b\x01\x0f\x01\x00\x03F\t\x1f\x03)\x03\x01\x03F\t\x0b\x03\x05\x03\x05\x0f\x06\r\x03\x05\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\xea\tO%\x03\x0b\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/A)Sci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00mhlo.backend_config\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00uplo\x00\x00lapack_dsytrd_ffi\x00\x08y#\x057\x01\x0bOY[mo\x03y\x11\x81\x83\x85O\x87\x89\x8b\x8f\x03;\x031\x05=?\x03A\x03C\x03Q\x03S\x03K\x0bEqGQI\x03w\x0bEsGSI\x03U\x0bEuGKI', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 780759e69ee5..a352cee757ca 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2837,24 +2837,35 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) - return tridiagonal(x), 0 + return tridiagonal(x, lower=lower), 0 batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule -def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower): +def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower, platform): a_aval, = ctx.avals_in - a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower) + cpu_args = [] + if platform == "cpu": + # TODO(b/344892332): Remove the conditional after the compatibility period. + ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else () + cpu_args.extend(ctx_args) + a, d, e, taus, info = sytrd_impl(*cpu_args, a_aval.dtype, a, lower=lower) return a, d, e, taus, info mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo), - platform='cpu') + tridiagonal_p, + partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo, platform="cpu"), + platform="cpu", +) mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd), - platform='cuda') + tridiagonal_p, + partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd, platform="cuda"), + platform="cuda", +) mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd), - platform='rocm') + tridiagonal_p, + partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd, platform="rocm"), + platform="rocm", +) # Utilities diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index fa7ef99af7f3..5c1d316cf255 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -548,8 +548,9 @@ def gehrd_hlo(ctx, dtype, a): # sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. -def sytrd_hlo(dtype, a, *, lower): - _lapack.initialize() +def sytrd_hlo(ctx, dtype, a, *, lower): + fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy" + fn_base = prepare_lapack_call(fn_base=fn_base + "trd", dtype=dtype) a_type = ir.RankedTensorType(a.type) dims = a_type.shape assert len(dims) >= 2 @@ -557,52 +558,83 @@ def sytrd_hlo(dtype, a, *, lower): assert m == n, (m, n) batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - b = 1 - for d in batch_dims: - b *= d + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + i32_type = ir.IntegerType.get_signless(32) - if dtype == np.float32: - fn = "lapack_ssytrd" - lwork = _lapack.lapack_ssytrd_workspace(n, n) - diag_type = a_type.element_type - elif dtype == np.float64: - fn = "lapack_dsytrd" - lwork = _lapack.lapack_dsytrd_workspace(n, n) - diag_type = a_type.element_type - elif dtype == np.complex64: - fn = "lapack_chetrd" - lwork = _lapack.lapack_chetrd_workspace(n, n) + if ctx.is_forward_compat(): + fn = fn_base + b = 1 + for d in batch_dims: + b *= d + + if dtype == np.float32: + lwork = _lapack.lapack_ssytrd_workspace(n, n) + diag_type = a_type.element_type + elif dtype == np.float64: + lwork = _lapack.lapack_dsytrd_workspace(n, n) + diag_type = a_type.element_type + elif dtype == np.complex64: + lwork = _lapack.lapack_chetrd_workspace(n, n) + diag_type = ir.F32Type.get() + elif dtype == np.complex128: + lwork = _lapack.lapack_zhetrd_workspace(n, n) + diag_type = ir.F64Type.get() + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + return custom_call( + fn, + result_types=[ + a.type, + ir.RankedTensorType.get(batch_dims + (n,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), + ir.RankedTensorType.get(batch_dims, i32_type), + ir.RankedTensorType.get([lwork], a_type.element_type), + ], + operands=[hlo_s32(n), hlo_s32(1 if lower else 0), hlo_s32(max(1, n)), + hlo_s32(b), hlo_s32(lwork), a], + operand_layouts=[[]] * 5 + [layout], + result_layouts=[ + layout, + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + [0], + ], + operand_output_aliases={5: 0}, + ).results[:5] + fn = fn_base + "_ffi" + if dtype == np.float32 or dtype == np.complex64: diag_type = ir.F32Type.get() - elif dtype == np.complex128: - fn = "lapack_zhetrd" - lwork = _lapack.lapack_zhetrd_workspace(n, n) + elif dtype == np.float64 or dtype == np.complex128: diag_type = ir.F64Type.get() else: raise NotImplementedError(f"Unsupported dtype {dtype}") - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - i32_type = ir.IntegerType.get_signless(32) - out = custom_call( + # Returns x_out, on_diag, off_diag, tau, info + return custom_call( fn, result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (n,), diag_type), - ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type), - ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), - ir.RankedTensorType.get(batch_dims, i32_type), - ir.RankedTensorType.get([lwork], a_type.element_type), + a.type, + ir.RankedTensorType.get(batch_dims + (n,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type), + ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type), + ir.RankedTensorType.get(batch_dims, i32_type), ], - operands=[hlo_s32(n), hlo_s32(1 if lower else 0), hlo_s32(max(1, n)), - hlo_s32(b), hlo_s32(lwork), a], - operand_layouts=[[]] * 5 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ - layout, - (num_bd,) + tuple(range(num_bd - 1, -1, -1)), - (num_bd,) + tuple(range(num_bd - 1, -1, -1)), - (num_bd,) + tuple(range(num_bd - 1, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), ], - operand_output_aliases={5: 0}, + operand_output_aliases={0: 0}, + backend_config={ + "uplo": _matrix_uplo_attr(lower=lower), + }, + api_version=4, ).results - return out[:5] diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index b16cdc787345..ae0848a74a37 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -44,6 +44,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm from jax._src.internal_test_util.export_back_compat_test_data import cpu_hessenberg_lapack_gehrd +from jax._src.internal_test_util.export_back_compat_test_data import cpu_tridiagonal_lapack_sytrd_hetrd from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32 from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf @@ -123,6 +124,7 @@ def test_custom_call_coverage(self): cpu_schur_lapack_gees.data_2024_11_29, cpu_svd_lapack_gesdd.data_2024_08_13, cpu_hessenberg_lapack_gehrd.data_2024_08_31, + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01, ] # Add here all the testdatas that should cover the targets guaranteed # stable @@ -145,6 +147,7 @@ def test_custom_call_coverage(self): cpu_svd_lapack_gesdd.data_2023_06_19, cpu_triangular_solve_blas_trsm.data_2023_07_16, cpu_hessenberg_lapack_gehrd.data_2024_08_30, + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16, @@ -769,6 +772,40 @@ def func(): ) self.run_one_test(func, data, rtol=rtol, atol=atol) + + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + @jax.default_matmul_precision("float32") + def test_cpu_tridiagonal_lapack_sytrd_hetrd(self, dtype_name="f32"): + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + shape = (2, 4, 4) + input_data = jtu.rand_default(self.rng())(shape, dtype) + # del input_data # Input is in the testdata, here for readability + def func(): + return lax.linalg.tridiagonal(input_data, lower=True) + + rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] + atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] + + data = self.load_testdata( + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) + # TODO(b/344892332): Remove the check after the compatibility period. + has_xla_ffi_support = jaxlib_version >= (0, 4, 37) + if has_xla_ffi_support: + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) + def test_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) From 1ac6b762ddcb16af4853c0a3ced151d582455bef Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 9 Dec 2024 06:52:25 -0800 Subject: [PATCH 666/698] Ensured that JAX type checks under pytype on Python 3.12 Some errors uncovered by pytype look genuine and need to be revisited in the in the future. PiperOrigin-RevId: 704268742 --- jax/_src/array.py | 2 +- jax/_src/export/shape_poly.py | 9 +++++---- jax/_src/interpreters/mlir.py | 1 + jax/_src/interpreters/partial_eval.py | 10 ++++++++-- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/experimental/shard_map.py | 2 +- 6 files changed, 17 insertions(+), 9 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index d8182976254e..d5f742915284 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1120,7 +1120,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): bufs.append(buf) break else: - bufs.append(buf) + bufs.append(candidates_list[-1]) return pxla.batched_device_put(x.aval, sharding, bufs, devices) diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 1c4671ee6451..010edef1e54a 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1992,7 +1992,8 @@ def compute_dim_vars_from_arg_shapes( generate the code for computing the dimension variables. It also generates the shape assertions. - Returns: the values of the dimension variables, in the order determined by + Returns: + The values of the dimension variables, in the order determined by `all_dim_vars(args_avals)`. """ dim_vars = all_dim_vars(args_avals) @@ -2006,8 +2007,7 @@ def compute_dim_vars_from_arg_shapes( } synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] - return tuple(dim_values) + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) def _solve_dim_equations( eqns: list[_DimEquation], @@ -2141,7 +2141,8 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)] if not eqns: add_explicit_symbolic_constraints(shape_env) - return shape_env, shape_constraints # SUCCESS + # SUCCESS + return shape_env, shape_constraints # pytype: disable=bad-return-type elif len(eqns) >= nr_eqns: break diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 102e4f490b5c..531177b7244c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1699,6 +1699,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. + assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim s = sharding_impls.SdyArraySharding( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6e2f11833b9d..4b4f8f7eddee 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -177,9 +177,12 @@ def new_arg(self, pval: PartialVal) -> JaxprTracer: if const is None: aval = pval.get_aval() if type(aval) is DShapedArray: + # TODO(dougalm): Fix the type error and remove the pytype pragmas. + # pytype: disable=attribute-error shape = [self.new_instantiated_const(d) if isinstance(d, Tracer) and d._trace.level < self.level else d for d in aval.shape] + # pytype: enable=attribute-error aval = aval.update(shape=tuple(shape)) return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding()) else: @@ -1776,6 +1779,9 @@ def lit(a: Atom) -> Literal | None: newvars: dict[Var, Var] = {} newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) + lit_or_var = ( + lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) + ) dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: @@ -1794,10 +1800,10 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: new_invars = [var(v) for v in jaxpr.invars] new_eqns = [] for eqn in jaxpr.eqns: - invars = [lit(x) or var(x) for x in eqn.invars] + invars = [lit_or_var(x) for x in eqn.invars] outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit(v) or var(v) for v in jaxpr.outvars] + new_outvars = [lit_or_var(v) for v in jaxpr.outvars] jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 418240a4a86e..547415c098b4 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -513,7 +513,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # jaxpr for each branch. branches_known_ : list[core.ClosedJaxpr] = [] branches_staged_: list[core.ClosedJaxpr] = [] - branch_res_avals: list[core.AbstractValue] = [] + branch_res_avals: list[list[core.AbstractValue]] = [] for jaxpr in branches: jaxpr_known, jaxpr_staged, _, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 1c529b8938f1..b4609282e2f8 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1651,7 +1651,7 @@ def _partial_eval_jaxpr_custom_rule( def _add_reshapes(which, jaxpr_known, jaxpr_staged): # add singleton axes to residuals which are from jaxpr_known and are scalars - which_ = [w and not v.aval.shape + which_ = [w and not v.aval.shape # pytype: disable=attribute-error for w, v in zip(which, jaxpr_staged.invars[:len(which)])] if not any(which_): return jaxpr_known, jaxpr_staged assert not jaxpr_known.constvars and not jaxpr_staged.constvars From 79318a08cfab02ee42f8bc8e6417ad8665c3e22d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Dec 2024 07:34:26 -0800 Subject: [PATCH 667/698] Remove dead code after minimum jaxlib version bump to v0.4.36. New minimum xla_extension_version is 299, and the new mlir_api_version is 57. PiperOrigin-RevId: 704280856 --- jax/_src/array.py | 9 +- jax/_src/cache_key.py | 6 +- jax/_src/compiler.py | 6 +- jax/_src/config.py | 521 ++++++----------------- jax/_src/interpreters/pxla.py | 3 - jax/_src/tree_util.py | 156 ++----- jax/experimental/colocated_python/api.py | 4 - tests/api_test.py | 5 +- tests/memories_test.py | 5 - tests/pjit_test.py | 2 - tests/tree_util_test.py | 8 - 11 files changed, 154 insertions(+), 571 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index d5f742915284..7c5385f97e40 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -40,7 +40,6 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe -from jax._src.lib import xla_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, NamedSharding, @@ -1169,12 +1168,8 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): results.append( shard_sharded_device_array_slow_path(x, devices, indices, sharding)) - if xla_extension_version >= 296: - copy_outs = xc.batched_copy_array_to_devices_with_sharding( - batch_xs, batch_devs, batch_shardings, batch_cs) - else: - copy_outs = xc.batched_copy_array_to_devices_with_sharding( # pytype: disable=missing-parameter - batch_xs, batch_devs, batch_shardings) + copy_outs = xc.batched_copy_array_to_devices_with_sharding( + batch_xs, batch_devs, batch_shardings, batch_cs) for i, copy_out in safe_zip(batch_indices, copy_outs): assert results[i] is None results[i] = copy_out diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 2ec645cee407..e4b6e7a2669c 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -24,7 +24,6 @@ from jax._src import config from jax._src.lib import version_str as jaxlib_version_str from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir import passmanager as pm import numpy as np @@ -301,10 +300,7 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_dump_hlo_as_long_text = False debug_options.xla_dump_disable_metadata = False debug_options.xla_dump_hlo_pipeline_re = "" - - # "Requires jaxlib 0.4.36+" - if xla_extension_version > 296: - debug_options.xla_gpu_experimental_autotune_cache_mode = 0 + debug_options.xla_gpu_experimental_autotune_cache_mode = 0 # Optional way to specify the cuda install path to be used by the compiler. # This could possibly affect the cuda version compiled with, but this should diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 2d032069eb02..16fbb890956c 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -36,7 +36,6 @@ from jax._src.interpreters import mlir from jax._src.lib import version as jaxlib_version from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -192,9 +191,8 @@ def get_compile_options( assert device_assignment.computation_count() == num_partitions compile_options.device_assignment = device_assignment - if xla_extension_version >= 294: - build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value - build_options.memory_fitting_effort = config.memory_fitting_effort.value + build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value + build_options.memory_fitting_effort = config.memory_fitting_effort.value if env_options_overrides is not None: # Some overrides are passed directly on build_options. diff --git a/jax/_src/config.py b/jax/_src/config.py index 34a1e152c0d1..f1b170050a6b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -14,30 +14,21 @@ from __future__ import annotations -from collections.abc import Callable, Hashable, Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence import contextlib import functools import itertools import logging import os import sys -import threading -from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast +from typing import Any, Generic, NoReturn, Optional, Protocol, TypeVar, cast -from jax._src import lib from jax._src.lib import guard_lib from jax._src.lib import jax_jit from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version from jax._src import logging_config -# TODO(phawkins): reenable pytype after xla_extension_version >= 295 -# pytype: skip-file - -if xla_extension_version >= 295: - config_ext = xla_client._xla.config -else: - config_ext = None +config_ext = xla_client._xla.config logger = logging.getLogger(__name__) @@ -200,91 +191,38 @@ def parse_flags_with_absl(self): already_configured_with_absl = True -if xla_extension_version >= 295: - def trace_context(): - """Returns a tuple of configuration values that affect tracing. +def trace_context(): + """Returns a tuple of configuration values that affect tracing. - These values are included in the cache key for linear_util.cache. + These values are included in the cache key for linear_util.cache. - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately. - """ - return (axis_env_state.value, mesh_context_manager.value, - xla_metadata_context_manager.value, - abstract_mesh_context_manager.value, - device_context.value, - compute_on_context_manager.value, enable_x64.value, - numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, - eager_constant_folding.value, - numpy_dtype_promotion.value, - default_device.value, random_seed_offset.value, - threefry_partitionable.value, - threefry_gpu_kernel_lowering.value, - sharding_in_types.value, - use_direct_linearize.value, - softmax_custom_jvp.value, - enable_memories.value, - disable_jit.value, - debug_key_reuse.value, - jax_xla_profile_version.value, - # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value, - pgle_profiling_runs.value, - enable_pgle.value, - use_shardy_partitioner.value) -else: - def trace_context(): - """Returns a tuple of configuration values that affect tracing. - - These values are included in the cache key for linear_util.cache. - - Values included in this set should also most likely be included in - the C++ JIT state, which is handled separately. - """ - tls = jax_jit.thread_local_state() - axis_env_state = () - mesh_context_manager = () - abstract_mesh_context_manager = () - device_context = () - xla_metadata_context_manager = () - compute_on_context_manager = () - - context: Any = tls.extra_jit_context - if context and context.axis_env_state is not None: - axis_env_state = context.axis_env_state - if context and context.mesh_context_manager: - mesh_context_manager = context.mesh_context_manager - if context and context.abstract_mesh_context_manager: - abstract_mesh_context_manager = context.abstract_mesh_context_manager - if context and context.device_context: - device_context = context.device_context - if context and context.xla_metadata_context_manager: - xla_metadata_context_manager = context.xla_metadata_context_manager - if context and context.compute_on_context_manager: - compute_on_context_manager = context.compute_on_context_manager - return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager, - device_context, xla_metadata_context_manager, - compute_on_context_manager, enable_x64.value, - numpy_rank_promotion.value, default_matmul_precision.value, - dynamic_shapes.value, - eager_constant_folding.value, - numpy_dtype_promotion.value, - default_device.value, random_seed_offset.value, - threefry_partitionable.value, - threefry_gpu_kernel_lowering.value, - sharding_in_types.value, - use_direct_linearize.value, - softmax_custom_jvp.value, - enable_memories.value, - disable_jit.value, - debug_key_reuse.value, - jax_xla_profile_version.value, - # Technically this affects jaxpr->stablehlo lowering, not tracing. - hlo_source_file_canonicalization_regex.value, - pgle_profiling_runs.value, - enable_pgle.value, - use_shardy_partitioner.value) + Values included in this set should also most likely be included in + the C++ JIT state, which is handled separately. + """ + return (axis_env_state.value, mesh_context_manager.value, + xla_metadata_context_manager.value, + abstract_mesh_context_manager.value, + device_context.value, + compute_on_context_manager.value, enable_x64.value, + numpy_rank_promotion.value, default_matmul_precision.value, + dynamic_shapes.value, + eager_constant_folding.value, + numpy_dtype_promotion.value, + default_device.value, random_seed_offset.value, + threefry_partitionable.value, + threefry_gpu_kernel_lowering.value, + sharding_in_types.value, + use_direct_linearize.value, + softmax_custom_jvp.value, + enable_memories.value, + disable_jit.value, + debug_key_reuse.value, + jax_xla_profile_version.value, + # Technically this affects jaxpr->stablehlo lowering, not tracing. + hlo_source_file_canonicalization_regex.value, + pgle_profiling_runs.value, + enable_pgle.value, + use_shardy_partitioner.value) config = Config() @@ -296,185 +234,85 @@ def trace_context(): class NoDefault: pass no_default = NoDefault() -if xla_extension_version >= 295: - class State(config_ext.Config[_T]): +class State(config_ext.Config[_T]): - __slots__ = ( - '_name', '_update_thread_local_hook', '_update_global_hook', - '_validator', '_default_context_manager_value', '__doc__', '__name__', - ) + __slots__ = ( + '_name', '_update_thread_local_hook', '_update_global_hook', + '_validator', '_default_context_manager_value', '__doc__', '__name__', + ) - def __init__( - self, - name: str, - default: _T, - help, - update_global_hook: Callable[[_T], None] | None = None, - update_thread_local_hook: Callable[[_T | None], None] | None = None, - validator: Callable[[Any], None] | None = None, - extra_description: str = '', - default_context_manager_value: Any = no_default, - include_in_jit_key: bool = False, - ): - super().__init__(default, include_in_jit_key) - self._name = name - self.__name__ = name[4:] if name.startswith('jax_') else name - self.__doc__ = (f"Context manager for `{name}` config option" - f"{extra_description}.\n\n{help}") - self._update_global_hook = update_global_hook - self._update_thread_local_hook = update_thread_local_hook - self._validator = validator - self._default_context_manager_value = default_context_manager_value - if self._validator: - self._validator(default) - if self._update_global_hook: - self._update_global_hook(default) - - def __bool__(self) -> NoReturn: - raise TypeError( - "bool() not supported for instances of type '{0}' " - "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) - - def _set(self, value: _T) -> None: - if self._validator: - self._validator(value) - self.set_global(value) - if self._update_global_hook: - self._update_global_hook(value) - - @contextlib.contextmanager - def __call__(self, new_val: Any = no_default): - if new_val is no_default: - if self._default_context_manager_value is not no_default: - new_val = self._default_context_manager_value # default_context_manager_value provided to constructor - else: - # no default_value provided to constructor and no value provided as an - # argument, so we raise an error - raise TypeError(f"Context manager for {self.__name__} config option " - "requires an argument representing the new value for " - "the config option.") - if self._validator: - self._validator(new_val) - prev_val = self.swap_local(new_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(new_val) - try: - yield - finally: - self.set_local(prev_val) - if self._update_thread_local_hook: - if prev_val is config_ext.unset: - self._update_thread_local_hook(None) - else: - self._update_thread_local_hook(cast(Optional[Any], prev_val)) - - def _add_hooks(self, update_global_hook, update_thread_local_hook): - """Private method that adds hooks to an existing context-manager. - - Used to avoid cyclic import dependencies.""" - self._update_thread_local_hook = update_thread_local_hook - self._update_global_hook = update_global_hook - update_global_hook(self.get_global()) - -else: - class _Unset: pass - unset = _Unset() - - _thread_local_state = threading.local() - - class State(Generic[_T]): # type: ignore[no-redef] - - __slots__ = ( - '_name', '_value', '_update_thread_local_hook', '_update_global_hook', - '_validator', '_default_context_manager_value', '__doc__', '__name__', - ) + def __init__( + self, + name: str, + default: _T, + help, + update_global_hook: Callable[[_T], None] | None = None, + update_thread_local_hook: Callable[[_T | None], None] | None = None, + validator: Callable[[Any], None] | None = None, + extra_description: str = '', + default_context_manager_value: Any = no_default, + include_in_jit_key: bool = False, + ): + super().__init__(default, include_in_jit_key) + self._name = name + self.__name__ = name[4:] if name.startswith('jax_') else name + self.__doc__ = (f"Context manager for `{name}` config option" + f"{extra_description}.\n\n{help}") + self._update_global_hook = update_global_hook + self._update_thread_local_hook = update_thread_local_hook + self._validator = validator + self._default_context_manager_value = default_context_manager_value + if self._validator: + self._validator(default) + if self._update_global_hook: + self._update_global_hook(default) - def __init__( - self, - name: str, - default: _T, - help, - update_global_hook: Callable[[_T], None] | None = None, - update_thread_local_hook: Callable[[_T | None], None] | None = None, - validator: Callable[[Any], None] | None = None, - extra_description: str = '', - default_context_manager_value: Any = no_default, - include_in_jit_key: bool = False, - ): - self._name = name - self.__name__ = name[4:] if name.startswith('jax_') else name - self.__doc__ = (f"Context manager for `{name}` config option" - f"{extra_description}.\n\n{help}") - if include_in_jit_key: - assert update_global_hook is None - assert update_thread_local_hook is None - update_global_hook = lambda val: _update_global_jit_state( - **{self.__name__: val}) - update_thread_local_hook = lambda val: update_thread_local_jit_state( - **{self.__name__: val}) - self._update_global_hook = update_global_hook - self._update_thread_local_hook = update_thread_local_hook - self._validator = validator - self._default_context_manager_value = default_context_manager_value - self._set(default) - def __bool__(self) -> NoReturn: - raise TypeError( - "bool() not supported for instances of type '{0}' " - "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) - - def _set(self, value: _T) -> None: - if self._validator: - self._validator(value) - self._value = value - if self._update_global_hook: - self._update_global_hook(value) - - @property - def value(self) -> _T: - val = _thread_local_state.__dict__.get(self._name, unset) - return cast(_T, val) if val is not unset else self._value - - def get_local(self) -> Any: - return _thread_local_state.__dict__.get(self._name, unset) - - @contextlib.contextmanager - def __call__(self, new_val: Any = no_default): - if new_val is no_default: - if self._default_context_manager_value is not no_default: - new_val = self._default_context_manager_value # default_context_manager_value provided to constructor - else: - # no default_value provided to constructor and no value provided as an - # argument, so we raise an error - raise TypeError(f"Context manager for {self.__name__} config option " - "requires an argument representing the new value for " - "the config option.") - if self._validator: - self._validator(new_val) - prev_val = getattr(_thread_local_state, self._name, unset) - setattr(_thread_local_state, self._name, new_val) + def __bool__(self) -> NoReturn: + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format( + type(self).__name__)) + + def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) + self.set_global(value) + if self._update_global_hook: + self._update_global_hook(value) + + @contextlib.contextmanager + def __call__(self, new_val: Any = no_default): + if new_val is no_default: + if self._default_context_manager_value is not no_default: + new_val = self._default_context_manager_value # default_context_manager_value provided to constructor + else: + # no default_value provided to constructor and no value provided as an + # argument, so we raise an error + raise TypeError(f"Context manager for {self.__name__} config option " + "requires an argument representing the new value for " + "the config option.") + if self._validator: + self._validator(new_val) + prev_val = self.swap_local(new_val) + if self._update_thread_local_hook: + self._update_thread_local_hook(new_val) + try: + yield + finally: + self.set_local(prev_val) if self._update_thread_local_hook: - self._update_thread_local_hook(new_val) - try: - yield - finally: - if prev_val is unset: - delattr(_thread_local_state, self._name) - if self._update_thread_local_hook: - self._update_thread_local_hook(None) + if prev_val is config_ext.unset: + self._update_thread_local_hook(None) else: - setattr(_thread_local_state, self._name, prev_val) - if self._update_thread_local_hook: - self._update_thread_local_hook(cast(_T, prev_val)) + self._update_thread_local_hook(cast(Optional[Any], prev_val)) - def _add_hooks(self, update_global_hook, update_thread_local_hook): - """Private method that adds hooks to an existing context-manager. + def _add_hooks(self, update_global_hook, update_thread_local_hook): + """Private method that adds hooks to an existing context-manager. - Used to avoid cyclic import dependencies.""" - self._update_thread_local_hook = update_thread_local_hook - self._update_global_hook = update_global_hook - update_global_hook(self._value) + Used to avoid cyclic import dependencies.""" + self._update_thread_local_hook = update_thread_local_hook + self._update_global_hook = update_global_hook + update_global_hook(self.get_global()) UPGRADE_BOOL_HELP = ( @@ -975,132 +813,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: already_configured_with_absl = False -if xla_extension_version >= 295: - trace_state = config_ext.Config(None, include_in_jit_key=True) - axis_env_state = config_ext.Config((), include_in_jit_key=True) - mesh_context_manager = config_ext.Config((), include_in_jit_key=True) - abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True) - device_context = config_ext.Config((), include_in_jit_key=True) - compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) - xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) -else: - # The C++ JIT maintains its own copy of several configuration items as - # a global/thread-local state. These methods allow updates to part of the - # state when a configuration value changes. - class _GlobalExtraJitContext(NamedTuple): - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool = False - eager_constant_folding: bool = False - random_seed_offset: int = 0 - threefry_partitionable: bool = False - threefry_gpu_kernel_lowering: bool = False - sharding_in_types: bool = False - use_direct_linearize: bool = False - softmax_custom_jvp: bool = False - xla_profile_version: int = 0 - pgle_profiling_runs: int = 0 - enable_pgle: bool = False - use_shardy_partitioner: bool = False - - - def _update_global_jit_state(**kw): - gs = jax_jit.global_state() - context = gs.extra_jit_context or _GlobalExtraJitContext() - gs.extra_jit_context = context._replace(**kw) - - - class _ThreadLocalExtraJitContext(NamedTuple): - """A namedtuple containing states to add to the cache key. - - Just in time compilation (for jit, pmap, etc) behavior is configurable through - global and thread-local options, used in the cache key. - - The initialization, which uses both config.py and core.py is done using - `_update_thread_local_jit_state` in core.py to prevent circular imports. - """ - trace_state: Any | None = None - axis_env_state: Hashable = () - mesh_context_manager: Hashable = () - abstract_mesh_context_manager: Hashable = () - device_context: Hashable = () - compute_on_context_manager: Hashable = () - xla_metadata_context_manager: Hashable = () - - # Values set by _StateContextManager context managers. - # CAUTION: these must be initialized to `None`! The state context manager - # restores these to None on exit. If the object default is not `None`, the - # context manager is not a no-op, which leads to problems with stale state - # (e.g. spurious cache misses in tests). - numpy_rank_promotion: str | None = None - numpy_dtype_promotion: str | None = None - default_matmul_precision: Any | None = None - dynamic_shapes: bool | None = None - eager_constant_folding : bool | None = None - random_seed_offset: int | None = None - threefry_partitionable: bool | None = None - threefry_gpu_kernel_lowering: bool | None = None - sharding_in_types: bool | None = None - use_direct_linearize: bool | None = None - softmax_custom_jvp: bool | None = None - xla_profile_version: int | None = None - pgle_profiling_runs: int | None = None - enable_pgle: bool | None = None - use_shardy_partitioner: bool | None = None - - - class _ThreadLocalStateCache(threading.local): - """"A thread local cache for _ThreadLocalExtraJitContext - - The extra_jit_context in jax_jit.thread_local_state() may get updated and thus - incurring dispatch overhead for comparing this python object during jit calls. - We want to deduplicate the objects that have the same hash/equality to also - have the same object ID, since the equality check is much faster if the object - IDs match. - """ - def __init__(self): - self.canonicalize = functools.lru_cache(128)(lambda x: x) - - - _thread_local_state_cache = _ThreadLocalStateCache() - - - def update_thread_local_jit_state(**kw): - tls = jax_jit.thread_local_state() - # After xla_client._version >= 70, the thread_local object will necessarily - # be initialized when accessed. The following line can be removed when the - # minimum jaxlib version is past version 70 - context = tls.extra_jit_context or _ThreadLocalExtraJitContext() - tmp = context._replace(**kw) - tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) - - class JitConfig: - def __init__(self, name): - self._name = name - - @property - def value(self): - return self.get_local() - - def get_local(self): - return getattr(jax_jit.thread_local_state().extra_jit_context, self._name) - - def set_local(self, value): - update_thread_local_jit_state(**{self._name: value}) - - def swap_local(self, new_value): - prev_value = self.value - self.set_local(new_value) - return prev_value - - trace_state = JitConfig('trace_state') - axis_env_state = JitConfig('axis_env_state') - mesh_context_manager = JitConfig('mesh_context_manager') - abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager') - device_context = JitConfig('device_context') - compute_on_context_manager = JitConfig('compute_on_context_manager') - xla_metadata_context_manager = JitConfig('xla_metadata_context_manager') +trace_state = config_ext.Config(None, include_in_jit_key=True) +axis_env_state = config_ext.Config((), include_in_jit_key=True) +mesh_context_manager = config_ext.Config((), include_in_jit_key=True) +abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True) +device_context = config_ext.Config((), include_in_jit_key=True) +compute_on_context_manager = config_ext.Config((), include_in_jit_key=True) +xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True) # TODO(b/214340779): remove flag when XLA:CPU is improved. @@ -1254,10 +973,10 @@ def swap_local(self, new_value): help='If True, pmap and shard_map API will be merged.') def _update_jax_memories_global(val): - lib.jax_jit.global_state().enable_memories = val + jax_jit.global_state().enable_memories = val def _update_jax_memories_thread_local(val): - lib.jax_jit.thread_local_state().enable_memories = val + jax_jit.thread_local_state().enable_memories = val enable_memories = bool_state( 'jax_enable_memories', @@ -1576,10 +1295,10 @@ def _update_jax_memories_thread_local(val): ) def _update_x64_global(val): - lib.jax_jit.global_state().enable_x64 = val + jax_jit.global_state().enable_x64 = val def _update_x64_thread_local(val): - lib.jax_jit.thread_local_state().enable_x64 = val + jax_jit.thread_local_state().enable_x64 = val enable_x64 = bool_state( name='jax_enable_x64', @@ -1594,11 +1313,11 @@ def _update_x64_thread_local(val): setattr(Config, "x64_enabled", property(lambda _: enable_x64.value)) def _update_default_device_global(val): - lib.jax_jit.global_state().default_device = val + jax_jit.global_state().default_device = val def _update_default_device_thread_local(val): - lib.jax_jit.thread_local_state().default_device = val + jax_jit.thread_local_state().default_device = val def _validate_default_device(val): @@ -1632,10 +1351,10 @@ def _validate_default_device(val): validator=_validate_default_device) def _update_disable_jit_global(val): - lib.jax_jit.global_state().disable_jit = val + jax_jit.global_state().disable_jit = val def _update_disable_jit_thread_local(val): - lib.jax_jit.thread_local_state().disable_jit = val + jax_jit.thread_local_state().disable_jit = val disable_jit = bool_state( name='jax_disable_jit', diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 41f91202e1de..d48e81b9092c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -61,7 +61,6 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -108,8 +107,6 @@ class WeakRefList(list): def to_xc_copy_semantics(copy_semantics): - if xla_extension_version < 296: - return [None] * len(copy_semantics) out = [] for cs in copy_semantics: if cs is None or cs == dispatch.CopySemantics.ALIAS: diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index bb9924f8bb72..77871f3a908f 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -16,7 +16,6 @@ import collections from collections.abc import Callable, Hashable, Iterable, Sequence import dataclasses -from dataclasses import dataclass import difflib import functools from functools import partial @@ -26,7 +25,6 @@ from jax._src import traceback_util from jax._src.lib import pytree -from jax._src.lib import xla_extension_version from jax._src.util import safe_zip, set_module from jax._src.util import unzip2 @@ -289,20 +287,15 @@ def register_pytree_node( >>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32) """ - if xla_extension_version >= 299: - default_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] - ) - none_leaf_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] - ) - dispatch_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] - ) - else: - default_registry.register_node(nodetype, flatten_func, unflatten_func) - none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) - dispatch_registry.register_node(nodetype, flatten_func, unflatten_func) + default_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + ) + none_leaf_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + ) + dispatch_registry.register_node( # type: ignore[call-arg] + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + ) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) @@ -710,47 +703,10 @@ def _equality_errors(path, t1, t2, is_leaf): yield from _equality_errors((*path, k), c1, c2, is_leaf) -@export -@dataclass(frozen=True) -class SequenceKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - idx: int - def __str__(self): - return f'[{self.idx!r}]' - - -@export -@dataclass(frozen=True) -class DictKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - key: Hashable - def __str__(self): - return f'[{self.key!r}]' - - -@export -@dataclass(frozen=True) -class GetAttrKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - name: str - def __str__(self): - return f'.{self.name}' - - -@export -@dataclass(frozen=True) -class FlattenedIndexKey(): - """Struct for use with :func:`jax.tree_util.register_pytree_with_keys`.""" - key: int - def __str__(self): - return f'[]' - - -if xla_extension_version >= 299: - SequenceKey = pytree.SequenceKey # type: ignore - DictKey = pytree.DictKey # type: ignore - GetAttrKey = pytree.GetAttrKey # type: ignore - FlattenedIndexKey = pytree.FlattenedIndexKey # type: ignore +SequenceKey: Any = pytree.SequenceKey # type: ignore +DictKey: Any = pytree.DictKey # type: ignore +GetAttrKey: Any = pytree.GetAttrKey # type: ignore +FlattenedIndexKey: Any = pytree.FlattenedIndexKey # type: ignore @export @@ -1098,38 +1054,21 @@ def flatten_func(x): return nodetype -if xla_extension_version >= 299: - register_pytree_with_keys( - collections.OrderedDict, - lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())), - lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), - ) - - def _flatten_defaultdict_with_keys(d): - keys = tuple(sorted(d)) - return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys) - - register_pytree_with_keys( - collections.defaultdict, - _flatten_defaultdict_with_keys, - lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), - ) -else: - register_pytree_node( - collections.OrderedDict, - lambda x: (tuple(x.values()), tuple(x.keys())), - lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), - ) +register_pytree_with_keys( + collections.OrderedDict, + lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())), + lambda keys, values: collections.OrderedDict(safe_zip(keys, values)), +) - def _flatten_defaultdict(d): - keys = tuple(sorted(d)) - return tuple(d[k] for k in keys), (d.default_factory, keys) +def _flatten_defaultdict_with_keys(d): + keys = tuple(sorted(d)) + return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys) - register_pytree_node( - collections.defaultdict, - _flatten_defaultdict, - lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), - ) +register_pytree_with_keys( + collections.defaultdict, + _flatten_defaultdict_with_keys, + lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)), +) @export @@ -1184,10 +1123,7 @@ def tree_flatten_with_path( which contains a leaf and its key path. The second element is a treedef representing the structure of the flattened tree. """ - if xla_extension_version >= 299: - return default_registry.flatten_with_path(tree, is_leaf) - _, tree_def = tree_flatten(tree, is_leaf) - return _generate_key_paths(tree, is_leaf), tree_def + return default_registry.flatten_with_path(tree, is_leaf) @export @@ -1213,46 +1149,10 @@ def tree_leaves_with_path( def generate_key_paths( tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[tuple[KeyPath, Any]]: - if xla_extension_version >= 299: - return tree_leaves_with_path(tree, is_leaf) - return list(_generate_key_paths_((), tree, is_leaf)) + return tree_leaves_with_path(tree, is_leaf) _generate_key_paths = generate_key_paths # alias for backward compat -# The overall logic should be same as PyTreeDef::FlattenIntoImpl -def _generate_key_paths_( - key_path: KeyPath, - tree: Any, - is_leaf: Callable[[Any], bool] | None = None, -) -> Iterable[tuple[KeyPath, Any]]: - if is_leaf and is_leaf(tree): - yield key_path, tree - return - key_handler = _registry_with_keypaths.get(type(tree)) - if key_handler: - key_children, _ = key_handler.flatten_with_keys(tree) - for k, c in key_children: - yield from _generate_key_paths_((*key_path, k), c, is_leaf) - return - - flat = default_registry.flatten_one_level(tree) - if flat is None: - yield key_path, tree # strict leaf type - return - - if (isinstance(tree, tuple) and hasattr(tree, '_fields') and - flat[1] == type(tree)): - # handle namedtuple as a special case, based on heuristic - key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields] - for k, c in key_children: - yield from _generate_key_paths_((*key_path, k), c, is_leaf) - return - - for i, c in enumerate(flat[0]): - k = FlattenedIndexKey(i) - yield from _generate_key_paths_((*key_path, k), c, is_leaf) - - @export def tree_map_with_path(f: Callable[..., Any], tree: Any, *rest: Any, diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index abf92306ef0f..770820b39222 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -20,7 +20,6 @@ import jax from jax._src import api_util -from jax._src.lib import xla_extension_version from jax.experimental.colocated_python.func import make_callable @@ -28,9 +27,6 @@ def colocated_cpu_devices( devices: Sequence[jax.Device], ) -> Sequence[jax.Device]: """Finds CPU devices colocated with the given devices.""" - if xla_extension_version < 290: - raise NotImplementedError("Requires xla_extension_version >= 290") - cpu_devices_by_colocation_id = collections.defaultdict(list) for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access if device.device_kind == "cpu": diff --git a/tests/api_test.py b/tests/api_test.py index a6c1c5d53d91..398070764a23 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -60,7 +60,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import xla_extension import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -1388,9 +1388,6 @@ def f(x): })(1.0) # doesn't crash. def test_exec_time_optimization_effort_compiler_option(self): - if xla_extension_version < 294: - raise unittest.SkipTest("test requires newer xla extension version") - def f(x): return jnp.sqrt(x ** 2) + 1. diff --git a/tests/memories_test.py b/tests/memories_test.py index 9c9b3a4ad2bf..1547ccd1fc5a 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -25,7 +25,6 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.lib import xla_extension_version from jax._src.layout import DeviceLocalLayout as DLL, Layout from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint @@ -697,8 +696,6 @@ def foo(x): self.assertIn('custom_call_target="AllocateBuffer"', compiled_text) def test_disallow_alias_copies_arrays(self): - if xla_extension_version < 296: - self.skipTest("Requires xla_extension_version >= 296") mesh = jtu.create_mesh((2,), ("x",)) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P("x"), memory_kind="pinned_host") @@ -712,8 +709,6 @@ def test_disallow_alias_copies_arrays(self): jax.block_until_ready(inp_host_copy) def test_disallow_alias_copies_arrays_with_donated_input(self): - if xla_extension_version < 296: - self.skipTest("Requires xla_extension_version >= 296") mesh = jtu.create_mesh((2,), ("x",)) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P("x"), memory_kind="pinned_host") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7e66c41cca76..72c43d6c2222 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -59,7 +59,6 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -3817,7 +3816,6 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P())) self.assertEqual(out.sharding.memory_kind, 'device') - @unittest.skipIf(xla_extension_version < 297, "Requires jaxlib 0.4.36+") def test_jit_static_argnames_non_interned(self): def do_nothing(foobar: int): return foobar diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index bd0497a33820..1b921121e27d 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -25,7 +25,6 @@ from jax import flatten_util from jax import tree_util from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -800,8 +799,6 @@ def testTreeFlattenWithPathCustom(self): ) def testFlattenWithPathDefaultDict(self): - if xla_extension_version < 299: - self.skipTest("Skipping for Python-based with path APIs.") d = collections.defaultdict(int, {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) leaves, treedef = tree_util.tree_flatten_with_path(d) self.assertEqual( @@ -819,8 +816,6 @@ def testFlattenWithPathDefaultDict(self): self.assertEqual(treedef, from_flatten) def testFlattenWithPathOrderedDict(self): - if xla_extension_version < 299: - self.skipTest("Skipping for Python-based with path APIs.") d = collections.OrderedDict({"b": 2, "a": 1, "c": {"b": 2, "a": 1}}) leaves, treedef = tree_util.tree_flatten_with_path(d) self.assertEqual( @@ -920,9 +915,6 @@ def testBadFlattenNonIterableLeaves(self): class TreeKeyTest(absltest.TestCase): def testBasic(self): - if xla_extension_version < 299: - self.skipTest("Skipping for Python-based with path APIs.") - def assert_equal_and_hash_equal(a, b): self.assertEqual(a, b) self.assertEqual(hash(a), hash(b)) From 9c98c0cbbf0698cb131bb4c6e5ec5807b78aca0d Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 9 Dec 2024 08:22:56 -0800 Subject: [PATCH 668/698] [Pallas TPU] Improve lowerings for boolean comparison operations The error when negating a boolean value (https://github.com/jax-ml/jax/issues/24243) has been fixed, so we can lower the boolean comparison operations using boolean algebra instead of using the previous workaround. Besides, the original tests uses `allclose` on boolean arrays, which is wrong. I have changed them to `assertArraysEqual`. PiperOrigin-RevId: 704294742 --- jax/_src/pallas/mosaic/lowering.py | 100 +++++++++++++++-------------- tests/pallas/ops_test.py | 24 ++++--- 2 files changed, 67 insertions(+), 57 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index af5aa66a3851..c3211efa2031 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -63,7 +63,7 @@ from jax._src.state import primitives as state_primitives from jax._src.state.types import RefBitcaster, RefReshaper from jax._src.state.utils import dtype_bitwidth -from jax._src.typing import DTypeLike +from jax._src.typing import Array, DTypeLike from jax._src.util import safe_map from jax._src.util import safe_zip from jax._src.util import split_list @@ -2295,7 +2295,49 @@ def _population_count_lowering_rule(ctx: LoweringRuleContext, x): } -def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): +# The relationship between comparison operations on booleans and boolean +# algebra is as follows: +# eq(x, y) = !(x ^ y) +# ne(x, y) = x ^ y +# lt(x, y) = !x && y +# le(x, y) = !x || y +# gt(x, y) = x && !y +# ge(x, y) = x || !y +def _cmp_boolean_lowering_helper(primitive, x: Array, y: Array): + """A helper function for lowering comparison operations for boolean inputs. + + Args: + primitive: A JAX primitive representing a comparison operation, which is + one of the following: `lax.eq_p` (equals), `lax.ne_p` (not equals), + `lax.lt_p` (less than), `lax.le_p` (less than or equal to), + `lax.gt_p` (greater than), or `lax.ge_p` (greater than or equal to). + x: A boolean array representing the first operand in the comparison. + y: A boolean array representing the second operand in the comparison. + + Returns: + A boolean array that is the result of applying the comparison operation + between `x` and `y` based on the given primitive. + + Raises: + ValueError: If an unsupported comparison primitive is provided. + """ + if primitive == lax.eq_p: + return jnp.logical_not(jnp.logical_xor(x, y)) + elif primitive == lax.ne_p: + return jnp.logical_xor(x, y) + elif primitive == lax.lt_p: + return jnp.logical_and(jnp.logical_not(x), y) + elif primitive == lax.le_p: + return jnp.logical_or(jnp.logical_not(x), y) + elif primitive == lax.gt_p: + return jnp.logical_and(x, jnp.logical_not(y)) + elif primitive == lax.ge_p: + return jnp.logical_or(x, jnp.logical_not(y)) + else: + raise ValueError(f"Unsupported comparison primitive: {primitive}") + + +def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) x_aval, y_aval = ctx.avals_in if x_aval.dtype != y_aval.dtype: @@ -2304,60 +2346,22 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): ) dtype = x_aval.dtype - # For boolean comparisons, we handle them in two different ways. For `ne`, - # we directly use the xor operation since they are equivalent. For all - # other comparisons, we convert the boolean values to `int32` and use select - # operations to perform the comparison. - # - # The relationship between comparison operations on booleans and boolean - # algebra is as follows: - # - # eq(a, b) = !(a ^ b) - # ne(a, b) = a ^ b - # lt(a, b) = !a && b - # le(a, b) = !a || b - # gt(a, b) = a && !b - # ge(a, b) = a || !b - # - # However, except for `ne`, all other operations require negation, which is - # currently not supported. At present, even if negation were supported, - # it would still need to be implemented using `select` operations, making - # it equivalent to our current approach. For more details on negation support, - # see https://github.com/jax-ml/jax/issues/24243. if jnp.issubdtype(dtype, jnp.bool_): - if prim == lax.ne_p: - return arith.xori(x, y) - - i32 = ir.IntegerType.get_signless(32) - vtype = ir.VectorType.get(x_aval.shape, i32) - - # Convert `x` and `y` from `bool` to `int32` for comparison, with 2 - # for true and 0 for false. For example, comparing `x > y` is equivalent - # to `(x ? 2 : 0) > (y ? 2 : 0)`. - # - # Note that we cannot use 1 for true because the select operation will be - # misteriously eliminated. - two = arith.constant(i32, 2) - zero = arith.constant(i32, 0) - - out_aval, = ctx.avals_out - if out_aval.shape != (): - # broadcast to vectors if we are comparing vectors - two = vector.broadcast(vtype, two) - zero = vector.broadcast(vtype, zero) - - x = arith.select(x, two, zero) - y = arith.select(y, two, zero) - dtype = jnp.int32 + return lower_fun( + functools.partial(_cmp_boolean_lowering_helper, primitive), + multiple_results=False, + )(ctx, x, y) if jnp.issubdtype(dtype, jnp.integer): is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger) - pred = (_cmpui_lowering_types if is_uint else _cmpsi_lowering_types)[prim] + pred = ( + _cmpui_lowering_types if is_uint else _cmpsi_lowering_types + )[primitive] predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) return arith.cmpi(predicate, x, y) if jnp.issubdtype(dtype, jnp.floating): - pred = _cmpf_lowering_types[prim] + pred = _cmpf_lowering_types[primitive] predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) return arith.cmpf(predicate, x, y) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 8586ae346654..12a2ad49306c 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -951,16 +951,17 @@ def kernel(x_ref, y_ref, o_ref): ] @parameterized.named_parameters( - (f"{fn.__name__}_{dtype}", fn, dtype) + (f"{fn.__name__}_{dtype.__name__}", fn, dtype) for fn, dtype in itertools.product( - COMPARISON_OPS, ["int32", "uint32", "float16", "float32", "bool"] + COMPARISON_OPS, + (jnp.int32, jnp.uint32, jnp.float16, jnp.float32, jnp.bool_), ) ) def test_comparison(self, fn, dtype): - if jtu.test_device_matches(["gpu"]) and dtype == "bool": + if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: self.skipTest("Not implemented on GPU.") - if jtu.test_device_matches(["tpu"]) and dtype == "float16": + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") @functools.partial( @@ -973,16 +974,19 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), fn(x, y)) + out = kernel(x, y) + expected = fn(x, y) + self.assertArraysEqual(out, expected) @parameterized.named_parameters( - (f"{fn.__name__}_{dtype}", fn, dtype) + (f"{fn.__name__}_{dtype.__name__}", fn, dtype) for fn, dtype in itertools.product( - COMPARISON_OPS, ["int32", "uint32", "float16", "float32", "bool"] + COMPARISON_OPS, + (jnp.int32, jnp.uint32, jnp.float16, jnp.float32, jnp.bool_), ) ) def test_comparison_scalar(self, fn, dtype): - if jtu.test_device_matches(["tpu"]) and dtype == "float16": + if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") if ( @@ -1007,7 +1011,9 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), fn(x, y)) + out = kernel(x, y) + expected = fn(x, y) + self.assertArraysEqual(out, expected) def test_isnan(self): @functools.partial( From dd74394e63cfd4d6d7285c10677405f39d92b5fe Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 9 Dec 2024 11:24:11 -0500 Subject: [PATCH 669/698] Use private names for args in api_util to avoid shadowing kwargs keys. This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util. We probably shouldn't use linear_util... --- jax/_src/api_util.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 1bfce85d592c..eb5e7e8bf8de 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): sentinel = object() - args = [sentinel] * (len(fixed_args) + len(dyn_args)) - for i, arg in zip(dyn_argnums, dyn_args): + args = [sentinel] * (len(_fixed_args) + len(dyn_args)) + for i, arg in zip(_dyn_argnums, dyn_args): args[i] = arg - fixed_args_ = iter(fixed_args) + fixed_args_ = iter(_fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel - return f(*args, **kwargs) + return _fun(*args, **kwargs) def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], kwargs: dict[str, Any]): @@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs @lu.transformation2 -def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): - kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs) - return f(*args, **kwargs) +def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): + kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs) + return _fun(*args, **kwargs) @lru_cache(maxsize=4096) @@ -438,9 +438,9 @@ def flat_out_axes( return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) @lu.transformation_with_aux2 -def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): - ans = f(*args, **kwargs) - spec = tree_unflatten(treedef, leaves) +def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs): + ans = _fun(*args, **kwargs) + spec = tree_unflatten(_treedef, _leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) except ValueError: @@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None - store.store(spec_flat) + _store.store(spec_flat) return ans def check_callable(fun): @@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, for path, l in generate_key_paths(x) if l is not static) @lu.transformation_with_aux2 -def result_paths(f, store, *args, **kwargs): +def result_paths(_fun, _store, *args, **kwargs): "linear_util transform to get output pytree paths of pre-flattened function." - ans = f(*args, **kwargs) - store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + ans = _fun(*args, **kwargs) + _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) return ans def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, From b76d264fe725474d83b7ae0059ea7af21da4c63f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 9 Dec 2024 08:34:28 -0800 Subject: [PATCH 670/698] [Mosaic:TPU][NFC] In ext and trunc rules, avoid vreg array reshape by always using implicit shapes PiperOrigin-RevId: 704297805 --- .../tpu/transforms/apply_vector_layout.cc | 51 +++++++------------ 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 5cbb5e620c88..5c9b3d178c15 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -850,20 +850,13 @@ FailureOr> ext_op_rule_impl(RewriteContext &ctx, const VectorLayout &layout_out) { const auto result_ty = cast(op.getResult().getType()); auto source = cast>(op.getIn()); - const auto source_ty = source.getType(); auto output_vregs_shape = - layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape); + layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( xla::Array input_vregs, - disassemble(builder, layout_in, source, ctx.target_shape)); + disassemble(builder, layout_in, source, ctx.target_shape, + /*use_implicit_shape=*/true)); xla::Array output_vregs(output_vregs_shape); - // TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble? - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(), - ctx.target_shape)); - output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), - ctx.target_shape)); - } const VectorType res_vreg_ty = getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_in.implicit_dim() != layout_out.implicit_dim()) { @@ -900,9 +893,6 @@ FailureOr> ext_op_rule_impl(RewriteContext &ctx, op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); }); } - if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - output_vregs.Reshape(output_vregs_shape); - } return output_vregs; } @@ -925,8 +915,9 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, *layouts_out.front())); const auto result_ty = cast(extf_op.getResult().getType()); extf_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), ctx.target_shape) - .getResult()); + std::move(output_vregs), ctx.target_shape, + /*use_implicit_shape=*/true) + .getResult()); extf_op.erase(); return success(); } @@ -946,8 +937,10 @@ LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op, *layouts_out.front())); const auto result_ty = cast(extsi_op.getResult().getType()); extsi_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), ctx.target_shape) - .getResult()); + std::move(output_vregs), + ctx.target_shape, + /*use_implicit_shape=*/true) + .getResult()); extsi_op.erase(); return success(); } @@ -998,8 +991,10 @@ LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op, *v = builder.create(op.getLoc(), res_vreg_ty, unpacked); }); extui_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), - std::move(output_vregs), ctx.target_shape) - .getResult()); + std::move(output_vregs), + ctx.target_shape, + /*use_implicit_shape=*/true) + .getResult()); extui_op.erase(); return success(); } @@ -1010,13 +1005,13 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorLayout &layout_out) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); auto source = cast>(op.getIn()); - const auto source_ty = source.getType(); auto result_ty = cast(op.getResult().getType()); auto output_vregs_shape = - layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape); + layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( xla::Array input_vregs, - disassemble(builder, layout_in, source, ctx.target_shape)); + disassemble(builder, layout_in, source, ctx.target_shape, + /*use_implicit_shape=*/true)); xla::Array output_vregs(output_vregs_shape); if (layout_in.bitwidth() != 32) { return op.emitOpError("Not implemented: Only 32-bit truncation supported"); @@ -1031,12 +1026,6 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, if (layout_in.tiling() != ctx.target_shape) { return op.emitOpError("Not implemented: Only (8,128) tiling supported"); } - if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(), - ctx.target_shape)); - output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), - ctx.target_shape)); - } VectorType res_vreg_ty = getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_out.tiling() == ctx.target_shape) { @@ -1081,11 +1070,9 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, } else { return op.emitOpError("Not implemented: unsupported output tiling"); } - if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - output_vregs.Reshape(output_vregs_shape); - } op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, - std::move(output_vregs), ctx.target_shape) + std::move(output_vregs), ctx.target_shape, + /*use_implicit_shape=*/true) .getResult()); op.erase(); return success(); From f17b2bc2d3d73a91ae149f6ac51b85c2877e2167 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Mon, 9 Dec 2024 08:37:46 -0800 Subject: [PATCH 671/698] Reenable for_loop_test on TPU v5p. PiperOrigin-RevId: 704298792 --- tests/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index f80f17e54455..c25d10f460aa 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1341,9 +1341,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], - disable_configs = [ - "tpu_v5p_1x1", # TODO(b/377666550): enable once XLA is fixed. - ], shard_count = { "cpu": 20, "gpu": 10, From 6f69774c00a44836a52771f20ef25e6fdf4510b3 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 9 Dec 2024 08:46:17 -0800 Subject: [PATCH 672/698] #sdy enable `test_compute_offload_mesh_with_linear_layout` for Shardy. PiperOrigin-RevId: 704301465 --- tests/memories_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index 1547ccd1fc5a..fcb1d6bdc2b1 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1581,10 +1581,6 @@ def test_fn(x_in, y_in): self.assertArraysEqual(y_out, y1 + y1) def test_compute_offload_mesh_with_linear_layout(self): - if config.use_shardy_partitioner.value: - self.skipTest( - "Shardy inlines the host compute. Remove when that's fixed." - ) mesh = jtu.create_mesh((2, 2), ("x", "y")) sharding = NamedSharding(mesh, P("x", "y")) p_sharding = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") From b6863dfcb54eac003da08f5a192578dc412dc320 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:14:14 +0000 Subject: [PATCH 673/698] Bump actions/cache from 4.1.2 to 4.2.0 Bumps [actions/cache](https://github.com/actions/cache) from 4.1.2 to 4.2.0. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/6849a6489940f00c2f30c0fb92c6274307ccb58a...1bd1e32a3bdc45362d1e726936510720a7c30a57) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/ci-build.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 9828a160f2e1..b83c4a34295e 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -35,7 +35,7 @@ jobs: with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} @@ -77,7 +77,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -124,7 +124,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -167,7 +167,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -202,7 +202,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }} @@ -243,7 +243,7 @@ jobs: python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - name: pip cache - uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} From 296d1670bff5c29366e958bcf7ab1755a669e2d1 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 9 Dec 2024 10:42:52 -0800 Subject: [PATCH 674/698] [JAX] Add concurrent execution support in colocated Python This change makes asynchronous execution run without holding a mutex. This allows colocated Python executions from multiple Python threads to run concurrently. PiperOrigin-RevId: 704340663 --- jax/experimental/colocated_python/func.py | 4 +- tests/colocated_python_test.py | 81 +++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index cba2a0f3801b..5567f2f765c1 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -343,7 +343,9 @@ def specialized_func(*args, **kwargs) -> Any: async_execution_func = _make_async_execution_fun(info, specialization) # Fall-through. - return async_execution_func(*args, **kwargs) + # Asynchronous execution runs outside of the mutex to allow concurrent + # execution for inline executors. + return async_execution_func(*args, **kwargs) return specialized_func diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 787d97613a15..bbd5c38068f3 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -13,9 +13,12 @@ # limitations under the License. import contextlib +import threading +import time from typing import Sequence from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import config from jax._src import test_util as jtu @@ -241,6 +244,84 @@ def add_one(x): self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count[0], 2) + @parameterized.named_parameters( + ("on_main_thread", True), + ("on_non_main_thread", False), + ) + def testSequentialExecution(self, on_main_thread: bool): + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + # Make sure that this input array is ready for use by the colocated Python + # function and does not disrupt elapsed time measurement. + jax.block_until_ready(x) + + @colocated_python.colocated_python + def sleep(x: jax.Array) -> jax.Array: + time.sleep(5) + return x + + # Specify out_specs_fn so that all executions are asynchronously dispatched. + sleep = sleep.specialize(out_specs_fn=lambda x: x) + + def sleep_twice_and_wait(x: jax.Array) -> None: + _ = sleep(x) + jax.block_until_ready(sleep(x)) + + start_time = time.time() + + # Two executions of `sleep` within `sleep_twice_and_wait` should run + # sequentially. + if on_main_thread: + sleep_twice_and_wait(x) + else: + t = threading.Thread(target=sleep_twice_and_wait, args=(x,)) + t.start() + t.join() + + elapsed_time = time.time() - start_time + + # If sequential execution did not happen, elapsed time typically will be + # around 5 seconds. + self.assertGreaterEqual(elapsed_time, 10) + + def testConcurrentExecution(self): + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + x = np.array(1) + x = jax.device_put(x, cpu_devices[0]) + # Make sure that this input array is ready for use by the colocated Python + # function and does not disrupt elapsed time measurement. + jax.block_until_ready(x) + + @colocated_python.colocated_python + def sleep(x: jax.Array) -> jax.Array: + time.sleep(5) + return x + + # Specify out_specs_fn so that all executions are asynchronously dispatched. + sleep = sleep.specialize(out_specs_fn=lambda x: x) + + def sleep_and_wait(x: jax.Array) -> None: + jax.block_until_ready(sleep(x)) + + start_time = time.time() + + # All three executions of `sleep_and_wait` should run concurrently. + t1 = threading.Thread(target=sleep_and_wait, args=(x,)) + t2 = threading.Thread(target=sleep_and_wait, args=(x,)) + t1.start() + t2.start() + sleep_and_wait(x) + t1.join() + t2.join() + + elapsed_time = time.time() - start_time + + self.assertGreaterEqual(elapsed_time, 5) + # If concurrent execution did not happen, elapsed time typically will be + # around 15 seconds. + self.assertLess(elapsed_time, 10) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From bd77a703fd1c2b151423b5dde8f727684e900bca Mon Sep 17 00:00:00 2001 From: IvyZX Date: Mon, 9 Dec 2024 10:44:28 -0800 Subject: [PATCH 675/698] Avoid index out of range error in carry structure check --- CHANGELOG.md | 5 +++++ jax/_src/lax/control_flow/loops.py | 6 ++++++ tests/lax_control_flow_test.py | 13 +++++++++++++ 3 files changed, 24 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e411e27faee..92dcfe6cc3f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.37 +* Bug fixes + * Fix a bug that will throw `index out of range` error in + {func}`jax.lax.while_loop` if the user register pytree node class with + different aux data for the flatten and flatten_with_path. + ## jax 0.4.36 (Dec 5, 2024) * Breaking Changes diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index f62ce2434755..9b2d688c322b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'of the carry output is a {thing2}, so {explanation}' for path, thing1, thing2, explanation in equality_errors(in_carry, out_carry)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: @@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d383e4c6ac20..4b0420fda8f9 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -322,6 +322,19 @@ def testWhileTypeErrors(self): lax.while_loop(lambda c: True, lambda c: (True, True), (np.bool_(True), np.float32(0.))) + def testWhileLoopCustomPytreeDiffAuxData(self): + class Node: + def __init__(self, x, y): + self.x = x + self.y = y + tree_util.register_pytree_with_keys( + Node, + lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys + lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved) + lambda o: ((o.x, o.y), 'without_keys'), # flatten + ) + lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.)) + def testNestedWhileWithDynamicUpdateSlice(self): num = 5 From 092d2a0db598a180b3355108a1bfbed6d8612cf9 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 9 Dec 2024 11:17:08 -0800 Subject: [PATCH 676/698] Add error message when using custom_vmap with reverse-mode AD, and add docstrings. The `custom_vmap` API is discussed in https://github.com/jax-ml/jax/issues/9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs. One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem. PiperOrigin-RevId: 704353963 --- docs/jax.rst | 13 +++++ jax/_src/custom_batching.py | 112 ++++++++++++++++++++++++++++++++++-- jax/_src/interpreters/ad.py | 6 +- tests/api_test.py | 61 ++++++++++++++++++++ 4 files changed, 186 insertions(+), 6 deletions(-) diff --git a/docs/jax.rst b/docs/jax.rst index a5e0dcad5b50..042804792f8a 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -102,6 +102,9 @@ Automatic differentiation closure_convert checkpoint +Customization +------------- + ``custom_jvp`` ~~~~~~~~~~~~~~ @@ -121,6 +124,16 @@ Automatic differentiation custom_vjp custom_vjp.defvjp +``custom_batching`` +~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + custom_batching.custom_vmap + custom_batching.custom_vmap.def_vmap + custom_batching.sequential_vmap + jax.Array (:code:`jax.Array`) ----------------------------- diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index afeef1e18456..74ad261b3218 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Callable +from typing import Any import functools import operator @@ -48,17 +49,93 @@ @custom_api_util.register_custom_decorator_type class custom_vmap: - fun: Callable - vmap_rule: Callable | None - - def __init__(self, fun: Callable): + """Customize the vmap behavior of a JAX-transformable function. + + This decorator is used to customize the behavior of a JAX function under the + :func:`jax.vmap` transformation. A ``custom_vmap``-decorated function will + mostly (see below for caveats) have the same behavior as the underlying + function, except when batched using :py:func:`jax.vmap`. When batched, the + rule defined using :py:func:`~jax.custom_batching.custom_vmap.def_vmap` will + be used. + + For example: + + >>> @jax.custom_batching.custom_vmap + ... def f(x, y): + ... return x + y + ... + >>> @f.def_vmap + ... def f_vmap_rule(axis_size, in_batched, xs, ys): + ... assert all(in_batched) + ... assert xs.shape[0] == axis_size + ... assert ys.shape[0] == axis_size + ... out_batched = True + ... return xs * ys, out_batched + ... + >>> xs = jnp.arange(3) + >>> ys = jnp.arange(1, 4) + >>> jax.vmap(f)(xs, ys) # prints xs * ys instead of xs + ys + Array([0, 2, 6], dtype=int32) + + Of note, ``custom_vmap`` functions do not support reverse-mode autodiff. To + customize both vmap and reverse-mode autodiff, combine ``custom_vmap`` with + :py:class:`jax.custom_vjp`. For example: + + >>> @jax.custom_vjp + ... @jax.custom_batching.custom_vmap + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> @f.def_vmap + ... def f_vmap_rule(axis_size, in_batched, xs, ys): + ... return jnp.cos(xs) * ys, True + ... + >>> def f_fwd(x, y): + ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) + ... + >>> def f_bwd(res, g): + ... cos_x, sin_x, y = res + ... return (cos_x * g * y, sin_x * g) + ... + >>> f.defvjp(f_fwd, f_bwd) + >>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3)) + Array([1., 1., 1.], dtype=float32) + >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) + Array(1., dtype=float32) + + Note that the :py:class:`jax.custom_vjp` must be on the ouside, wrapping the + ``custom_vmap``-decorated function. + """ + + fun: Callable[..., Any] + vmap_rule: Callable[..., tuple[Any, Any]] | None + + def __init__(self, fun: Callable[..., Any]): functools.update_wrapper(self, fun) self.fun = fun self.vmap_rule = None __getattr__ = custom_api_util.forward_attr - def def_vmap(self, vmap_rule: Callable) -> Callable: + def def_vmap( + self, + vmap_rule: Callable[..., tuple[Any, Any]], + ) -> Callable[..., tuple[Any, Any]]: + """Define the vmap rule for this custom_vmap function. + + Args: + vmap_rule: A function that implements the vmap rule. This function should + accept the following arguments: (1) an integer ``axis_size`` as its + first argument, (2) a pytree of booleans with the same structure as the + inputs to the function, specifying whether each argument is batched, + and (3) the batched arguments. It should return a tuple of the batched + output and a pytree of booleans with the same structure as the output, + specifying whether each output element is batched. See the documentation + for :py:func:`jax.custom_batching.custom_vmap` for some examples. + + Returns: + This method passes the rule through, returning ``vmap_rule`` unchanged. + """ self.vmap_rule = vmap_rule return vmap_rule @@ -272,6 +349,31 @@ def tree_merge(mask, lhs_tree, rhs_tree): mask, lhs_tree, rhs_tree) def sequential_vmap(f): + """A special case of ``custom_vmap`` that uses a loop. + + A function decorated with ``sequential_vmap`` will be called sequentially + within a loop when batched. This is useful for functions that don't natively + support batch dimensions. + + For example: + + >>> @jax.custom_batching.sequential_vmap + ... def f(x): + ... jax.debug.print("{}", x) + ... return x + 1 + ... + >>> jax.vmap(f)(jnp.arange(3)) + 0 + 1 + 2 + Array([1, 2, 3], dtype=int32) + + Where the print statements demonstrate that this :py:func:`~jax.vmap` is being + generated using a loop. + + See the documentation for :py:class:`~jax.custom_batching.custom_vmap` for + more details. + """ f = custom_vmap(f) @f.def_vmap diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index c5e78321f331..290dd38ac21b 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -175,7 +175,11 @@ def linearize(traceable, *primals, **kwargs): jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) - assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) + if any(not out_primal_pval.is_known() for out_primal_pval in out_primals_pvals): + raise ValueError( + "Linearization failed to produce known values for all output primals. " + "This is typically caused by attempting to differentiate a function " + "uses an operation that does not support reverse-mode autodiff.") out_primals_consts = [pval.get_known() for pval in out_primals_pvals] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts diff --git a/tests/api_test.py b/tests/api_test.py index 398070764a23..38467809e9d3 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10844,6 +10844,67 @@ def rule(axis_size, in_batched, xs): ys = api.vmap(f)(x=xs) self.assertAllClose(ys, jnp.cos(xs)) + def test_partial_eval_raises(self): + @jax.custom_batching.custom_vmap + def f(x): + return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + del axis_size # unused + return jnp.cos(xs), in_batched[0] + + with self.assertRaisesRegex( + ValueError, + "Linearization failed to produce known values for all output primals", + ): + jax.grad(f)(0.5) + + def test_compose_custom_vjp(self): + @jax.custom_vjp + @jax.custom_batching.custom_vmap + def f(x, y): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + return jnp.cos(xs) * ys, True + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd) + + xs = jnp.linspace(0, 1, 5) + ys = jnp.linspace(-0.1, 0.1, 5) + self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) + jax.grad(f)(xs[0], ys[0]) # Doesn't crash. + + def test_compose_custom_vjp_bwd_rule(self): + # This tests the case where both the forward and backward rules are wrapped + # in custom_vmap. + @jax.custom_batching.sequential_vmap + def fun_fwd(x, y): + return jnp.sin(x) * y, (x, y) + + @jax.custom_batching.sequential_vmap + def fun_bwd(res, ct): + x, y = res + return x * ct, y * ct + + fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) + fun.defvjp(fun_fwd, fun_bwd) + + xs = jnp.linspace(0, 1, 5) + y = jnp.array(0.5, dtype=xs.dtype) + f = jax.vmap(jax.jit(fun), in_axes=(0, None)) + out, f_vjp = jax.vjp(f, xs, y) + f_vjp(out) # Doesn't crash. + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" From 66b900540a74a4e2a2ad0f03fdf39496e3dd4436 Mon Sep 17 00:00:00 2001 From: Kanglan Tang Date: Mon, 9 Dec 2024 12:29:57 -0800 Subject: [PATCH 677/698] Disable pjit ArrayPjitTest.test_device_put_grad test on TPU v5e PiperOrigin-RevId: 704378732 --- tests/pjit_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 72c43d6c2222..5ca87aae6561 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3411,6 +3411,9 @@ def test_device_assignment_mismatch_apply_primitive(self): def test_device_put_grad(self): if jax.device_count() < 8: self.skipTest("Requires >=8 devices.") + if jtu.is_device_tpu(5, 'e'): + self.skipTest('TPU v5e does not support computations that run on a ' + 'non-singleton subset of cores.') def _test(fun, inp, np_inp, in_s): out = fun(inp) From 65b60884114261549ffc2eb937162bdeaa493928 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Mon, 9 Dec 2024 10:44:28 -0800 Subject: [PATCH 678/698] Avoid index out of range error in carry structure check --- CHANGELOG.md | 9 ++++++++- jax/_src/lax/control_flow/loops.py | 6 ++++++ tests/lax_control_flow_test.py | 13 +++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d0f97f439d..92dcfe6cc3f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,14 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.36 +## jax 0.4.37 + +* Bug fixes + * Fix a bug that will throw `index out of range` error in + {func}`jax.lax.while_loop` if the user register pytree node class with + different aux data for the flatten and flatten_with_path. + +## jax 0.4.36 (Dec 5, 2024) * Breaking Changes * This release lands "stackless", an internal change to JAX's tracing diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index f62ce2434755..9b2d688c322b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'of the carry output is a {thing2}, so {explanation}' for path, thing1, thing2, explanation in equality_errors(in_carry, out_carry)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: @@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d383e4c6ac20..4b0420fda8f9 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -322,6 +322,19 @@ def testWhileTypeErrors(self): lax.while_loop(lambda c: True, lambda c: (True, True), (np.bool_(True), np.float32(0.))) + def testWhileLoopCustomPytreeDiffAuxData(self): + class Node: + def __init__(self, x, y): + self.x = x + self.y = y + tree_util.register_pytree_with_keys( + Node, + lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys + lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved) + lambda o: ((o.x, o.y), 'without_keys'), # flatten + ) + lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.)) + def testNestedWhileWithDynamicUpdateSlice(self): num = 5 From 95892fdac86524151b6dadd7d8bedbf915f1500f Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 9 Dec 2024 11:24:11 -0500 Subject: [PATCH 679/698] Use private names for args in api_util to avoid shadowing kwargs keys. This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util. We probably shouldn't use linear_util... --- jax/_src/api_util.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 1bfce85d592c..eb5e7e8bf8de 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): sentinel = object() - args = [sentinel] * (len(fixed_args) + len(dyn_args)) - for i, arg in zip(dyn_argnums, dyn_args): + args = [sentinel] * (len(_fixed_args) + len(dyn_args)) + for i, arg in zip(_dyn_argnums, dyn_args): args[i] = arg - fixed_args_ = iter(fixed_args) + fixed_args_ = iter(_fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel - return f(*args, **kwargs) + return _fun(*args, **kwargs) def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], kwargs: dict[str, Any]): @@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs @lu.transformation2 -def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): - kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs) - return f(*args, **kwargs) +def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): + kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs) + return _fun(*args, **kwargs) @lru_cache(maxsize=4096) @@ -438,9 +438,9 @@ def flat_out_axes( return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) @lu.transformation_with_aux2 -def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): - ans = f(*args, **kwargs) - spec = tree_unflatten(treedef, leaves) +def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs): + ans = _fun(*args, **kwargs) + spec = tree_unflatten(_treedef, _leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) except ValueError: @@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None - store.store(spec_flat) + _store.store(spec_flat) return ans def check_callable(fun): @@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, for path, l in generate_key_paths(x) if l is not static) @lu.transformation_with_aux2 -def result_paths(f, store, *args, **kwargs): +def result_paths(_fun, _store, *args, **kwargs): "linear_util transform to get output pytree paths of pre-flattened function." - ans = f(*args, **kwargs) - store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + ans = _fun(*args, **kwargs) + _store.store([keystr(path) for path, _ in generate_key_paths(ans)]) return ans def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, From ffb07cdadb5dc3bc43485cf041dbc2b43136109e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 9 Dec 2024 15:38:37 -0500 Subject: [PATCH 680/698] Update versions for v0.4.37 release. --- CHANGELOG.md | 4 ++++ jax/version.py | 2 +- setup.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92dcfe6cc3f6..c2e237dd6f4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.37 +This is a patch release of jax 0.4.36. Only "jax" was released at this version. + * Bug fixes + * Fixed a bug where `jit` would error if an argument was named `f` (#25329). * Fix a bug that will throw `index out of range` error in {func}`jax.lax.while_loop` if the user register pytree node class with different aux data for the flatten and flatten_with_path. + * Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e. ## jax 0.4.36 (Dec 5, 2024) diff --git a/jax/version.py b/jax/version.py index 941b34f1226f..9da3d63f8708 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.36" +_version = "0.4.37" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index ea42d625eadc..65b07f931a71 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.35' -_libtpu_version = '0.0.5' +_libtpu_version = '0.0.6' _libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup' def load_version_module(pkg_path): From 32df37e6e464d09b03cfaf2f47bbdc18686c1a6a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 9 Dec 2024 12:40:44 -0800 Subject: [PATCH 681/698] Port symmetric tridiagonal reduction GPU kernel to FFI. PiperOrigin-RevId: 704382200 --- jaxlib/gpu/gpu_kernels.cc | 6 ++ jaxlib/gpu/solver.cc | 1 + jaxlib/gpu/solver_interface.cc | 28 +++++++++ jaxlib/gpu/solver_interface.h | 25 ++++++-- jaxlib/gpu/solver_kernels_ffi.cc | 98 +++++++++++++++++++++++++++++++- jaxlib/gpu/solver_kernels_ffi.h | 1 + 6 files changed, 152 insertions(+), 7 deletions(-) diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 62977c5f57a1..a1e59385e6fa 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -60,8 +60,14 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", SyevdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA", + SytrdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA", + GesvdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvdj_ffi", "CUDA", + GesvdjFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_cholesky_update_ffi", "CUDA", CholeskyUpdateFfi); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 38936ee497cf..c74d9a1476c2 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -482,6 +482,7 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi); + dict[JAX_GPU_PREFIX "solver_sytrd_ffi"] = EncapsulateFfiHandler(SytrdFfi); #ifdef JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index 4d1af3c50d76..d93d049d41db 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -317,6 +317,34 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); #endif // JAX_GPU_CUDA +// Symmetric tridiagonal reduction: sytrd + +#define JAX_GPU_DEFINE_SYTRD(Type, Name) \ + template <> \ + absl::StatusOr SytrdBufferSize(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, \ + /*E=*/nullptr, /*tau=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Sytrd(gpusolverDnHandle_t handle, \ + gpusolverFillMode_t uplo, int n, Type *a, \ + RealType::value *d, RealType::value *e, \ + Type *tau, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, uplo, n, a, n, d, e, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_SYTRD(float, gpusolverDnSsytrd); +JAX_GPU_DEFINE_SYTRD(double, gpusolverDnDsytrd); +JAX_GPU_DEFINE_SYTRD(gpuComplex, gpusolverDnChetrd); +JAX_GPU_DEFINE_SYTRD(gpuDoubleComplex, gpusolverDnZhetrd); +#undef JAX_GPU_DEFINE_SYTRD + } // namespace solver } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h index 336480e2e13b..e84a688a6081 100644 --- a/jaxlib/gpu/solver_interface.h +++ b/jaxlib/gpu/solver_interface.h @@ -188,8 +188,8 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBufferSize); #define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \ gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ - Type *a, Real *s, Type *u, Type *v, Type *workspace, \ - int lwork, int *info, gesvdjInfo_t params + Type *a, Real *s, Type *u, Type *v, Type *workspace, int lwork, \ + int *info, gesvdjInfo_t params JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); #undef JAX_GPU_SOLVER_Gesvdj_ARGS @@ -199,15 +199,28 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBatchedBufferSize); #undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS -#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \ - gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \ - Real *s, Type *u, Type *v, Type *workspace, int lwork, \ - int *info, gpuGesvdjInfo_t params, int batch +#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \ + Real *s, Type *u, Type *v, Type *workspace, int lwork, int *info, \ + gpuGesvdjInfo_t params, int batch JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); #undef JAX_GPU_SOLVER_GesvdjBatched_ARGS #endif // JAX_GPU_CUDA +// Symmetric tridiagonal reduction: sytrd + +#define JAX_GPU_SOLVER_SytrdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpublasFillMode_t uplo, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SytrdBufferSize); +#undef JAX_GPU_SOLVER_SytrdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Sytrd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpublasFillMode_t uplo, int n, Type *a, Real *d, \ + Real *e, Type *tau, Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Sytrd); +#undef JAX_GPU_SOLVER_Sytrd_ARGS + #undef JAX_GPU_SOLVER_EXPAND_DEFINITION } // namespace solver diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index b5742b5a7972..7e6f14ed4717 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -915,7 +915,8 @@ ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto s_data = static_cast::value*>(s->untyped_data()); + auto s_data = + static_cast::value*>(s->untyped_data()); auto u_data = static_cast(u->untyped_data()); auto v_data = static_cast(v->untyped_data()); auto info_data = info->typed_data(); @@ -1014,6 +1015,101 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch, #endif // JAX_GPU_CUDA +// Symmetric tridiagonal reduction: sytrd + +template +ffi::Error SytrdImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result d, + ffi::Result e, + ffi::Result tau, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + FFI_ASSIGN_OR_RETURN(int lwork, + solver::SytrdBufferSize(handle.get(), uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "sytrd")); + + auto* a_data = static_cast(a.untyped_data()); + auto* out_data = static_cast(out->untyped_data()); + auto* d_data = + static_cast::value*>(d->untyped_data()); + auto* e_data = + static_cast::value*>(e->untyped_data()); + auto* tau_data = static_cast(tau->untyped_data()); + auto* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + int out_step = n * n; + for (int64_t i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Sytrd(handle.get(), uplo, n, out_data, + d_data, e_data, tau_data, + workspace, lwork, info_data)); + out_data += out_step; + d_data += n; + e_data += n - 1; + tau_data += n - 1; + ++info_data; + } + return ffi::Error::Success(); +} + +ffi::Error SytrdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool lower, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result d, + ffi::Result e, + ffi::Result tau, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + d->element_type() != ffi::ToReal(dataType) || + e->element_type() != ffi::ToReal(dataType) || + tau->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to sytrd must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to sytrd must be square"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "sytrd")); + FFI_RETURN_IF_ERROR(CheckShape(d->dimensions(), {batch, cols}, "d", "sytrd")); + FFI_RETURN_IF_ERROR( + CheckShape(e->dimensions(), {batch, cols - 1}, "e", "sytrd")); + FFI_RETURN_IF_ERROR( + CheckShape(tau->dimensions(), {batch, cols - 1}, "tau", "sytrd")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "sytrd")); + SOLVER_DISPATCH_IMPL(SytrdImpl, batch, rows, stream, scratch, lower, a, out, + d, e, tau, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in sytrd", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SytrdFfi, SytrdDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("lower") + .Arg() // a + .Ret() // out + .Ret() // d + .Ret() // e + .Ret() // tau + .Ret>() // info +); + #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 022564eb108c..2f9494d7fb38 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -36,6 +36,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi); #ifdef JAX_GPU_CUDA XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); From 71c48cba1cd8435b753ac197cd822d3cb15f4ec9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 9 Dec 2024 13:55:35 -0800 Subject: [PATCH 682/698] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a041e1b15524cd15751e9a5b5dc581b9f276958f. PiperOrigin-RevId: 704406817 --- third_party/xla/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index e07d72edc453..110b5e055b31 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "fb46636c1f8f88b3bd69bf0523b76c904191d1ad" -XLA_SHA256 = "20c5009feca949739a89b4f0077caac9345b623fafbc91f154d59085c4193e23" +XLA_COMMIT = "a041e1b15524cd15751e9a5b5dc581b9f276958f" +XLA_SHA256 = "39ea15ad645a2973efbfe7d1b4761d114cb688b5d2934561009aab7c911473da" def repo(): tf_http_archive( From 978d35f69704ce95a9d792f9ca9c7e3ee356417f Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 9 Dec 2024 14:01:22 -0800 Subject: [PATCH 683/698] Fix expected exception type in pallas grad tests. PiperOrigin-RevId: 704408603 --- tests/pallas/pallas_test.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 8c5a0a99c279..39bd279e8bce 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1340,11 +1340,7 @@ def if_true(z): np.testing.assert_allclose(f(jnp.bool_(False), arg), -arg) - # We actually expect the assertion failure in linearize, but this also - # covers another case where an effect was causing an earlier assertion - # failure. - with self.assertRaises(AssertionError): - # Notably, we should not have a ValueError for mismatched Read effect. + with self.assertRaisesRegex(ValueError, "Linearization failed"): _ = jax.grad(lambda x: jnp.sum(f(jnp.bool_(True), x)**2))(arg) # np.testing.assert_allclose( # dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14])) @@ -1397,7 +1393,7 @@ def body_fn(i, args): 16 * x * params[4, 2]) np.testing.assert_allclose(f(program, params, x), expected) - with self.assertRaises(AssertionError): + with self.assertRaisesRegex(ValueError, "Linearization failed"): jax.value_and_grad(lambda params, x: f(program, params, x).sum())( params, x) @@ -1451,7 +1447,7 @@ def body_fn(i, args): 16 * x * params[4, 2]) np.testing.assert_allclose(f(program, params, x), expected) - with self.assertRaises(AssertionError): + with self.assertRaisesRegex(ValueError, "Linearization failed"): jax.value_and_grad(lambda params, x: f(program, params, x).sum())( params, x) From fc2edbfac8739f561e26a71ff6087b7f5d08ca4d Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 9 Dec 2024 09:54:58 -0500 Subject: [PATCH 684/698] Add a `freeze` primitive to delimit ref lifetimes for AD. Also some basic AD through mutable_array/freeze. Co-authored-by: Matthew Johnson --- jax/_src/core.py | 24 ++++++++++++++++++++---- jax/_src/interpreters/ad.py | 14 ++++++++++++++ jax/_src/interpreters/partial_eval.py | 6 +++++- jax/_src/lax/lax.py | 3 +++ jax/_src/pjit.py | 14 ++++++++++++-- jax/_src/state/discharge.py | 4 ++++ jax/_src/state/primitives.py | 5 +++++ tests/mutable_array_test.py | 13 +++++++++++++ 8 files changed, 76 insertions(+), 7 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 30893ce99ce4..0c2949de07af 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -431,6 +431,8 @@ class Primitive: call_primitive: bool = False # set for map primitives processed in final style. map_primitive: bool = False + # set for ref primitives + ref_primitive: bool = False def __init__(self, name: str): self.name = name @@ -1882,6 +1884,7 @@ def __repr__(self) -> str: return 'Mutable' + repr(self[...]) def mutable_array(init_val): return mutable_array_p.bind(init_val) mutable_array_p = Primitive('mutable_array') +mutable_array_p.ref_primitive = True class InternalMutableArrayEffect(effects.Effect): pass @@ -1899,6 +1902,18 @@ def _mutable_array_impl(init_val): aval = get_aval(init_val) return MutableArray(AbstractRef(aval), init_val) +def freeze(ref): + return freeze_p.bind(ref) +freeze_p = Primitive('freeze') +freeze_p.ref_primitive = True + +@freeze_p.def_effectful_abstract_eval +def freeze_abstract_eval(ref_aval): + return ref_aval.inner_aval, {internal_mutable_array_effect} + +@freeze_p.def_impl +def _freeze_impl(ref): + return ref[()] class AbstractToken(AbstractValue): def str_short(self, short_dtypes=False): return 'Tok' @@ -2516,10 +2531,11 @@ def write(v: Var, a: AbstractValue) -> None: # Check the computed effect type matches the eqn's annotation, and is # included in the jaxpr's annotation. - if prim is mutable_array_p: - outvar, = eqn.outvars - in_idx[outvar] = None # type: ignore - mut_arrays.add(outvar) + if prim.ref_primitive: + if prim is mutable_array_p: + outvar, = eqn.outvars + in_idx[outvar] = None # type: ignore + mut_arrays.add(outvar) if eqn.effects != eqn_effects: raise JaxprTypeError("Inferred effects do not match equation effects. " f"Equation effects: {eqn.effects}. " diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index c5e78321f331..c43d87942050 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -263,6 +263,20 @@ def write_primal(v, val): with ctx: map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: + if eqn.primitive.ref_primitive: + if eqn.primitive is core.mutable_array_p: + val_var, = eqn.invars + ref_var, = eqn.outvars + ref = read_primal(ref_var) + ct_out = core.freeze(ref) + write_cotangent(eqn.primitive, val_var, ct_out) + elif eqn.primitive is core.freeze_p: + val_var, = eqn.outvars + ref_var, = eqn.invars + ct_in = instantiate_zeros(read_cotangent(val_var)) + write_primal(ref_var, core.mutable_array(ct_in)) + continue + invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6e2f11833b9d..880475cb81c9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1006,7 +1006,7 @@ def partial_eval_jaxpr_stateful( in_inst: bool | Sequence[bool], ensure_out_unknowns: bool | Sequence[bool], ensure_out_inst: bool | Sequence[bool], - saveable: Callable[..., RematCases_], + saveable: Callable[..., RematCases_] | None, ) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int, int]: if type(in_inst) is bool: in_inst = (in_inst,) * len(jaxpr.invars) @@ -1014,6 +1014,8 @@ def partial_eval_jaxpr_stateful( ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars) if type(ensure_out_inst) is bool: ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars) + if saveable is None: + saveable = everything_saveable jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ _partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns), tuple(in_inst), @@ -1021,6 +1023,8 @@ def partial_eval_jaxpr_stateful( tuple(ensure_out_inst), saveable) return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref +everything_saveable = lambda *_, **__: True + @weakref_lru_cache def _partial_eval_jaxpr_custom_cached( jaxpr: Jaxpr, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ad08e1335a40..dd908e4b2468 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1719,6 +1719,9 @@ def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray: val = ad_util.zeros_like_aval(aval.inner_aval) return core.mutable_array(val) +# TODO(dougalm): this is nonsense but it's here because in places like +# custom_vjp we assume that all arguments have tangent spaces. We could have +# a distinct NotATangentType value instead. ad_util.aval_zeros_likers[state.AbstractRef] = zeros_like_abstract_ref # type: ignore def iota(dtype: DTypeLike, size: int) -> Array: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 196dd8b014ae..5f5a2b8b7692 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2156,8 +2156,18 @@ def _pjit_partial_eval(trace, *in_tracers, known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ - pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) + if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect)) + for e in jaxpr.effects): + known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \ + pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins, + False, False, None) + if num_res_ref: raise NotImplementedError + known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts) + unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts) + res_avals = unknown_jaxpr.in_avals[:num_res_val] + else: + known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ + pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 2c38878c7112..1ab28435bd37 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -153,6 +153,10 @@ def _eval_jaxpr_discharge_state( [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) refs_to_discharge.add(id(outvar.aval)) + elif eqn.primitive is core.freeze_p: + [invar], [outvar] = eqn.invars, eqn.outvars + ans = env.read(invar) + refs_to_discharge.remove(id(invar.aval)) elif (any(should_discharge) or core.internal_mutable_array_effect in eqn.effects ): diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 14d42ad0809c..8b8d189b3e97 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -654,3 +654,8 @@ def _broadcast_to_abstract_eval(aval, *, shape): mlir.register_lowering( broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False) ) + +# === AD rules for mutable arrays === + +ad.defjvp(core.mutable_array_p, lambda g, _: core.mutable_array(g)) +ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g)) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index d3e32873c597..f1b80f32446a 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -232,5 +232,18 @@ def f(): x = f() self.assertArraysEqual(x, jnp.zeros(8)) + def test_grad_mutable_array(self): + @jax.jit + def f(x): + x_ = core.mutable_array(x) + x_[()] = x_[()] + x_[()] + y = core.freeze(x_) + return y + + ans = jax.grad(f)(1.) + expected = 2.0 + self.assertAllClose(ans, expected, check_dtypes=False) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From a36af966fd0ecd7bd8ba908aca984b4e58a125c0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 9 Dec 2024 19:01:46 -0800 Subject: [PATCH 685/698] CI: temporarily pin numpy version for mypy check --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 87f706d3a404..ed38faa6774b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib] + additional_dependencies: [types-requests==2.31.0, jaxlib, numpy~=2.1.0] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext From 944d822ce64450f698bd9b4e8236421ade401e84 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 9 Dec 2024 19:20:04 -0800 Subject: [PATCH 686/698] Add a no-op batching rule for optimization_barrier_p PiperOrigin-RevId: 704507586 --- jax/_src/lax/lax.py | 4 ++++ tests/lax_test.py | 2 +- tests/lax_vmap_test.py | 19 +++++++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 972beb04cab4..4f316fe89633 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6496,3 +6496,7 @@ def _optimization_barrier_lowering_rule(ctx, *args): optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) mlir.register_lowering(optimization_barrier_p, _optimization_barrier_lowering_rule) + +def _optimization_barrier_batcher(batched_args, batch_dims, **params): + return optimization_barrier_p.bind(*batched_args, **params), batch_dims +batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher diff --git a/tests/lax_test.py b/tests/lax_test.py index 360525efab0b..2bebbe9db763 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3551,7 +3551,7 @@ def testAsarray(self, typ): with jax.transfer_guard('disallow'): jax.jit(asarray_closure)() - def testOptimizationBarrier(self): + def test_optimization_barrier(self): x = lax.optimization_barrier((2, 3)) self.assertEqual((2, 3), x) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 83d4d657751b..bfe9fecd6c7e 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -691,6 +691,25 @@ def testTopK(self, shape, dtype, k, bdims): op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng) + @jtu.sample_product( + [dict(shape=shape, bdims=bdims) + for shape in [(8,), (3, 4, 5)] + for bdims in lax_test_util.all_bdims(shape)], + dtype=lax_test_util.default_dtypes, + ) + def test_optimization_barrier_vmap(self, shape, dtype, bdims): + rng = jtu.rand_small(self.rng()) + self._CheckBatching(lax.optimization_barrier, 5, bdims, (shape,), (dtype,), + rng) + + def test_optimization_barrier_vmap_out_axes(self): + x = jnp.arange(8) + y = x.reshape(1, 8) + out = jax.vmap(lax.optimization_barrier, in_axes=((0, 1),), + out_axes=(0, 1))((x, y)) + self.assertArraysEqual(out[0], x) + self.assertArraysEqual(out[1], y) + @jtu.sample_product( [dict(shape=shape, bdims=bdims, dimension=dimension, arity=arity) for shape in [(2, 3)] From 12c30578b2da2a9b9a2f9d46de2e32fb71b9e0a0 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 9 Dec 2024 22:19:04 -0800 Subject: [PATCH 687/698] Introduce `lax.ragged_all_to_all` primitive This version emits a StableHLO custom call. The test outputs the following MLIR module: ``` module @jit_ragged_all_to_all { func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) { %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32> return %0 : tensor<6xf32> } } ``` For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above). The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all. PiperOrigin-RevId: 704550890 --- jax/_src/lax/parallel.py | 107 ++++++++++++++++++ .../jax2tf/tests/primitives_test.py | 2 + jax/extend/core/primitives.py | 1 + jax/lax/__init__.py | 2 + tests/lax_test.py | 50 ++++++++ 5 files changed, 162 insertions(+) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c8cea6a9df5b..6ae2d02f82b7 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -457,6 +457,55 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): return tree_util.tree_map(bind, x) +def ragged_all_to_all(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes): + """Ragged version of :func:`all_to_all`. + + For now, ``split_axis`` and ``concat_axis`` from `all_to_all` are equivalent + and the outermost (ragged) dimension. ``axis_index_groups`` is default to all + replicas (e.g. there is only one group and covers all axis indices). + + Ragged arrays are defined by a set of three arrays: + * ``data``: the ``data`` array is "ragged" along its outermost dimension, + along which each indexed element has variable size. + * ``offsets``: the ``offsets`` array indexes the outermost dimension of the + ``data`` array, and represents the starting offset of each ragged element of + the ``data`` array. + * ``sizes``: the ``sizes`` array represents the size of each ragged element of + the ``data`` array, where the size is specified in units of sub-elements. A + sub-element is defined as the suffix of the ``data`` array shape obtained by + removing the outermost "ragged" dimension. + The ``offsets`` and ``sizes`` arrays must have the same size. + + # Example ragged tensor + data: [8,3] = {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}} + offsets: [3] = {0, 1, 4} + sizes: [3] = {1, 3, 4} + + # Index 'data' at 'offsets'[0], 'sizes'[0]' + {a,b,c} + + # Index 'data' at 'offsets'[1], 'sizes'[1]' + {d,e,f},{g,h,i},{j,k,l} + + # Index 'data' at 'offsets'[2], 'sizes'[2]' + {m,n,o},{p,q,r},{s,t,u},{v,w,x} + + Args: + operand: array with ragged dimension along its outermost dimension. + output: array of ragged input offsets. + input_offsets: array of ragged input send sizes. + send_sizes: array of ragged output data. + output_offsets: array of ragged output offsets. + recv_sizes: array of ragged output receive sizes. + Returns: + array with shape equal to ``output``. + """ + return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes, + output_offsets, recv_sizes) + +ragged_all_to_all_p = core.Primitive('ragged_all_to_all') + + def axis_index(axis_name): """Return the index along the mapped axis ``axis_name``. @@ -1052,6 +1101,64 @@ def _all_to_all_effectful_abstract_eval( batching.skippable_batchers[all_to_all_p] = partial(_names_in_param, 'axis_name') +def _ragged_all_to_all_lowering(ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes): + N = input_offsets.type.shape[0] + backend_config = ir.DictAttr.get({ + 'replica_groups': ir.DenseIntElementsAttr.get( + np.arange(0, N, 1, dtype=np.int64), shape=[1, N] + ) + }) + return hlo.CustomCallOp( + result=[output.type], + inputs=[operand, output, input_offsets, send_sizes, output_offsets, + recv_sizes], + call_target_name=ir.StringAttr.get('ragged_all_to_all'), + backend_config=backend_config, + api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4), + ).results + +@ragged_all_to_all_p.def_abstract_eval +def _ragged_all_to_all_abstract_eval(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes): + if operand.shape != output.shape: + raise ValueError('ragged_all_to_all input and output shapes must be equal.') + if not dtypes.issubdtype(input_offsets.dtype, np.integer): + raise ValueError("ragged_all_to_all input_offsets must be integer type.") + if not dtypes.issubdtype(send_sizes.dtype, np.integer): + raise ValueError("ragged_all_to_all send_sizes must be integer type.") + if not dtypes.issubdtype(output_offsets.dtype, np.integer): + raise ValueError("ragged_all_to_all output_offsets must be integer type.") + if not dtypes.issubdtype(recv_sizes.dtype, np.integer): + raise ValueError("ragged_all_to_all recv_sizes must be integer type.") + if len(input_offsets.shape) != 1 or input_offsets.shape[0] < 1: + raise ValueError( + "ragged_all_to_all input_offsets must be rank 1 with positive dimension" + " size, but got shape {}".format(input_offsets.shape) + ) + if len(send_sizes.shape) != 1 or send_sizes.shape[0] < 1: + raise ValueError( + "ragged_all_to_all send_sizes must be rank 1 with positive dimension" + " size, but got shape {}".format(send_sizes.shape) + ) + if len(output_offsets.shape) != 1 or output_offsets.shape[0] < 1: + raise ValueError( + "ragged_all_to_all output_offsets must be rank 1 with positive" + " dimension size, but got shape {}".format(output_offsets.shape) + ) + if len(recv_sizes.shape) != 1 or recv_sizes.shape[0] < 1: + raise ValueError( + "ragged_all_to_all recv_sizes must be rank 1 with positive dimension" + " size, but got shape {}".format(recv_sizes.shape) + ) + return output.update( + shape=list(output.shape), + dtype=output.dtype, + weak_type=output.weak_type, + ) + +ragged_all_to_all_p.def_impl(partial(dispatch.apply_primitive, ragged_all_to_all_p)) +mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) + + def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): """Gather values of x across all replicas. diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 2863ca4ed616..76d5b4cde6c7 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -183,6 +183,8 @@ def test_primitive_coverage(self): continue if p.name == "pallas_call": continue + if p.name == "ragged_all_to_all": + continue if p.name == "ffi_call": continue if p.name == "tpu_custom_call": diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index 02f0657cc371..d8a10154cf4a 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -204,6 +204,7 @@ pmin_p as pmin_p, ppermute_p as ppermute_p, psum_p as psum_p, + ragged_all_to_all_p as ragged_all_to_all_p, ) from jax._src.lax.ann import ( diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index d569ed641138..321b1dda19cf 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -362,6 +362,8 @@ psum_p as psum_p, psum_scatter as psum_scatter, pswapaxes as pswapaxes, + ragged_all_to_all as ragged_all_to_all, + ragged_all_to_all_p as ragged_all_to_all_p, ) from jax._src.lax.other import ( conv_general_dilated_local as conv_general_dilated_local, diff --git a/tests/lax_test.py b/tests/lax_test.py index 2bebbe9db763..89d41d0b9312 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1346,6 +1346,56 @@ def testDotGeneralAgainstNumpy(self, lhs_shape, rhs_shape, dtype, numpy_op = lambda x, y: lax_reference.dot_general(x, y, dimension_numbers) self._CheckAgainstNumpy(numpy_op, op, args_maker) + def testRaggedAllToAllErrors(self): + operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32) + output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32) + input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32) + send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32) + output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32) + recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32) + + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input and output shapes must be equal."): + jax.jit(lax.ragged_all_to_all).lower(operand, jnp.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32), input_offsets, send_sizes, output_offsets, recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be integer type."): + jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), send_sizes, output_offsets, recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be integer type."): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), output_offsets, recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be integer type."): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([0.0, 1.0, 3.0], dtype=jnp.float32), recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be integer type."): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([[0, 1, 3]], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all input_offsets must be rank 1 with positive dimension size, but got shape (0,)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, jnp.array([], dtype=jnp.int32), send_sizes, output_offsets, recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32), output_offsets, recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all send_sizes must be rank 1 with positive dimension size, but got shape (0,)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, jnp.array([], dtype=jnp.int32), output_offsets, recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (1, 3)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([[0, 1, 3]], dtype=jnp.int32), recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all output_offsets must be rank 1 with positive dimension size, but got shape (0,)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, jnp.array([], dtype=jnp.int32), recv_sizes) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (1, 3)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([[1, 2, 3]], dtype=jnp.int32)) + with self.assertRaisesWithLiteralMatch(ValueError, "ragged_all_to_all recv_sizes must be rank 1 with positive dimension size, but got shape (0,)"): + jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, jnp.array([], dtype=jnp.int32)) + + def testRaggedAllToAll(self): + operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32) + output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32) + input_offsets = jnp.array([0, 1, 3], dtype=jnp.int32) + send_sizes = jnp.array([1, 2, 3], dtype=jnp.int32) + output_offsets = jnp.array([0, 1, 3], dtype=jnp.int32) + recv_sizes = jnp.array([1, 2, 3], dtype=jnp.int32) + mlir_module = jax.jit(lax.ragged_all_to_all).lower(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes).as_text() + self.assertIn("stablehlo.custom_call @ragged_all_to_all", mlir_module) + self.assertIn( + "backend_config = {replica_groups = dense<[[0, 1, 2]]> :" + " tensor<1x3xi64>}}", + mlir_module, + ) + @jtu.sample_product( [ {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, From 09309e64521acd200de2660cc3eb2ee109ce08fd Mon Sep 17 00:00:00 2001 From: Silvio Traversaro Date: Tue, 10 Dec 2024 12:11:28 +0100 Subject: [PATCH 688/698] Update conda-forge installation docs after CUDA 12 upgrade --- docs/installation.md | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 78a6a5a5a444..6686eac41186 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -253,18 +253,14 @@ simply run: conda install jax -c conda-forge ``` -To install it on a machine with an NVIDIA GPU, run: +If you run this command on machine with an NVIDIA GPU, this should install a CUDA-enabled package of `jaxlib`. + +To ensure that the jax version you are installing is indeed CUDA-enabled, run: ```bash -conda install "jaxlib=*=*cuda*" jax cuda-nvcc -c conda-forge -c nvidia +conda install "jaxlib=*=*cuda*" jax -c conda-forge ``` -Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which -JAX requires. You must therefore either install the `cuda-nvcc` package from -the `nvidia` channel, or install CUDA on your machine separately so that `ptxas` -is in your path. The channel order above is important (`conda-forge` before -`nvidia`). - If you would like to override which release of CUDA is used by JAX, or to install the CUDA build on a machine without GPUs, follow the instructions in the [Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch) From acae2f054690659980b1a7a65dc088b0308e8100 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 10 Dec 2024 10:12:10 -0500 Subject: [PATCH 689/698] Remove code in jax2tf for compatibility with TF 2.10 or earlier. --- jax/experimental/jax2tf/tests/jax2tf_test.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 8993d044cb3b..7d3313be6c92 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -44,20 +44,8 @@ import numpy as np import tensorflow as tf -# pylint: disable=g-direct-tensorflow-import -from tensorflow.compiler.tf2xla.python import xla as tfxla -# pylint: enable=g-direct-tensorflow-import config.parse_flags_with_absl() -_exit_stack = contextlib.ExitStack() - -# TODO(necula): Remove once tensorflow is 2.10.0 everywhere. -def setUpModule(): - if not hasattr(tfxla, "optimization_barrier"): - _exit_stack.enter_context(jtu.global_config_context(jax_remat_opt_barrier=False)) - -def tearDownModule(): - _exit_stack.close() class Jax2TfTest(tf_test_util.JaxToTfTestCase): From 8813973d9686f3bcc6cea1a21b145b740f02cb24 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 10 Dec 2024 09:36:44 -0800 Subject: [PATCH 690/698] [AutoPGLE] Cleanup compiler code. PiperOrigin-RevId: 704741308 --- jax/_src/compiler.py | 166 +++++++++++++++++++++++++++---------------- 1 file changed, 103 insertions(+), 63 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 16fbb890956c..f7d427f89ef1 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -374,60 +374,24 @@ def compile_or_get_cached( host_callbacks) is_multi_process = ( - len({device.process_index for device in devices.flatten()}) > 1) - min_device_process_id = ( - min(devices.flatten(), key=lambda device: device.id).process_index) - - # When PGLE is enabled there might be 3 types of situations: - # 1. PGLE profiled module (the one which was recompiled with FDO profile) is - # in the persistent cache. In this case the module should be returned from - # cache and PGLE should be disabled for this module. Is module is stored in - # the persistent cache under the "pgle_profiled_module_key" which calculated - # with replacing FDO profile with flag which identify that module were PGLE - # profiled. - # 2. PGLE profiled module is not in the persistent cache and the module is - # getting built with an FDO profile. In this case we need to share FDO profile - # with other processes and store the result under the - # "pgle_profiled_module_key" so later in case 1 we will be able to find the - # module. - # 3. PGLE profiled module is not in the persistent cache and the module is - # getting compiled to be PGLEd (FDO profile is empty). In this case we need to - # simply return the non-PGLE profiled module from the persistent cache. - if (config.enable_pgle.value - and config.pgle_profiling_runs.value > 0): - fdo_profile = compile_options.executable_build_options.fdo_profile - compile_options.executable_build_options.fdo_profile = b"pgle profiled" - - pgle_profiled_module_key = compilation_cache.get_cache_key( + len({device.process_index for device in devices.flatten()}) > 1 + ) + min_device_process_id = min( + devices.flatten(), key=lambda device: device.id + ).process_index + + if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: + cache_key = _resolve_pgle_module_cache_key( computation, devices, compile_options, backend, - cache_key_type.IgnoreCallbacks.ALL, + pgle_profiler, + is_multi_process, + cache_key, + module_name, + min_device_process_id, ) - compile_options.executable_build_options.fdo_profile = fdo_profile - - if _is_executable_in_cache(backend, pgle_profiled_module_key): - # Load PGLE profiled module from the persistent cache. - cache_key = pgle_profiled_module_key - if pgle_profiler is not None: - pgle_profiler.disable() - elif fdo_profile is not None and len(fdo_profile) > 0: - # Store module under PGLE profiled module cache key. - cache_key = pgle_profiled_module_key - if is_multi_process and distributed.global_state.client is not None: - compile_options.executable_build_options.fdo_profile = _share_fdo_profiles( - computation, devices, compile_options, backend, - distributed.global_state.client, - min_device_process_id - ) - else: - compile_options.executable_build_options.fdo_profile = fdo_profile - logger.debug( - "Compiling module %s with FDO profile: %s", - module_name, - compile_options.executable_build_options.fdo_profile, - ) cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( @@ -493,6 +457,75 @@ def compile_or_get_cached( cache_key, ) + +# When PGLE is enabled there might be 3 types of situations: +# 1. PGLE profiled module (the one which was recompiled with FDO profile) is +# in the persistent cache. In this case the module should be returned from +# cache and PGLE should be disabled for this module. Is module is stored in +# the persistent cache under the "pgle_profiled_module_key" which calculated +# with replacing FDO profile with flag which identify that module were PGLE +# profiled. +# 2. PGLE profiled module is not in the persistent cache and the module is +# getting built with an FDO profile. In this case we need to share FDO profile +# with other processes and store the result under the +# "pgle_profiled_module_key" so later in case 1 we will be able to find the +# module. +# 3. PGLE profiled module is not in the persistent cache and the module is +# getting compiled to be PGLEd (FDO profile is empty). In this case we need to +# simply return the non-PGLE profiled module from the persistent cache. +def _resolve_pgle_module_cache_key( + computation: ir.Module, + devices: np.ndarray, + compile_options: xc.CompileOptions, + backend: xc.Client, + pgle_profiler: profiler.PGLEProfiler | None, + is_multi_process: bool, + cache_key: str, + module_name: str, + min_device_process_id: int, +) -> str: + fdo_profile = compile_options.executable_build_options.fdo_profile + compile_options.executable_build_options.fdo_profile = b"pgle profiled" + + pgle_profiled_module_key = compilation_cache.get_cache_key( + computation, + devices, + compile_options, + backend, + cache_key_type.IgnoreCallbacks.ALL, + ) + compile_options.executable_build_options.fdo_profile = fdo_profile + + result_key = cache_key + if _is_executable_in_cache(backend, pgle_profiled_module_key): + # Load PGLE profiled module from the persistent cache. + result_key = pgle_profiled_module_key + if pgle_profiler is not None: + pgle_profiler.disable() + elif fdo_profile is not None and len(fdo_profile) > 0: + # Store module under PGLE profiled module cache key. + result_key = pgle_profiled_module_key + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, + ) + ) + else: + compile_options.executable_build_options.fdo_profile = fdo_profile + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(compile_options.executable_build_options.fdo_profile), + ) + return result_key + + # The process that has the lowest device ID should share FDO profile before # compilation with other processes. def _share_fdo_profiles( @@ -510,32 +543,39 @@ def _share_fdo_profiles( return fdo_profile compile_options.executable_build_options.fdo_profile = b"" - profile_key = ( - compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - cache_key_type.IgnoreCallbacks.ALL, - ) - + "_fdo_sync" - ) + try: + profile_key = ( + compilation_cache.get_cache_key( + computation, + devices, + compile_options, + backend, + cache_key_type.IgnoreCallbacks.ALL, + ) + + "_fdo_sync" + ) + except xc._xla.XlaRuntimeError as ex: + logger.error( + "compile_or_get_cached: unable to generate cache key, " + "skipping the fdo profile sharing: %s", + ex, + ) + return fdo_profile + if profile_key in _share_fdo_profiles.modules_profiles: return _share_fdo_profiles.modules_profiles[profile_key] share_timeout = config.share_binary_between_hosts_timeout_ms.value if distributed.global_state.process_id == min_process_id: logger.debug( - "Sharing FDO profile: %s. For module %s. Process %d.", - fdo_profile, + "Module %s. Sharing FDO profile. Process %d.", module_name, min_process_id, ) global_client.key_value_set_bytes(profile_key, fdo_profile) else: logger.debug( - "Waiting for FDO profile: %s. For module %s. Should be set by process %d.", - fdo_profile, + "Module %s. Waiting for FDO profile which should be set by process %d.", module_name, min_process_id, ) From 6dbafed7bce8f8329afc16d0285a2b59acdc83a8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 10 Dec 2024 10:01:00 -0800 Subject: [PATCH 691/698] Fix mypy failure PiperOrigin-RevId: 704748889 --- jax/_src/interpreters/ad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 9ac3bcfb54f3..3f6c5ee5b043 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -275,8 +275,8 @@ def write_primal(v, val): ct_out = core.freeze(ref) write_cotangent(eqn.primitive, val_var, ct_out) elif eqn.primitive is core.freeze_p: - val_var, = eqn.outvars - ref_var, = eqn.invars + val_var, = eqn.outvars # type: ignore + ref_var, = eqn.invars # type: ignore ct_in = instantiate_zeros(read_cotangent(val_var)) write_primal(ref_var, core.mutable_array(ct_in)) continue From e6dfe8f3806161cc54a163ae4f920dc1a5e10b03 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 10 Dec 2024 10:23:07 -0800 Subject: [PATCH 692/698] [AutoPGLE] Share FDO profile even when compilation cache disabled. PiperOrigin-RevId: 704757991 --- jax/_src/compiler.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index f7d427f89ef1..6fbd9ab4e3a5 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -348,7 +348,33 @@ def compile_or_get_cached( use_compilation_cache = compilation_cache.is_cache_used(backend) + is_multi_process = ( + len({device.process_index for device in devices.flatten()}) > 1 + ) + min_device_process_id = min( + devices.flatten(), key=lambda device: device.id + ).process_index + is_auto_pgle_used = ( + config.enable_pgle.value and config.pgle_profiling_runs.value > 0 + ) + if not use_compilation_cache: + if ( + is_multi_process + and is_auto_pgle_used + and distributed.global_state.client is not None + ): + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, + ) + ) + return backend_compile(backend, computation, compile_options, host_callbacks) @@ -373,14 +399,7 @@ def compile_or_get_cached( return backend_compile(backend, computation, compile_options, host_callbacks) - is_multi_process = ( - len({device.process_index for device in devices.flatten()}) > 1 - ) - min_device_process_id = min( - devices.flatten(), key=lambda device: device.id - ).process_index - - if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: + if is_auto_pgle_used: cache_key = _resolve_pgle_module_cache_key( computation, devices, From cb6881d9e8a854190d1bacdb43af0456dc8efe72 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Tue, 10 Dec 2024 10:23:25 -0800 Subject: [PATCH 693/698] Reverts bdadc53ebcd40a5091d66d2586deba82fe5e01ca PiperOrigin-RevId: 704758075 --- tests/aot_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/aot_test.py b/tests/aot_test.py index 194982e046ba..62fecfaf48a4 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -19,6 +19,7 @@ import jax from jax._src import core from jax._src import test_util as jtu +import jax._src.lib from jax._src.lib import xla_client as xc from jax.experimental import topologies from jax.experimental.pjit import pjit @@ -62,7 +63,11 @@ def verify_serialization(lowered): jax.pmap(lambda x: x * x).lower( np.zeros((len(jax.devices()), 4), dtype=np.float32))) - @jtu.skip_on_devices('gpu') # Test fails in CI + @unittest.skipIf( + jax._src.lib.xla_extension_version < 300, + 'AOT compiler registration was broken in XLA extension version below' + ' 300.', + ) def test_topology_pjit_serialize(self): try: aot_topo = topologies.get_topology_desc( From 6541a62099edebca4d74ff71f537f5d21e76c3a9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 10 Dec 2024 11:11:32 -0800 Subject: [PATCH 694/698] jax.core: deprecate a number of APIs --- CHANGELOG.md | 8 +++ docs/contributor_guide.rst | 1 - docs/jax.extend.core.rst | 18 +++++++ docs/jax.extend.rst | 1 + docs/jax_internal_api.rst | 14 ----- jax/_src/cudnn/fused_attention_stablehlo.py | 2 +- jax/_src/cudnn/fusion.py | 2 +- .../pallas/mosaic/pallas_call_registration.py | 5 +- jax/_src/pallas/mosaic_gpu/BUILD | 1 + .../mosaic_gpu/pallas_call_registration.py | 2 +- jax/_src/pallas/triton/BUILD | 1 + .../pallas/triton/pallas_call_registration.py | 2 +- jax/_src/pallas/triton/primitives.py | 2 +- jax/_src/tpu_custom_call.py | 2 +- jax/core.py | 53 +++++++++++++++---- jax/experimental/mosaic/gpu/core.py | 2 +- jax/experimental/sparse/_lowerings.py | 2 +- jax/experimental/sparse/nm.py | 2 +- tests/key_reuse_test.py | 2 +- tests/pjit_test.py | 2 +- 20 files changed, 84 insertions(+), 40 deletions(-) create mode 100644 docs/jax.extend.core.rst delete mode 100644 docs/jax_internal_api.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index e351f64d0af0..d8bb1478ad0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.38 +* Deprecations + * a number of APIs in the internal `jax.core` namespace have been deprecated, including + `ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`, + `Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by + APIs of the same name in {mod}`jax.extend.core`; see the documentation for + {mod}`jax.extend` for information on the compatibility guarantees of these + semi-public extensions. + ## jax 0.4.37 (Dec 9, 2024) This is a patch release of jax 0.4.36. Only "jax" was released at this version. diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index 55094fc88958..f89122f944cc 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -25,4 +25,3 @@ some of JAX's (extensible) internals. autodidax jep/index - jax_internal_api diff --git a/docs/jax.extend.core.rst b/docs/jax.extend.core.rst new file mode 100644 index 000000000000..5f3ff0558af6 --- /dev/null +++ b/docs/jax.extend.core.rst @@ -0,0 +1,18 @@ +``jax.extend.core`` module +========================== + +.. automodule:: jax.extend.core + +.. autosummary:: + :toctree: _autosummary + + ClosedJaxpr + Jaxpr + JaxprEqn + Literal + Primitive + Token + Var + array_types + jaxpr_as_fun + primitives diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 9cbee08e8e50..0d68013c9261 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -11,6 +11,7 @@ Modules .. toctree:: :maxdepth: 1 + jax.extend.core jax.extend.ffi jax.extend.linear_util jax.extend.mlir diff --git a/docs/jax_internal_api.rst b/docs/jax_internal_api.rst deleted file mode 100644 index 1ece596d88ef..000000000000 --- a/docs/jax_internal_api.rst +++ /dev/null @@ -1,14 +0,0 @@ -Internal API reference -====================== - -core ----- - -.. currentmodule:: jax.core -.. automodule:: jax.core - -.. autosummary:: - :toctree: _autosummary - - Jaxpr - ClosedJaxpr diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index ef4e33ad0665..c45bb8a9efd8 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -18,8 +18,8 @@ import math import jax -from jax import core from jax import dtypes +from jax._src import core from jax._src import dispatch from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index 8a13399e3d63..f320672463cb 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -14,7 +14,7 @@ import functools import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax.interpreters import mlir from jax.interpreters.mlir import hlo from jax.interpreters.mlir import ir diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 4382cea914f0..ec9500c67cd7 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -21,10 +21,9 @@ from typing import Any import jax -from jax import core as jax_core from jax import dtypes from jax._src import config -from jax._src import core as jax_src_core +from jax._src import core as jax_core from jax._src import sharding_impls from jax._src import tpu_custom_call from jax._src.interpreters import mlir @@ -189,7 +188,7 @@ def lower_module(for_verification: bool): # Replace in_avals to physical avals. # This step is required for mapping logical types to physical types. # (e.g. PRNG key -> uint32[2]) - physical_avals = [jax_src_core.physical_aval(aval) for aval in ctx.avals_in] + physical_avals = [jax_core.physical_aval(aval) for aval in ctx.avals_in] ctx = ctx.replace(avals_in=physical_avals) # Booleans are loaded into the kernel as integers. diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 3d6e82d443b4..e9461a5ceba0 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -44,6 +44,7 @@ pytype_strict_library( deps = [ ":lowering", "//jax", + "//jax:core", "//jax:mlir", "//jax:mosaic_gpu", "//jax/_src/pallas", diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 05785cb511ea..18d8baf6e95e 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,7 +23,7 @@ import warnings import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index a9babcba0577..84fae3913491 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -76,6 +76,7 @@ pytype_strict_library( ":lowering", "//jax", "//jax:config", + "//jax:core", "//jax:mlir", "//jax:util", "//jax/_src/lib", diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 67b0bd326616..1805f8c0923a 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -19,7 +19,7 @@ import io from typing import Any -from jax import core as jax_core +import jax._src.core as jax_core from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 23fce50dc4f9..b845a4079ff4 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -19,7 +19,7 @@ from collections.abc import Sequence import jax -from jax import core as jax_core +from jax._src import core as jax_core from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas.triton import lowering diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index ccd77af5bef0..9e54f62d9ea0 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -29,7 +29,7 @@ from typing import Any import jax -from jax import core +from jax._src import core from jax._src import config from jax._src import sharding_impls from jax._src.interpreters import mlir diff --git a/jax/core.py b/jax/core.py index 8d7c546f0754..4d1742bc28ea 100644 --- a/jax/core.py +++ b/jax/core.py @@ -23,7 +23,6 @@ AxisSize as AxisSize, AxisName as AxisName, CallPrimitive as CallPrimitive, - ClosedJaxpr as ClosedJaxpr, ConcretizationTypeError as ConcretizationTypeError, DShapedArray as DShapedArray, DropVar as DropVar, @@ -34,23 +33,18 @@ InDBIdx as InDBIdx, InconclusiveDimensionOperation as InconclusiveDimensionOperation, InputType as InputType, - Jaxpr as Jaxpr, JaxprDebugInfo as JaxprDebugInfo, - JaxprEqn as JaxprEqn, JaxprPpContext as JaxprPpContext, JaxprPpSettings as JaxprPpSettings, JaxprTypeError as JaxprTypeError, - Literal as Literal, MapPrimitive as MapPrimitive, nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 OpaqueTraceState as OpaqueTraceState, OutDBIdx as OutDBIdx, OutputType as OutputType, ParamDict as ParamDict, - Primitive as Primitive, ShapedArray as ShapedArray, TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING, - Token as Token, Trace as Trace, Tracer as Tracer, unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 @@ -59,7 +53,6 @@ unsafe_get_current_trace as unsafe_get_current_trace_DO_NOT_USE, # noqa: F401 UnshapedArray as UnshapedArray, Value as Value, - Var as Var, abstract_token as abstract_token, aval_mapping_handlers as aval_mapping_handlers, call as call, @@ -78,7 +71,6 @@ eval_jaxpr as eval_jaxpr, extend_axis_env_nd as extend_axis_env_nd, find_top_trace as find_top_trace, - full_lower as full_lower, gensym as gensym, get_aval as get_aval, get_type as get_type, @@ -86,10 +78,8 @@ is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, - jaxpr_as_fun as jaxpr_as_fun, jaxprs_in_params as jaxprs_in_params, join_effects as join_effects, - lattice_join as lattice_join, leaked_tracer_error as leaked_tracer_error, literalable_types as literalable_types, mapped_aval as mapped_aval, @@ -101,7 +91,6 @@ no_effects as no_effects, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, pytype_aval_mappings as pytype_aval_mappings, - raise_to_shaped as raise_to_shaped, raise_to_shaped_mappings as raise_to_shaped_mappings, reset_trace_state as reset_trace_state, set_current_trace as set_current_trace, @@ -124,6 +113,37 @@ from jax._src import core as _src_core _deprecations = { + # Added 2024-12-10 + "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.ClosedJaxpr), + "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.full_lower), + "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Jaxpr), + "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.JaxprEqn), + "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.jaxpr_as_fun), + "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.lattice_join), + "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Literal), + "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Primitive), + "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", + _src_core.raise_to_shaped), + "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Token), + "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " + "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", + _src_core.Var), # Added 2024-08-14 "check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn), "check_type": ("jax.core.check_type is deprecated.", _src_core.check_type), @@ -152,10 +172,21 @@ import typing if typing.TYPE_CHECKING: + ClosedJaxpr = _src_core.ClosedJaxpr + Jaxpr = _src_core.Jaxpr + JaxprEqn = _src_core.JaxprEqn + Literal = _src_core.Literal + Primitive = _src_core.Primitive + Token = _src_core.Token + Var = _src_core.Var check_eqn = _src_core.check_eqn check_type = _src_core.check_type check_valid_jaxtype = _src_core.check_valid_jaxtype + full_lower = _src_core.full_lower + jaxpr_as_fun = _src_core.jaxpr_as_fun + lattice_join = _src_core.lattice_join non_negative_dim = _src_core.non_negative_dim + raise_to_shaped = _src_core.raise_to_shaped else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index d8774e932d04..b03c3a5b54fc 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -77,7 +77,7 @@ os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) -mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index f4fe0b9040e6..6962ef78bcff 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -19,7 +19,7 @@ from functools import partial -from jax import core +from jax._src import core from jax._src import dispatch from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py index 6c827325befc..f9d28f5ff83c 100644 --- a/jax/experimental/sparse/nm.py +++ b/jax/experimental/sparse/nm.py @@ -14,7 +14,7 @@ """N:M-sparsity associated primitives.""" -from jax import core +from jax._src import core from jax._src import dispatch from jax._src.lax.lax import DotDimensionNumbers from jax._src.lib import gpu_sparse diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 286088eebe48..3364c9be91dd 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -18,8 +18,8 @@ import numpy as np import jax -from jax import core import jax.numpy as jnp +from jax._src import core from jax._src import prng from jax._src import random from jax._src import test_util as jtu diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5ca87aae6561..5bb09043568c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3837,7 +3837,7 @@ def trace_to_jaxpr(x): constant_values= ((0.0, 0.0), (0.0, 0.0))) jaxpr = jax.make_jaxpr(trace_to_jaxpr)(x) - jax.core.jaxpr_as_fun(jaxpr)(x) + jax._src.core.jaxpr_as_fun(jaxpr)(x) jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') jnp.pad(x, [(0, 1), (0, 0)], mode= 'wrap') # doesn't crash From 593143e17e746812b25d7e302e165d986a039c7e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 10 Dec 2024 11:30:10 -0800 Subject: [PATCH 695/698] Deduplicate some GPU plugin definition code. The `jaxlib/cuda_plugin_extension.cc` and `jaxlib/rocm_plugin_extension.cc` files were nearly identical so this change consolidates the shared implementation into a single target. PiperOrigin-RevId: 704785926 --- jaxlib/BUILD | 52 ++++++---- jaxlib/cuda_plugin_extension.cc | 146 +------------------------- jaxlib/gpu_plugin_extension.cc | 178 ++++++++++++++++++++++++++++++++ jaxlib/gpu_plugin_extension.h | 27 +++++ jaxlib/rocm_plugin_extension.cc | 148 +------------------------- 5 files changed, 246 insertions(+), 305 deletions(-) create mode 100644 jaxlib/gpu_plugin_extension.cc create mode 100644 jaxlib/gpu_plugin_extension.h diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 987fe24a8008..e7ba1dd3de16 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -208,27 +208,47 @@ pybind_extension( ], ) -pybind_extension( - name = "cuda_plugin_extension", - srcs = ["cuda_plugin_extension.cc"], - module_name = "cuda_plugin_extension", +cc_library( + name = "gpu_plugin_extension", + srcs = ["gpu_plugin_extension.cc"], + hdrs = ["gpu_plugin_extension.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], deps = [ + ":kernel_nanobind_helpers", "@com_google_absl//absl/status", - "@nanobind", - "//jaxlib:kernel_nanobind_helpers", - "@xla//third_party/python_runtime:headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@tsl//tsl/platform:statusor", "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", - # TODO(jieying): move to jaxlib after py_client_gpu is separated from py_client "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +pybind_extension( + name = "cuda_plugin_extension", + srcs = ["cuda_plugin_extension.cc"], + module_name = "cuda_plugin_extension", + deps = [ + ":gpu_plugin_extension", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/pjrt:status_casters", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -237,20 +257,12 @@ pybind_extension( srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/status", + ":gpu_plugin_extension", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@xla//third_party/python_runtime:headers", - "@xla//xla:util", - "@xla//xla/ffi/api:c_api", - "@xla//xla/pjrt:status_casters", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_helpers", - "@xla//xla/python:py_client_gpu", - "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda_plugin_extension.cc index ea81109b36c0..34cf462d623e 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda_plugin_extension.cc @@ -12,135 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include -#include -#include #include "nanobind/nanobind.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/ffi/api/c_api.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" -#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "jaxlib/gpu_plugin_extension.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" -#include "xla/tsl/python/lib/core/numpy.h" -#include "xla/util.h" namespace nb = nanobind; namespace xla { namespace { -absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, - const char* fn_name_c_str, - size_t fn_name_size, nb::object fn, - int api_version, - XLA_FFI_Handler_Traits traits) { - if (c_api->extension_start == nullptr) { - return Unimplemented("The plugin does not have extension."); - } - const PJRT_Extension_Base* next = - reinterpret_cast(c_api->extension_start); - while (next != nullptr && - next->type != - PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { - next = next->next; - } - if (next == nullptr) { - return Unimplemented("The plugin does not have a custom call extension."); - } - PJRT_Gpu_Register_Custom_Call* register_custom_call = - reinterpret_cast(next)->custom_call; - - if (traits != 0) { - return Unimplemented("The plugin does not support custom call traits."); - } - - PJRT_Gpu_Register_Custom_Call_Args args; - args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name_c_str; - args.function_name_size = fn_name_size; - -#if PJRT_API_GPU_EXTENSION_VERSION >= 1 - args.api_version = api_version; -#endif - - auto as_capsule = [](nb::object obj) -> absl::StatusOr { - nb::capsule capsule; - if (!nb::try_cast(obj, capsule)) { - return absl::InvalidArgumentError( - "Custom call target registration requires handlers as PyCapsules"); - } - return capsule; - }; - -#if PJRT_API_GPU_EXTENSION_VERSION <= 1 - TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); - args.custom_call_function = fn_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); -#else - args.handler_instantiate = nullptr; - args.handler_prepare = nullptr; - args.handler_initialize = nullptr; - args.handler_execute = nullptr; - - // Register legacy custom call target (untyped void* API). - if (api_version == 0) { - TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); - args.handler_execute = capsule_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - // Register XLA FFI handler (typed API with explicit function signatures). - if (api_version == 1) { - auto capsule_execute = as_capsule(fn); - if (capsule_execute.ok()) { - args.handler_execute = capsule_execute->data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - nb::dict bundle; - if (nb::try_cast(fn, bundle)) { - auto handler = [&](const char* name) -> absl::StatusOr { - if (!bundle.contains(name)) return nullptr; - TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); - return capsule.data(); - }; - - TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); - TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); - TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); - TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - return absl::InvalidArgumentError( - "Unsupported custom call target type for api_version=1"); - } - - return absl::UnimplementedError(absl::StrFormat( - "API version %d is not supported by RegisterCustomCallTarget. " - "Supported versions are 0 and 1.", - api_version)); -#endif -} - -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -155,31 +41,7 @@ static std::string ToString(CUresult result) { } // namespace NB_MODULE(cuda_plugin_extension, m) { - tsl::ImportNumpy(); - m.def( - "register_custom_call_target", - [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, - nb::str xla_platform_name, int api_version, - XLA_FFI_Handler_Traits traits) { - const char* fn_name_c_str; - size_t fn_name_size; - nb::str fn_name_bn_str; - if (nb::try_cast(fn_name_py, fn_name_bn_str)) { - fn_name_c_str = fn_name_bn_str.c_str(); - fn_name_size = nb::len(fn_name_bn_str); - } else{ - nb::bytes bytes = nb::cast(fn_name_py); - fn_name_c_str = bytes.c_str(); - fn_name_size = bytes.size(); - } - xla::ThrowIfError(RegisterCustomCallTarget( - static_cast(c_api.data()), fn_name_c_str, - fn_name_size, std::move(fn), api_version, traits)); - }, - nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), - nb::arg("xla_platform_name"), nb::arg("api_version") = 0, - nb::arg("traits") = 0); - m.def("registrations", &Registrations); + BuildGpuPluginExtension(m); m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu_plugin_extension.cc new file mode 100644 index 000000000000..ba7896aa5dfe --- /dev/null +++ b/jaxlib/gpu_plugin_extension.cc @@ -0,0 +1,178 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu_plugin_extension.h" + +#include +#include + +#include "nanobind/nanobind.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/c_api.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/py_client_gpu.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, + const char* fn_name_c_str, + size_t fn_name_size, nb::object fn, + int api_version, + XLA_FFI_Handler_Traits traits) { + if (c_api->extension_start == nullptr) { + return Unimplemented("The plugin does not have extension."); + } + const PJRT_Extension_Base* next = + reinterpret_cast(c_api->extension_start); + while (next != nullptr && + next->type != + PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { + next = next->next; + } + if (next == nullptr) { + return Unimplemented("The plugin does not have a custom call extension."); + } + PJRT_Gpu_Register_Custom_Call* register_custom_call = + reinterpret_cast(next)->custom_call; + + if (traits != 0) { + return Unimplemented("The plugin does not support custom call traits."); + } + + PJRT_Gpu_Register_Custom_Call_Args args; + args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; + args.function_name = fn_name_c_str; + args.function_name_size = fn_name_size; + +#if PJRT_API_GPU_EXTENSION_VERSION >= 1 + args.api_version = api_version; +#endif + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration requires handlers as PyCapsules"); + } + return capsule; + }; + +#if PJRT_API_GPU_EXTENSION_VERSION <= 1 + TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); + args.custom_call_function = fn_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); +#else + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = nullptr; + + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); + args.handler_execute = capsule_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + auto capsule_execute = as_capsule(fn); + if (capsule_execute.ok()) { + args.handler_execute = capsule_execute->data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); + return capsule.data(); + }; + + TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +#endif +} + +nb::dict Registrations() { + nb::dict dict; + dict["xla_python_gpu_callback"] = + jax::EncapsulateFunction(xla::XlaPythonGpuCallback); + return dict; +} + +} // namespace + +void BuildGpuPluginExtension(nanobind::module_& m) { + tsl::ImportNumpy(); + m.def( + "register_custom_call_target", + [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, + nb::str xla_platform_name, int api_version, + XLA_FFI_Handler_Traits traits) { + const char* fn_name_c_str; + size_t fn_name_size; + nb::str fn_name_bn_str; + if (nb::try_cast(fn_name_py, fn_name_bn_str)) { + fn_name_c_str = fn_name_bn_str.c_str(); + fn_name_size = nb::len(fn_name_bn_str); + } else { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name_c_str = bytes.c_str(); + fn_name_size = bytes.size(); + } + xla::ThrowIfError(RegisterCustomCallTarget( + static_cast(c_api.data()), fn_name_c_str, + fn_name_size, std::move(fn), api_version, traits)); + }, + nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), + nb::arg("xla_platform_name"), nb::arg("api_version") = 0, + nb::arg("traits") = 0); + m.def("registrations", &Registrations); +} + +} // namespace xla diff --git a/jaxlib/gpu_plugin_extension.h b/jaxlib/gpu_plugin_extension.h new file mode 100644 index 000000000000..ae8cd73dbcfb --- /dev/null +++ b/jaxlib/gpu_plugin_extension.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_ +#define JAXLIB_GPU_PLUGIN_EXTENSION_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildGpuPluginExtension(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_ diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc index c6855879e8be..f28b5c9b4e53 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm_plugin_extension.cc @@ -12,134 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include -#include -#include #include "nanobind/nanobind.h" -#include "absl/status/status.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/ffi/api/c_api.h" -#include "xla/pjrt/c/pjrt_c_api.h" -#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" -#include "xla/pjrt/c/pjrt_c_api_helpers.h" -#include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" -#include "xla/tsl/python/lib/core/numpy.h" -#include "xla/util.h" +#include "jaxlib/gpu_plugin_extension.h" namespace nb = nanobind; namespace xla { namespace { -absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, - const char* fn_name_c_str, size_t fn_name_size, - nb::object fn, int api_version, - XLA_FFI_Handler_Traits traits) { - if (c_api->extension_start == nullptr) { - return Unimplemented("The plugin does not have extension."); - } - const PJRT_Extension_Base* next = - reinterpret_cast(c_api->extension_start); - while (next != nullptr && - next->type != - PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { - next = next->next; - } - if (next == nullptr) { - return Unimplemented("The plugin does not have a custom call extension."); - } - PJRT_Gpu_Register_Custom_Call* register_custom_call = - reinterpret_cast(next)->custom_call; - - if (traits != 0) { - return Unimplemented("The plugin does not support custom call traits."); - } - - PJRT_Gpu_Register_Custom_Call_Args args; - args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; - args.function_name = fn_name_c_str; - args.function_name_size = fn_name_size; - -#if PJRT_API_GPU_EXTENSION_VERSION >= 1 - args.api_version = api_version; -#endif - - auto as_capsule = [](nb::object obj) -> absl::StatusOr { - nb::capsule capsule; - if (!nb::try_cast(obj, capsule)) { - return absl::InvalidArgumentError( - "Custom call target registration requires handlers as PyCapsules"); - } - return capsule; - }; - -#if PJRT_API_GPU_EXTENSION_VERSION <= 1 - TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); - args.custom_call_function = fn_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); -#else - args.handler_instantiate = nullptr; - args.handler_prepare = nullptr; - args.handler_initialize = nullptr; - args.handler_execute = nullptr; - - // Register legacy custom call target (untyped void* API). - if (api_version == 0) { - TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); - args.handler_execute = capsule_execute.data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - // Register XLA FFI handler (typed API with explicit function signatures). - if (api_version == 1) { - auto capsule_execute = as_capsule(fn); - if (capsule_execute.ok()) { - args.handler_execute = capsule_execute->data(); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - nb::dict bundle; - if (nb::try_cast(fn, bundle)) { - auto handler = [&](const char* name) -> absl::StatusOr { - if (!bundle.contains(name)) return nullptr; - TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); - return capsule.data(); - }; - - TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); - TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); - TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); - TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); - RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); - return absl::OkStatus(); - } - - return absl::InvalidArgumentError( - "Unsupported custom call target type for api_version=1"); - } - - return absl::UnimplementedError(absl::StrFormat( - "API version %d is not supported by RegisterCustomCallTarget. " - "Supported versions are 0 and 1.", - api_version)); -#endif -} - -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - std::string ToString(hipError_t result) { #define OSTREAM_ROCM_ERROR(__name) \ case hipError##__name: \ @@ -179,31 +65,7 @@ std::string ToString(hipError_t result) { } // namespace NB_MODULE(rocm_plugin_extension, m) { - tsl::ImportNumpy(); - m.def( - "register_custom_call_target", - [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, - nb::str xla_platform_name, int api_version, - XLA_FFI_Handler_Traits traits) { - const char* fn_name_c_str; - size_t fn_name_size; - nb::str fn_name_bn_str; - if (nb::try_cast(fn_name_py, fn_name_bn_str)) { - fn_name_c_str = fn_name_bn_str.c_str(); - fn_name_size = nb::len(fn_name_bn_str); - } else{ - nb::bytes bytes = nb::cast(fn_name_py); - fn_name_c_str = bytes.c_str(); - fn_name_size = bytes.size(); - } - xla::ThrowIfError(RegisterCustomCallTarget( - static_cast(c_api.data()), fn_name_c_str, - fn_name_size, std::move(fn), api_version, traits)); - }, - nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), - nb::arg("xla_platform_name"), nb::arg("api_version") = 0, - nb::arg("traits") = 0); - m.def("registrations", &Registrations); + BuildGpuPluginExtension(m); m.def( "get_device_ordinal", [](std::intptr_t data_value) { From e418e88321f861fd7395fb60223d542edd1efc83 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 10 Dec 2024 11:37:29 -0800 Subject: [PATCH 696/698] [Pallas] Add non-square pl.dot test cases. PiperOrigin-RevId: 704788500 --- tests/pallas/ops_test.py | 57 +++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 12a2ad49306c..38e359aef3a1 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -17,6 +17,7 @@ from collections.abc import Sequence import functools import itertools +import math import sys from typing import Any import unittest @@ -62,6 +63,10 @@ floatx = dtypes.canonicalize_dtype(jnp.float64) +def is_power_of_two(n: int) -> bool: + return (n > 0) and (n & (n - 1) == 0) + + def smem_on_tpu(): if jtu.test_device_matches(["tpu"]): return pltpu.SMEM @@ -1410,12 +1415,45 @@ def f(x_ref, o_ref): np.testing.assert_allclose(f(x), expected) @parameterized.product( - size=[16, 32, 64, 128, 256], + lhs_and_rhs_shape=[ + ((16, 16), (16, 16)), + ((32, 32), (32, 32)), + ((64, 64), (64, 64)), + ((128, 128), (128, 128)), + ((256, 256), (256, 256)), + ((8, 128), (128, 256)), + ((8, 128), (256, 128)), + ((8, 256), (256, 128)), + ((16, 128), (128, 256)), + ((16, 128), (256, 128)), + ((16, 256), (256, 128)), + ((24, 128), (128, 256)), + ((24, 128), (256, 128)), + ((24, 256), (256, 128)), + ((128, 8), (128, 256)), + ((128, 8), (256, 128)), + ((256, 8), (256, 128)), + ((128, 16), (128, 256)), + ((128, 16), (256, 128)), + ((256, 16), (256, 128)), + ((128, 24), (128, 256)), + ((128, 24), (256, 128)), + ((256, 24), (256, 128)), + ], dtype=[jnp.float32, jnp.float16, jnp.bfloat16], trans_x=[False, True], trans_y=[False, True], ) - def test_dot(self, size, dtype, trans_x, trans_y): + def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): + lhs_shape, rhs_shape = lhs_and_rhs_shape + + final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape + final_rhs_shape = rhs_shape[::-1] if trans_y else rhs_shape + if final_lhs_shape[1] != final_rhs_shape[0]: + self.skipTest("Contraction dimensions do not match") + + out_shape = (final_lhs_shape[0], final_rhs_shape[1]) + if jtu.test_device_matches(["tpu"]): if dtype == jnp.float16: self.skipTest("float16 type is not supported on TPU") @@ -1427,12 +1465,19 @@ def test_dot(self, size, dtype, trans_x, trans_y): if jtu.test_device_matches(["gpu"]): if dtype == jnp.bfloat16: self.skipTest("bfloat16 type are not supported on GPU") - if size > 128: + if ( + math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape) + > (256 * 256) * 2 + ): self.skipTest("Shared memory size limit exceeded") + if min(*lhs_shape, *rhs_shape) < 16: + self.skipTest("All dimensions of lhs and rhs must be >= 16") + if any(not is_power_of_two(x) for x in lhs_shape + rhs_shape): + self.skipTest("All dimensions of lhs and rhs must be power of two") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((size, size), dtype), + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), grid=1, ) def dot(x_ref, y_ref, o_ref): @@ -1441,8 +1486,8 @@ def dot(x_ref, y_ref, o_ref): o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype) k1, k2 = random.split(random.key(0)) - x = random.normal(k1, (size, size), dtype=dtype) - y = random.normal(k2, (size, size), dtype=dtype) + x = random.normal(k1, lhs_shape, dtype=dtype) + y = random.normal(k2, rhs_shape, dtype=dtype) out = dot(x, y) expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y) np.testing.assert_allclose( From d4899f7b9badb3a622d8e79405339d8f0796e149 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 10 Dec 2024 13:05:05 -0800 Subject: [PATCH 697/698] [jax:custom_partitioning] Make SdyShardingRule a user facing class. Move the parsing of a sharding rule string to a free function str_to_sdy_sharding_rule. Move the building of the MLIR sharding rule to a free function sdy_sharding_rule_to_mlir. PiperOrigin-RevId: 704818640 --- jax/_src/custom_partitioning_sharding_rule.py | 529 ++++++++++-------- .../custom_partitioning_sharding_rule_test.py | 218 +++++--- 2 files changed, 437 insertions(+), 310 deletions(-) diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index 5193c9126bb7..1e3e7fe60683 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -20,16 +20,124 @@ from jax._src.lib.mlir.dialects import sdy -_CompoundFactor = tuple[str, ...] -_DimMapping = tuple[str | _CompoundFactor, ...] - # A single character replacement for ... to simplify parsing. -_ELLIPSIS: str = "…" +BATCHING: str = "…" # A prefix for names of batching dimension factors, used for expanding the # leading ... into factors. _BATCHING_DIM_FACTOR_PREFIX = "?" +def _check_factor(factor:str): + """Validates a factor. + + A factor is a string starting with a letter and containing only letters, + digits, or underscores. + """ + if not factor[0].isalpha(): + raise ValueError(f"Factor names have to start with a letter, but got '{factor[0]}'") + for char in factor[1:]: + if char != "_" and not char.isdigit() and not char.isalpha(): + raise ValueError(f"Unknown character '{char}'") + +class CompoundFactor(tuple): + """Describes the factors for a compound factor. + + A compound factor should contain at least two factors, e.g. + * CompoundFactor('b', 'c'). + """ + def __init__(self, *factors): + if len(factors) < 2: + raise ValueError("A compound factor should contain at least two factors") + for factor in factors: + if not isinstance(factor, str): + raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}") + if factor == BATCHING: + raise ValueError("Ellipsis can't be used in a compound factor") + else: + _check_factor(factor) + + def __new__(cls, *factors): + return tuple.__new__(CompoundFactor, factors) + + +class ArrayMapping(tuple): + """Describes the factors for an operand or result. + + Each element is either a factor or a CompoundFactor. A leading element can + also be BATCHING, which represents batching dimensions. examples: + * ArrayMapping('a') + * ArrayMapping('b', 'c') + * ArrayMapping(CompoundFactor('b', 'c'), 'd') + * ArrayMapping(BATCHING, CompoundFactor('b', 'c'), 'd') + """ + def __init__(self, *dim_mappings): + for i, d in enumerate(dim_mappings): + if not isinstance(d, str) and not isinstance(d, CompoundFactor): + raise ValueError( + "Each element of ArrayMapping must be a str or CompoundFactor, but" + f" got {type(d)}") + if isinstance(d, str): + if d == BATCHING: + if i != 0: + raise ValueError("Ellipsis can only be used at the beginning of a dimension") + else: + _check_factor(d) + + def __new__(cls, *dim_mappings): + return tuple.__new__(ArrayMapping, dim_mappings) + + +class SdyShardingRule: + """Represents a Shardy sharding rule. + + An SdyShardingRule contains the ArrayMappings for operands and results, and an + optional list of factor sizes. A factor is a name used in the ArrayMappings. + If a factor is only used in CompoundFactors, its size must be specified. + """ + operand_mappings: tuple[ArrayMapping, ...] + result_mappings: tuple[ArrayMapping, ...] + factor_sizes: dict[str, int] + + def __init__(self, operand_mappings: tuple[ArrayMapping, ...], + result_mappings: tuple[ArrayMapping, ...], **factor_sizes): + # Find all factors and mark whether their size can be inferred. + factors_inferrable = dict() + for value in operand_mappings + result_mappings: + for dim in value: + if isinstance(dim, str): + factors_inferrable[dim] = True + else: + for factor in dim: + if factor not in factors_inferrable.keys(): + factors_inferrable[factor] = False + + # Check that factors in factor_sizes are used in the rule. + for factor in factor_sizes: + if factor not in factors_inferrable: + raise ValueError( + f"Factor {factor} is not used in the rule, but size is provided") + + # Check that factors that are used for a whole dimension aren't in + # factor_sizes and factors that are never used for a whole dimension are + # in factor_sizes. + for factor, inferrable in factors_inferrable.items(): + if factor not in factor_sizes and not inferrable: + raise ValueError( + f"Factor {factor} is only used in compound factors; must specify" + " its size") + if factor in factor_sizes and inferrable: + raise ValueError( + f"Factor {factor} represents a whole dimension; do not specify its" + " size") + + self.operand_mappings = operand_mappings + self.result_mappings = result_mappings + self.factor_sizes = factor_sizes + + def __str__(self): + return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})" + + def _get_batching_dim_factor_name(batch_dim_order : int): """Constructs a factor name for a batching dimension. @@ -42,18 +150,18 @@ def _get_batching_dim_factor_name(batch_dim_order : int): def _parse_values( rule: str, -) -> tuple[_DimMapping, ...]: +) -> tuple[ArrayMapping, ...]: """Parses the LHS or RHS of an Einsum notation like string. Converts each operand or result in the Einsum notation like string to a tuple - of _DimMapping. This very closely follows how einops parses their rules in + of ArrayMapping. This very closely follows how einops parses their rules in einops/parsing.py. Args: rule: The Einsum notation for the operands or results of an operation. Returns: - The tuple of values. + The tuple of ArrayMapping. Raises: ValueError: If the rule is not balanced or contains unknown characters. @@ -65,10 +173,10 @@ def _parse_values( # Similar to einops rules, an empty LHS/RHS has a single scalar value. if not rule: - return ((),) + return (ArrayMapping(),) all_values = [] - # Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the + # Represent all dimensions of an value. When an value[0]==BATCHING, the # value may have 0 or more leading dimensions. value = [] current_factor = None @@ -84,12 +192,12 @@ def add_factor(x): current_compound_dim.append(x) for char in rule: - if char == _ELLIPSIS: + if char == BATCHING: if (current_factor is not None or current_compound_dim is not None or value): raise ValueError( "Ellipsis can only be used at the beginning of a dimension") - add_factor(_ELLIPSIS) + add_factor(BATCHING) continue if char in "(), ": if current_factor is not None: @@ -106,10 +214,10 @@ def add_factor(x): raise ValueError("Brackets are not balanced") if len(current_compound_dim) <= 1: raise ValueError("Brackets should contain at least two factors") - value.append(tuple(current_compound_dim)) + value.append(CompoundFactor(*current_compound_dim)) current_compound_dim = None elif char == ",": - all_values.append(tuple(value)) + all_values.append(ArrayMapping(*value)) value = [] elif char == "_" or char.isdigit() or char.isalpha(): if current_factor is None: @@ -125,256 +233,203 @@ def add_factor(x): raise ValueError(f"Brackets are not balanced in rule: '{rule}'") if current_factor is not None: add_factor(current_factor) - all_values.append(tuple(value)) + all_values.append(ArrayMapping(*value)) return tuple(all_values) +def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule: + """Constructs a SdyShardingRule object from the Einsum notation like string. -class SdyShardingRule: - """A representation for Shardy sharding rule. - - A SdyShardingRule includes an Enisum notation like string and an optional - list of factor sizes. A factor is a name in the Einsum notation. If a factor - is only used in compound factors, its size must be specified. + This is done by verifying that the input Einsum notation like string and + with optional factor sizes represents a valid sharding rule and converting + it to an internal representation. - SdyShardingRule examples: + Args: + rule: The Einsum notation like string for an operation. + **factor_sizes: The optional factor sizes. - * Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k') - * Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k') - * A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j') - * Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2) - * An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...') + Raises: + ValueError: If there is any problem with the rule or factor_sizes. """ - - def __init__(self, rule: str, **factor_sizes): - """Constructs a SdyShardingRule object from the Einsum notation like string. - - This is done by verifying that the input Einsum notation like string and - with optional factor sizes represents a valid sharding rule and converting - it to an internal representation. - - Args: - rule: The Einsum notation like string for an operation. - **factor_sizes: The optional factor sizes. - - Raises: - ValueError: If there is any problem with the rule or factor_sizes. - """ - if not isinstance(rule, str): - raise TypeError(f"rule must be a str, but got {type(rule)}") - if not all(isinstance(size, int) for size in factor_sizes.values()): - raise TypeError( - f"factor_sizes must be a dict of str to int, but got {factor_sizes}") - - # Replace ... with a single char to simplify parsing. - if _ELLIPSIS in rule: - raise ValueError(f"Unknown character '{_ELLIPSIS}'") + if not isinstance(rule, str): + raise TypeError(f"rule must be a str, but got {type(rule)}") + if not all(isinstance(size, int) for size in factor_sizes.values()): + raise TypeError( + f"factor_sizes must be a dict of str to int, but got {factor_sizes}") + + # Replace ... with a single char to simplify parsing. + if BATCHING in rule: + raise ValueError(f"Unknown character '{BATCHING}'") + if "." in rule: + rule = rule.replace("...", BATCHING) if "." in rule: - rule = rule.replace("...", _ELLIPSIS) - if "." in rule: - raise ValueError("Character '.' must be used inside ellipsis '...'") + raise ValueError("Character '.' must be used inside ellipsis '...'") - try: - operands, results = rule.split("->") - except ValueError as e: - raise ValueError(f"There is no -> in rule: '{rule}'") from e + try: + operands, results = rule.split("->") + except ValueError as e: + raise ValueError(f"There is no -> in rule: '{rule}'") from e - self.operands = _parse_values(operands) - self.results = _parse_values(results) + operand_mappings = _parse_values(operands) + result_mappings = _parse_values(results) - # Find all factors and mark whether their size can be inferred. - factors_inferrable = dict() - for value in self.operands + self.results: - for dim in value: - if dim == _ELLIPSIS: - continue - if isinstance(dim, str): - factors_inferrable[dim] = True - else: - for factor in dim: - if factor not in factors_inferrable.keys(): - factors_inferrable[factor] = False + return SdyShardingRule(operand_mappings, result_mappings, **factor_sizes) - # Check that factors in factor_sizes are used in the rule. - for factor in factor_sizes: - if factor not in factors_inferrable: - raise ValueError( - f"Factor {factor} is not used in the rule, but size is provided") - # Check that factors that are used for a whole dimension aren't in - # factor_sizes and factors that are never used for a whole dimension are - # in factor_sizes. - for factor, inferrable in factors_inferrable.items(): - if factor not in factor_sizes and not inferrable: - raise ValueError( - f"Factor {factor} is only used in compound factors; must specify" - " its size") - if factor in factor_sizes and inferrable: - raise ValueError( - f"Factor {factor} represents a whole dimension; do not specify its" - " size") +def sdy_sharding_rule_to_mlir( + rule: SdyShardingRule, + operand_types: list[ir.Type], + result_types: list[ir.Type],) -> ir.Attribute: + """Builds the MLIR representation for the sharding rule. - self.factor_sizes = factor_sizes + This is done by verifying that the rule is consistent with the types of + the operation and converting the Einsum notation like string to + OpShardingRuleAttr. + """ + if len(rule.operand_mappings) != len(operand_types): + raise ValueError( + f"Sharding rule has {len(rule.operand_mappings)} operands, but the operation" + f" has {len(operand_types)} operands") + if len(rule.result_mappings) != len(result_types): + raise ValueError( + f"Sharding rule has {len(rule.result_mappings)} results, but the operation" + f" has {len(result_types)} results") + + factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() + types = operand_types + result_types + UNKNOWN = -1 # Representation for unknown factor size or factor index. + + def get_message_for_value(i): + if i >= len(operand_types): + return f"{i - len(operand_types)}th result" + else: + return f"{i}th operand" - def __str__(self): - return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})" + def get_rank_for_value(i): + return ir.ShapedType(types[i]).rank + + def get_size_for_value_dim(i, j): + return ir.ShapedType(types[i]).shape[j] - def build( - self, - operand_types: list[ir.Type], - result_types: list[ir.Type],) -> ir.Attribute: - """Builds the MLIR representation for the sharding rule. + def add_factor(factor, size): + """Adds a factor to factors_to_indices_sizes. - This is done by verifying that the rule is consistent with the types of - the operation and converting the Einsum notation like string to - OpShardingRuleAttr. + `size` may be a dimensions size, a user specified factor size, or UNKNOWN + if a factor is first used as in a compound factor and then used for a + whole dimension. """ - if len(self.operands) != len(operand_types): + factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) + if factor_index != UNKNOWN: + # Not the first time seeing the factor. + if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: + factor_or_batching_dim = ( + f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor + else f"Batching dimension {factor[1:]}") + raise ValueError( + f"{factor_or_batching_dim} corresponds to two sizes:" + f" {factor_size} and {size}") + if size != UNKNOWN and factor_size == UNKNOWN: + factors_to_indices_sizes[factor] = [factor_index, size] + else: + # First time seeing the factor. + factor_index = len(factors_to_indices_sizes) + factors_to_indices_sizes[factor] = [factor_index, size] + + def add_batching_dim_factor(batch_dim_order, factor_size): + ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) + add_factor(ellipsis_batch_dim_name, factor_size) + + def build_dim_mapping_for_compound_factors(i, j, factors): + accumulated_size = 1 + all_indices = [] + for factor in factors: + factor_index, factor_size = factors_to_indices_sizes[factor] + accumulated_size *= factor_size + all_indices.append(factor_index) + + dim_size = get_size_for_value_dim(i, j) + if accumulated_size != dim_size: raise ValueError( - f"Sharding rule has {len(self.operands)} operands, but the operation" - f" has {len(operand_types)} operands" - ) - if len(self.results) != len(result_types): + f"{get_message_for_value(i)} actual size {dim_size} doesn't match" + f" the size {accumulated_size} derived from the compound factors" + f" {factors}") + + return sdy.DimMappingAttr.get(factor_indices=all_indices) + + # Add factors and their sizes in the order they appear in the rule, + # including the batching dimensions represented by ellipsis. + ellipsis_rank = None + for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): + value = tuple(mapping) + if value and value[0] == BATCHING: + has_batching = True + value = value[1:] + else: + has_batching = False + rule_rank = len(value) + op_rank = get_rank_for_value(i) + # The number of dimensions represented by ellipsis. + current_batching_rank = 0 + if has_batching and op_rank >= rule_rank: + current_batching_rank = op_rank - rule_rank + if has_batching: + if ellipsis_rank is None: + ellipsis_rank = current_batching_rank + elif ellipsis_rank != current_batching_rank: + raise ValueError( + "Ellipsis represents different number of leading dimensions" + f" {ellipsis_rank} and {current_batching_rank}") + rule_rank += current_batching_rank + if rule_rank != op_rank: + msg = get_message_for_value(i) raise ValueError( - f"Sharding rule has {len(self.results)} results, but the operation" - f" has {len(result_types)} results" - ) + f"Sharding rule {msg} has rank {rule_rank}, but the operation" + f" {msg} has rank {op_rank}") - factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() - types = operand_types + result_types - UNKNOWN = -1 # Representation for unknown factor size or factor index. + for j in range(current_batching_rank): + add_batching_dim_factor(j, get_size_for_value_dim(i, j)) - def get_message_for_value(i): - if i >= len(operand_types): - return f"{i - len(operand_types)}th result" + for j, dim in enumerate(value): + if isinstance(dim, str): + add_factor(dim, get_size_for_value_dim(i, j + current_batching_rank)) else: - return f"{i}th operand" - - def get_rank_for_value(i): - return ir.ShapedType(types[i]).rank - - def get_size_for_value_dim(i, j): - return ir.ShapedType(types[i]).shape[j] - - def add_factor(factor, size): - """Adds a factor to factors_to_indices_sizes. - - `size` may be a dimensions size, a user specified factor size, or UNKNOWN - if a factor is first used as in a compound factor and then used for a - whole dimension. - """ - factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) - if factor_index != UNKNOWN: - # Not the first time seeing the factor. - if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: - factor_or_batching_dim = ( - f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor - else f"Batching dimension {factor[1:]}") - raise ValueError( - f"{factor_or_batching_dim} corresponds to two sizes:" - f" {factor_size} and {size}") - if size != UNKNOWN and factor_size == UNKNOWN: - factors_to_indices_sizes[factor] = [factor_index, size] - else: - # First time seeing the factor. - factor_index = len(factors_to_indices_sizes) - factors_to_indices_sizes[factor] = [factor_index, size] - - def add_batching_dim_factor(batch_dim_order, factor_size): - ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) - add_factor(ellipsis_batch_dim_name, factor_size) - - def build_dim_mapping_for_compound_factors(i, j, factors): - accumulated_size = 1 - all_indices = [] - for factor in factors: - factor_index, factor_size = factors_to_indices_sizes[factor] - accumulated_size *= factor_size - all_indices.append(factor_index) - - dim_size = get_size_for_value_dim(i, j) - if accumulated_size != dim_size: - raise ValueError( - f"{get_message_for_value(i)} actual size {dim_size} doesn't match" - f" the size {accumulated_size} derived from the compound factors" - f" {factors}") - - return sdy.DimMappingAttr.get(factor_indices=all_indices) - - # Add factors and their sizes in the order they appear in the rule, - # including the batching dimensions represented by ellipsis. - ellipsis_rank = None - for i, value in enumerate(self.operands + self.results): - if value and value[0] == _ELLIPSIS: - has_ellipsis = True - value = value[1:] + for factor in dim: + add_factor(factor, rule.factor_sizes.get(factor, UNKNOWN)) + + # Build the tensor mappings for each operand and result. + tensor_mappings = [] + for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): + value = tuple(mapping) + dim_mappings = [] + + if value and value[0] == BATCHING: + value = value[1:] + if ellipsis_rank is None: + current_batching_rank = 0 else: - has_ellipsis = False - rule_rank = len(value) - op_rank = get_rank_for_value(i) - # The number of dimensions represented by ellipsis. - current_ellipsis_rank = 0 - if has_ellipsis and op_rank > rule_rank: - current_ellipsis_rank = op_rank - rule_rank - if has_ellipsis: - if ellipsis_rank is None: - ellipsis_rank = current_ellipsis_rank - elif ellipsis_rank != current_ellipsis_rank: - raise ValueError( - "Ellipsis represents different number of leading dimensions" - f" {ellipsis_rank} and {current_ellipsis_rank}") - rule_rank += current_ellipsis_rank - if rule_rank != op_rank: - msg = get_message_for_value(i) - raise ValueError( - f"Sharding rule {msg} has rank {rule_rank}, but the operation" - f" {msg} has rank {op_rank}") - - for j in range(current_ellipsis_rank): - add_batching_dim_factor(j, get_size_for_value_dim(i, j)) - - for j, dim in enumerate(value): - if isinstance(dim, str): - add_factor( - dim, get_size_for_value_dim(i, j + current_ellipsis_rank)) - else: - for factor in dim: - add_factor(factor, self.factor_sizes.get(factor, UNKNOWN)) + current_batching_rank = ellipsis_rank + else: + current_batching_rank = 0 - # Build the tensor mappings for each operand and result. - tensor_mappings = [] - for i, value in enumerate(self.operands + self.results): - dim_mappings = [] + for j in range(current_batching_rank): + dim_mappings.append( + sdy.DimMappingAttr.get(factor_indices=[ + factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) - if value and value[0] == _ELLIPSIS: - value = value[1:] - if ellipsis_rank is None: - current_ellipsis_rank = 0 - else: - current_ellipsis_rank = ellipsis_rank + for j, dim in enumerate(value): + if isinstance(dim, str): + dim_mappings.append( + sdy.DimMappingAttr.get( + factor_indices=[factors_to_indices_sizes[dim][0]])) else: - current_ellipsis_rank = 0 - - for j in range(current_ellipsis_rank): dim_mappings.append( - sdy.DimMappingAttr.get(factor_indices=[ - factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) + build_dim_mapping_for_compound_factors( + i, j + current_batching_rank, dim)) - for j, dim in enumerate(value): - if isinstance(dim, str): - dim_mappings.append( - sdy.DimMappingAttr.get( - factor_indices=[factors_to_indices_sizes[dim][0]])) - else: - dim_mappings.append( - build_dim_mapping_for_compound_factors( - i, j + current_ellipsis_rank, dim)) - - tensor_mappings.append( - sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) - - op_sharding_rule = sdy.OpShardingRuleAttr.get( - factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], - operand_mappings=tensor_mappings[0:len(operand_types)], - result_mappings=tensor_mappings[len(operand_types):]) - return op_sharding_rule + tensor_mappings.append( + sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) + + return sdy.OpShardingRuleAttr.get( + factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], + operand_mappings=tensor_mappings[0:len(operand_types)], + result_mappings=tensor_mappings[len(operand_types):]) diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py index 2aac4e04862f..3aed16510a4f 100644 --- a/tests/custom_partitioning_sharding_rule_test.py +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -16,148 +16,189 @@ from jax._src import test_util as jtu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy -from jax._src.custom_partitioning_sharding_rule import SdyShardingRule +from jax._src.custom_partitioning_sharding_rule import ArrayMapping, BATCHING, CompoundFactor, sdy_sharding_rule_to_mlir, str_to_sdy_sharding_rule, SdyShardingRule from jax._src.lib.mlir.dialects import hlo as stablehlo class SdyShardingRuleTest(jtu.JaxTestCase): + def test_compound_factor_not_enough_factors(self): + with self.assertRaisesRegex(ValueError, "A compound factor should contain at least two factors"): + CompoundFactor("i") + + def test_compound_factor_batching_now_allowed(self): + with self.assertRaisesRegex(ValueError, "Ellipsis can't be used in a compound factor"): + CompoundFactor(BATCHING, "i") + + def test_compound_factor_element_not_a_str(self): + with self.assertRaisesRegex(ValueError, "Each element of CompoundFactor must be a str"): + CompoundFactor("i", 2) + + def test_compound_factor_str(self): + c = CompoundFactor("i", "j", "k") + self.assertEqual(str(c), "('i', 'j', 'k')") + + def test_value_mapping_element_not_a_str_or_compound_factor(self): + with self.assertRaisesRegex(ValueError, "Each element of ArrayMapping must be a str or CompoundFactor"): + ArrayMapping(CompoundFactor("i", "j"), 3) + + def test_value_mapping_factor_name_not_start_with_letter(self): + with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): + ArrayMapping("3i", "j") + + def test_value_mapping_ellipsis_not_first(self): + with self.assertRaisesRegex(ValueError, "Ellipsis can only be used at the beginning of a dimension"): + ArrayMapping("i_j", BATCHING) + + def test_value_mapping_str(self): + v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k") + self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')") + + def test_sdy_sharding_rule_factor_size_not_used(self): + with self.assertRaisesRegex(ValueError, "Factor k is not used"): + SdyShardingRule(("i",), ("j",), k=10) + + def test_sdy_sharding_rule_factor_sizes_missing(self): + with self.assertRaisesRegex( + ValueError, + "Factor k is only used in compound factors; must specify its size"): + SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")), + (ArrayMapping(CompoundFactor("j", "k")),)) + + def test_sdy_sharding_rule_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule((ArrayMapping("i"),), (ArrayMapping("j"),), i=10) + + def test_sdy_sharding_rule_compound_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule((ArrayMapping(CompoundFactor("i", "j")),), + (ArrayMapping("i"),), i=10, j=20) + + def test_sdy_sharding_rule_str(self): + r = SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")), + (ArrayMapping(CompoundFactor("j", "k")),), k=10) + self.assertEqual(str(r), "SdyShardingRule((('i',), ('j',)), ((('j', 'k'),),), {'k': 10})") + + +class StrToSdyShardingRuleTest(jtu.JaxTestCase): def test_rule_is_not_a_str(self): with self.assertRaisesRegex(TypeError, "rule must be a str"): - SdyShardingRule(1) + str_to_sdy_sharding_rule(1) def test_factor_sizes_is_not_a_proper_dict(self): with self.assertRaisesRegex( TypeError, "factor_sizes must be a dict of str to int"): - SdyShardingRule("i->j", i="j") + str_to_sdy_sharding_rule("i->j", i="j") def test_sharding_rule_ellipsis_not_complete(self): with self.assertRaisesRegex( ValueError, "Character '.' must be used inside ellipsis '...'"): - SdyShardingRule(".i -> j") + str_to_sdy_sharding_rule(".i -> j") def test_sharding_rule_invalid_factor_name(self): with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): - SdyShardingRule("2i -> j") + str_to_sdy_sharding_rule("2i -> j") def test_sharding_rule_missing_results(self): with self.assertRaisesRegex(ValueError, "There is no -> in rule"): - SdyShardingRule("i") + str_to_sdy_sharding_rule("i") def test_sharding_rule_inbalenced_brackets(self): with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): - SdyShardingRule("i j, k)->j") + str_to_sdy_sharding_rule("i j, k)->j") def test_sharding_rule_inbalenced_brackets2(self): with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): - SdyShardingRule("i (j k->j") + str_to_sdy_sharding_rule("i (j k->j") def test_sharding_rule_empty_compound_dim(self): with self.assertRaisesRegex( ValueError, "Brackets should contain at least two factors"): - SdyShardingRule("i ( ) j k->j") + str_to_sdy_sharding_rule("i ( ) j k->j") def test_sharding_rule_one_factorcompound_dim(self): with self.assertRaisesRegex( ValueError, "Brackets should contain at least two factors"): - SdyShardingRule("i (j ) k->j") + str_to_sdy_sharding_rule("i (j ) k->j") def test_sharding_rule_nested_brackets(self): with self.assertRaisesRegex( ValueError, "Compound factors should be one level"): - SdyShardingRule("i (j (k))->j") + str_to_sdy_sharding_rule("i (j (k))->j") def test_sharding_rule_unknown_char(self): with self.assertRaisesRegex(ValueError, "Unknown character"): - SdyShardingRule("i; j->j") + str_to_sdy_sharding_rule("i; j->j") def test_sharding_rule_unknown_single_char_ellipse(self): with self.assertRaisesRegex(ValueError, "Unknown character"): - SdyShardingRule("…j->…j") + str_to_sdy_sharding_rule("…j->…j") def test_sharding_rule_ellipsis_not_leading_dim(self): with self.assertRaisesRegex( ValueError, "Ellipsis can only be used at the beginning of a dimension"): - SdyShardingRule("i ... -> j") + str_to_sdy_sharding_rule("i ... -> j") def test_sharding_rule_ellipsis_inside_compound_dim(self): with self.assertRaisesRegex( ValueError, "Ellipsis can only be used at the beginning of a dimension"): - SdyShardingRule("i, (..., j) -> j") + str_to_sdy_sharding_rule("i, (..., j) -> j") def test_sharding_rule_scalar_operand_scalar_result(self): - rule = SdyShardingRule("->") + rule = str_to_sdy_sharding_rule("->") self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})") def test_sharding_rule_one_scalar_operand(self): - rule = SdyShardingRule("i j, , k->j") + rule = str_to_sdy_sharding_rule("i j, , k->j") self.assertEqual( str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})") - def test_sharding_rule_factor_size_not_used(self): - with self.assertRaisesRegex(ValueError, "Factor k is not used"): - SdyShardingRule("i->j", k=10) - - def test_sharding_rule_factor_size_not_necessary(self): - with self.assertRaisesRegex( - ValueError, - "Factor i represents a whole dimension; do not specify its size"): - SdyShardingRule("i->j", i=10) - - def test_sharding_rule_compound_factor_size_not_necessary(self): - with self.assertRaisesRegex( - ValueError, - "Factor i represents a whole dimension; do not specify its size"): - SdyShardingRule("(i j) -> i", i=10, j=20) - - def test_sharding_rule_factor_sizes_missing(self): - with self.assertRaisesRegex( - ValueError, - "Factor k is only used in compound factors; must specify its size"): - SdyShardingRule("i j -> (j k)") - def test_sharding_rule_factor_elementwise_add(self): - rule = SdyShardingRule("... i j, ...i j -> ...i j") + rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j") self.assertEqual( str(rule), "SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i'," " 'j'),), {})") def test_sharding_rule_factor_vector_scalar_add(self): - rule = SdyShardingRule("...i, -> ...i") + rule = str_to_sdy_sharding_rule("...i, -> ...i") self.assertEqual( str(rule), "SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})") def test_sharding_rule_factor_reshape_combining(self): - rule = SdyShardingRule("i j -> (i j)") + rule = str_to_sdy_sharding_rule("i j -> (i j)") self.assertEqual( str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})") def test_sharding_rule_factor_reshape_reordering(self): - rule = SdyShardingRule("(j i) -> (i j)", i=10, j=20) + rule = str_to_sdy_sharding_rule("(j i) -> (i j)", i=10, j=20) self.assertEqual( str(rule), "SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':" " 20})") def test_sharding_rule_factor_compound_then_individual(self): - rule = SdyShardingRule("(i j) (j k) i -> j k") + rule = str_to_sdy_sharding_rule("(i j) (j k) i -> j k") self.assertEqual( str(rule), "SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})") def test_sharding_rule_factor_individual_then_compound(self): - rule = SdyShardingRule("i j k -> (i j) (j k)") + rule = str_to_sdy_sharding_rule("i j k -> (i j) (j k)") self.assertEqual( str(rule), "SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})") def test_sharding_rule_factor_infer_k(self): - rule = SdyShardingRule("_i (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20) + rule = str_to_sdy_sharding_rule("i_ (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20) self.assertEqual( str(rule), - "SdyShardingRule((('_i', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" + "SdyShardingRule((('i_', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" ",), {'k': 10, 'm': 10, 'bar_24': 20})") @@ -189,11 +230,11 @@ def test_conversion_rule_op_mismatch_in_operands_num(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("i j-> i j") + rule = str_to_sdy_sharding_rule("i j-> i j") with self.assertRaisesRegex( ValueError, "Sharding rule has 1 operands, but the operation has 2 operands"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -205,12 +246,12 @@ def test_conversion_rule_op_mismatch_in_operands_rank(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("i j, i j k-> i j") + rule = str_to_sdy_sharding_rule("i j, i j k-> i j") with self.assertRaisesRegex( ValueError, "Sharding rule 1th operand has rank 3, but the operation 1th " "operand has rank 2"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -223,11 +264,11 @@ def test_conversion_rule_op_mismatch_in_results_num(self): operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("i j, i j -> i j, i j") + rule = str_to_sdy_sharding_rule("i j, i j -> i j, i j") with self.assertRaisesRegex( ValueError, "Sharding rule has 2 results, but the operation has 1 results"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -239,12 +280,12 @@ def test_conversion_rule_op_mismatch_in_results_dim(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j, i j -> i j k") + rule = str_to_sdy_sharding_rule("i j, i j -> i j k") with self.assertRaisesRegex( ValueError, "Sharding rule 0th result has rank 3, but the operation 0th " "result has rank 2"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -256,11 +297,11 @@ def test_conversion_factor_has_two_sizes(self): results=[self.get_tensor_type((16, 64))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j, i j -> i j") + rule = str_to_sdy_sharding_rule("i j, i j -> i j") with self.assertRaisesRegex( ValueError, "Factor j corresponds to two sizes: 32 and 64"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -272,14 +313,30 @@ def test_conversion_batching_dim_has_two_sizes(self): results=[self.get_tensor_type((16, 64))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("..., ... -> ...") + rule = str_to_sdy_sharding_rule("..., ... -> ...") with self.assertRaisesRegex( ValueError, "Batching dimension 1 corresponds to two sizes: 32 and 64"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,],) + def test_conversion_invalid_batching_dim(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("... i j k, ... i j k -> ... i j k") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 0th operand has rank 3, but the operation 0th operand has rank 2"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + def test_conversion_compound_dimension_size_mismatch(self): opnd = self.create_tensor_value((2, 4)) result = ir.Operation.create( @@ -287,12 +344,12 @@ def test_conversion_compound_dimension_size_mismatch(self): results=[self.get_tensor_type((9,))], operands=[opnd,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j -> (i j)") + rule = str_to_sdy_sharding_rule("i j -> (i j)") with self.assertRaisesRegex( ValueError, "0th result actual size 9 doesn't match the size 8 derived from the" " compound factors"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type], [result.result.type,]) @@ -304,14 +361,29 @@ def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("..., ... -> ...") + rule = str_to_sdy_sharding_rule("..., ... -> ...") with self.assertRaisesRegex( ValueError, "Ellipsis represents different number of leading dimensions 2 and 1"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) + def test_conversion_compound_then_individual(self): + opnd = self.create_tensor_value((8,)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((2,4))], + operands=[opnd,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("(i j) -> i j") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>") + def test_conversion_elementwise_rule_scalar_instance(self): opnd0 = self.create_tensor_value(()) opnd1 = self.create_tensor_value(()) @@ -320,8 +392,8 @@ def test_conversion_elementwise_rule_scalar_instance(self): results=[self.get_tensor_type(())], operands=[opnd0, opnd1], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("..., ... -> ...") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("..., ... -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( @@ -336,8 +408,8 @@ def test_conversion_elementwise_rule_2D_instance(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("..., ... -> ...") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("..., ... -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( @@ -352,8 +424,8 @@ def test_conversion_vector_scalar_add_2D_instance(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("..., -> ...") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("..., -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( @@ -367,8 +439,8 @@ def test_conversion_reshape_rule(self): results=[self.get_tensor_type((8,))], operands=[opnd0,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j -> (i j)") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("i j -> (i j)") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type], [result.result.type,]) self.assertEqual( @@ -383,8 +455,8 @@ def test_conversion_contracting_dim_matmul(self): results=[self.get_tensor_type((16, 8))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("... contracting_dim, contracting_dim k -> ... k") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("... contracting_dim, contracting_dim k -> ... k") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( From e6d6c4ef8a9b2297fb49c621abebceef5142a079 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 10 Dec 2024 13:24:21 -0800 Subject: [PATCH 698/698] Delete non-public API jax.lib.xla_bridge._backends This is doubly non-public: nothing under `jax.lib` is public, and also the object itself has a preceding underscore. Therefore it is safe to remove (chex had referenced this previously, but that's now addressed in https://github.com/google-deepmind/chex/commit/adaf1b2b7555e75a8ac118549e204520311f8ea0). PiperOrigin-RevId: 704825268 --- jax/lib/xla_bridge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 654abc35fc78..2bcb1cb037f4 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -17,7 +17,6 @@ default_backend as _deprecated_default_backend, get_backend as _deprecated_get_backend, xla_client as _deprecated_xla_client, - _backends as _backends, ) from jax._src.compiler import (