From 97ddf1d7f817bb1714cc97c43a4c08576af67640 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 14 Oct 2024 20:08:03 -0700 Subject: [PATCH] Add a private property to NamedSharding called `_logical_device_ids` which allows you to pass a custom `tile_assignment_devices()` equivalent. This is because for Shardy, GSPMDSharding doesn't work, so `device_put` on a mesh with different device order needs `NamedSharding` support. Bonus is that the logic is now simplified wrt the previous version in `_different_device_order_reshard`. This will also allow us to remove OpSharding usage in other projects which require such kind of permutation capabilities. PiperOrigin-RevId: 685925636 --- xla/python/sharding.cc | 16 +++++++++++----- xla/python/sharding.h | 7 ++++++- xla/python/xla_client.py | 2 +- xla/python/xla_extension/__init__.pyi | 2 ++ 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/xla/python/sharding.cc b/xla/python/sharding.cc index 2c3d70a465d63..a5678221c4208 100644 --- a/xla/python/sharding.cc +++ b/xla/python/sharding.cc @@ -148,7 +148,9 @@ bool ShardingEqual(nb::handle a, nb::handle b) { a_named_sharding->memory_kind().equal( b_named_sharding->memory_kind()) && a_named_sharding->manual_axes().equal( - b_named_sharding->manual_axes()); + b_named_sharding->manual_axes()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); } if (a_type.is(GSPMDSharding::type())) { @@ -175,7 +177,8 @@ bool ShardingEqual(nb::handle a, nb::handle b) { NamedSharding::NamedSharding(nb::object mesh, nb::object spec, nb::object memory_kind, nb::object parsed_pspec, - nb::object manual_axes) + nb::object manual_axes, + nb::object logical_device_ids) : Sharding(/*num_devices=*/[&mesh]() { return nb::cast(mesh.attr("size")); }()), @@ -183,7 +186,8 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, spec_(std::move(spec)), memory_kind_(std::move(memory_kind)), parsed_pspec_(std::move(parsed_pspec)), - manual_axes_(std::move(manual_axes)) { + manual_axes_(std::move(manual_axes)), + logical_device_ids_(std::move(logical_device_ids)) { nb::object idl = nb::object(mesh_.attr("_internal_device_list")); if (idl.is_none()) { internal_device_list_ = std::nullopt; @@ -261,16 +265,18 @@ void RegisterSharding(nb::module_& m) { nb::class_(m, "Sharding").def(nb::init<>()); nb::class_(m, "NamedSharding", nb::dynamic_attr()) - .def(nb::init(), nb::arg("mesh"), nb::arg("spec").none(), nb::arg("memory_kind").none() = nb::none(), nb::arg("_parsed_pspec").none() = nb::none(), - nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr))) + nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)), + nb::arg("_logical_device_ids").none() = nb::none()) .def_prop_ro("mesh", &NamedSharding::mesh) .def_prop_ro("spec", &NamedSharding::spec) .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) .def_prop_rw("_parsed_pspec", &NamedSharding::parsed_pspec, &NamedSharding::set_parsed_pspec) .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { diff --git a/xla/python/sharding.h b/xla/python/sharding.h index 847938478b2e8..5b41ae0411068 100644 --- a/xla/python/sharding.h +++ b/xla/python/sharding.h @@ -71,13 +71,17 @@ class NamedSharding : public Sharding { public: NamedSharding(nanobind::object mesh, nanobind::object spec, nanobind::object memory_kind, nanobind::object parsed_pspec, - nanobind::object manual_axes); + nanobind::object manual_axes, + nanobind::object logical_device_ids); const nanobind::object& mesh() const { return mesh_; } const nanobind::object& spec() const { return spec_; } const nanobind::object& memory_kind() const { return memory_kind_; } const nanobind::object& parsed_pspec() const { return parsed_pspec_; } const nanobind::object& manual_axes() const { return manual_axes_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } void set_parsed_pspec(nanobind::object parsed_pspec) { parsed_pspec_ = std::move(parsed_pspec); } @@ -102,6 +106,7 @@ class NamedSharding : public Sharding { nanobind::object memory_kind_; nanobind::object parsed_pspec_; nanobind::object manual_axes_; + nanobind::object logical_device_ids_; std::optional> internal_device_list_; }; diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index d04a70c55db3e..81caaaf73598d 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 291 +_version = 292 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 49190362933c7..33b68ec36e70d 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -884,6 +884,7 @@ class NamedSharding(Sharding): memory_kind: Optional[str] = None, _parsed_pspec: Any = None, _manual_axes: frozenset[Any] = frozenset(), + _logical_device_ids: tuple[int, ...] | None = None, ): ... mesh: Any spec: Any @@ -891,6 +892,7 @@ class NamedSharding(Sharding): _parsed_pspec: Any _internal_device_list: DeviceList _manual_axes: frozenset[Any] + _logical_device_ids: tuple[int, ...] | None class SingleDeviceSharding(Sharding): def __init__(self, device: Device, *, memory_kind: Optional[str] = None): ...