Skip to content

Commit

Permalink
Add a private property to NamedSharding called _logical_device_ids
Browse files Browse the repository at this point in the history
…which allows you to pass a custom `tile_assignment_devices()` equivalent.

This is because for Shardy, GSPMDSharding doesn't work, so `device_put` on a mesh with different device order needs `NamedSharding` support. Bonus is that the logic is now simplified wrt the previous version in `_different_device_order_reshard`.

This will also allow us to remove OpSharding usage in other projects which require such kind of permutation capabilities.

PiperOrigin-RevId: 685925636
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 15, 2024
1 parent 4c3bfad commit 97ddf1d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
16 changes: 11 additions & 5 deletions xla/python/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ bool ShardingEqual(nb::handle a, nb::handle b) {
a_named_sharding->memory_kind().equal(
b_named_sharding->memory_kind()) &&
a_named_sharding->manual_axes().equal(
b_named_sharding->manual_axes());
b_named_sharding->manual_axes()) &&
a_named_sharding->logical_device_ids().equal(
b_named_sharding->logical_device_ids());
}

if (a_type.is(GSPMDSharding::type())) {
Expand All @@ -175,15 +177,17 @@ bool ShardingEqual(nb::handle a, nb::handle b) {

NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
nb::object memory_kind, nb::object parsed_pspec,
nb::object manual_axes)
nb::object manual_axes,
nb::object logical_device_ids)
: Sharding(/*num_devices=*/[&mesh]() {
return nb::cast<int>(mesh.attr("size"));
}()),
mesh_(std::move(mesh)),
spec_(std::move(spec)),
memory_kind_(std::move(memory_kind)),
parsed_pspec_(std::move(parsed_pspec)),
manual_axes_(std::move(manual_axes)) {
manual_axes_(std::move(manual_axes)),
logical_device_ids_(std::move(logical_device_ids)) {
nb::object idl = nb::object(mesh_.attr("_internal_device_list"));
if (idl.is_none()) {
internal_device_list_ = std::nullopt;
Expand Down Expand Up @@ -261,16 +265,18 @@ void RegisterSharding(nb::module_& m) {
nb::class_<Sharding>(m, "Sharding").def(nb::init<>());

nb::class_<NamedSharding, Sharding>(m, "NamedSharding", nb::dynamic_attr())
.def(nb::init<nb::object, nb::object, nb::object, nb::object,
.def(nb::init<nb::object, nb::object, nb::object, nb::object, nb::object,
nb::object>(),
nb::arg("mesh"), nb::arg("spec").none(),
nb::arg("memory_kind").none() = nb::none(),
nb::arg("_parsed_pspec").none() = nb::none(),
nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)))
nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)),
nb::arg("_logical_device_ids").none() = nb::none())
.def_prop_ro("mesh", &NamedSharding::mesh)
.def_prop_ro("spec", &NamedSharding::spec)
.def_prop_ro("_memory_kind", &NamedSharding::memory_kind)
.def_prop_ro("_manual_axes", &NamedSharding::manual_axes)
.def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids)
.def_prop_rw("_parsed_pspec", &NamedSharding::parsed_pspec,
&NamedSharding::set_parsed_pspec)
.def_prop_ro("_internal_device_list", [](const NamedSharding& s) {
Expand Down
7 changes: 6 additions & 1 deletion xla/python/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,17 @@ class NamedSharding : public Sharding {
public:
NamedSharding(nanobind::object mesh, nanobind::object spec,
nanobind::object memory_kind, nanobind::object parsed_pspec,
nanobind::object manual_axes);
nanobind::object manual_axes,
nanobind::object logical_device_ids);

const nanobind::object& mesh() const { return mesh_; }
const nanobind::object& spec() const { return spec_; }
const nanobind::object& memory_kind() const { return memory_kind_; }
const nanobind::object& parsed_pspec() const { return parsed_pspec_; }
const nanobind::object& manual_axes() const { return manual_axes_; }
const nanobind::object& logical_device_ids() const {
return logical_device_ids_;
}
void set_parsed_pspec(nanobind::object parsed_pspec) {
parsed_pspec_ = std::move(parsed_pspec);
}
Expand All @@ -102,6 +106,7 @@ class NamedSharding : public Sharding {
nanobind::object memory_kind_;
nanobind::object parsed_pspec_;
nanobind::object manual_axes_;
nanobind::object logical_device_ids_;
std::optional<xla::nb_class_ptr<PyDeviceList>> internal_device_list_;
};

Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 291
_version = 292

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
2 changes: 2 additions & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -884,13 +884,15 @@ class NamedSharding(Sharding):
memory_kind: Optional[str] = None,
_parsed_pspec: Any = None,
_manual_axes: frozenset[Any] = frozenset(),
_logical_device_ids: tuple[int, ...] | None = None,
): ...
mesh: Any
spec: Any
_memory_kind: Optional[str]
_parsed_pspec: Any
_internal_device_list: DeviceList
_manual_axes: frozenset[Any]
_logical_device_ids: tuple[int, ...] | None

class SingleDeviceSharding(Sharding):
def __init__(self, device: Device, *, memory_kind: Optional[str] = None): ...
Expand Down

0 comments on commit 97ddf1d

Please sign in to comment.