Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Jax: Stop returning a list of cost-analyses.
Browse files Browse the repository at this point in the history
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 707170952
zacmustin authored and Google-ML-Automation committed Jan 13, 2025
1 parent e72c148 commit c08a194
Showing 3 changed files with 11 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -59,6 +59,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jax 0.4.38 (Dec 17, 2024)

* Breaking Changes
* `XlaExecutable.cost_analysis` now returns a `dict[str, float]` (instead of a
single-element `list[dict[str, float]]`).

* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.
2 changes: 1 addition & 1 deletion docs/aot.md
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 :
>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
>>> compiled.cost_analysis()['flops']
2.0

>>> # Execute the compiled function!
16 changes: 6 additions & 10 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
@@ -250,15 +250,13 @@ def as_text(self) -> str:
else:
raise

# TODO(b/384741132): this should return a single dict (I think returning a list
# was to support MPMD executables, which never fully landed).
def cost_analysis(self) -> list[dict[str, float]]:
def cost_analysis(self) -> dict[str, float]:
xla_ext_exe = self.xla_extension_executable()

# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
if hasattr(xla_ext_exe, "cost_analysis"):
try:
return [xla_ext_exe.cost_analysis()]
return xla_ext_exe.cost_analysis()
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
@@ -276,11 +274,9 @@ def cost_analysis(self) -> list[dict[str, float]]:
" were found)."
)

return [
xla_extension.hlo_module_cost_analysis(
xla_ext_exe.client, hlo_modules[0]
)
]
return xla_extension.hlo_module_cost_analysis(
xla_ext_exe.client, hlo_modules[0]
)
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
supported = not (type(msg) is str and
@@ -295,7 +291,7 @@ def cost_analysis(self) -> list[dict[str, float]]:
and hasattr(self.unsafe_call, "compiled")
and hasattr(self.unsafe_call.compiled, "cost_analysis")
):
return [self.unsafe_call.compiled.cost_analysis()]
return self.unsafe_call.compiled.cost_analysis()

raise NotImplementedError(
f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}"

0 comments on commit c08a194

Please sign in to comment.