Skip to content

Commit

Permalink
Refactor partitioners to make them more modular (pytorch#1696)
Browse files Browse the repository at this point in the history
Summary:

This change is so that people can inherit the function more easily.

Differential Revision: D53628705
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 12, 2024
1 parent 8e41b29 commit 9195dc7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 27 deletions.
39 changes: 15 additions & 24 deletions torchrec/distributed/planner/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,11 @@ def partition(

_topology: Topology = copy.deepcopy(storage_constraint)
minheap_devices: Optional[List[OrderedDeviceHardware]] = None
_host_level_devices = GreedyPerfPartitioner._get_host_level_devices(_topology)
_host_level_devices = self._get_host_level_devices(_topology)

# first partition the uniform sharding options (RW & DP)
uniform_sharding_options = _get_uniform_sharding_options(proposal)
GreedyPerfPartitioner._uniform_partition(
uniform_sharding_options, _topology.devices
)
self._uniform_partition(uniform_sharding_options, _topology.devices)

# group the rest sharding options by colocation type (co-host, co-device, none)
# and sort the groups by storage in reverse order
Expand All @@ -245,23 +243,21 @@ def partition(
sharding_option_group.sharding_options[0].partition_by
== PartitionByType.HOST.value
):
GreedyPerfPartitioner._cohost_partition(
sharding_option_group, _host_level_devices
)
self._cohost_partition(sharding_option_group, _host_level_devices)
# _cohost_partition invalidates minheap_devices, force rebuild before using
minheap_devices = None
elif (
sharding_option_group.sharding_options[0].partition_by
== PartitionByType.DEVICE.value
):
if minheap_devices is None:
minheap_devices = GreedyPerfPartitioner._establish_minheap(
minheap_devices = self._establish_minheap(
_topology.devices, _topology.local_world_size
)
assert (
len(sharding_option_group.sharding_options) == 1
), f"Unexpected length for sharding options: {len(sharding_option_group.sharding_options)}"
GreedyPerfPartitioner._device_partition(
self._device_partition(
sharding_option_group.sharding_options[0],
minheap_devices,
)
Expand All @@ -273,18 +269,17 @@ def partition(
self._topology: Topology = _topology
return proposal

@staticmethod
def _establish_minheap(
devices: List[DeviceHardware], local_world_size: int
self, devices: List[DeviceHardware], local_world_size: int
) -> List[OrderedDeviceHardware]:
minheap_devices = [
OrderedDeviceHardware(device, local_world_size) for device in devices
]
heapq.heapify(minheap_devices)
return minheap_devices

@staticmethod
def _device_partition(
self,
sharding_option: ShardingOption,
minheap_devices: List[OrderedDeviceHardware],
bulk_heapify_threshold: float = 0.25,
Expand Down Expand Up @@ -322,8 +317,8 @@ def _device_partition(
minheap_devices.extend(tmp_heap)
heapq.heapify(minheap_devices)

@staticmethod
def _cohost_partition(
self,
sharding_option_group: ShardingOptionGroup,
_host_level_devices: List[List[DeviceHardware]],
) -> None:
Expand All @@ -344,9 +339,7 @@ def _cohost_partition(
sharding_option.sharding_type
== ShardingType.TABLE_ROW_WISE.value
):
GreedyPerfPartitioner._uniform_partition(
[sharding_option], host_devices
)
self._uniform_partition([sharding_option], host_devices)
# _uniform_partition invalidates minheap_devices, force rebuild
# before using
minheap_devices = None
Expand All @@ -355,12 +348,10 @@ def _cohost_partition(
== ShardingType.TABLE_COLUMN_WISE.value
):
if minheap_devices is None:
minheap_devices = GreedyPerfPartitioner._establish_minheap(
minheap_devices = self._establish_minheap(
host_devices, len(host_devices)
)
GreedyPerfPartitioner._device_partition(
sharding_option, minheap_devices
)
self._device_partition(sharding_option, minheap_devices)
else:
raise RuntimeError(
f"unexpected cohost sharding type: {sharding_option.sharding_type}"
Expand All @@ -382,8 +373,9 @@ def _cohost_partition(
message=f"can't find a host for sharding option group {sharding_option_group}",
)

@staticmethod
def _get_host_level_devices(_topology: Topology) -> List[List[DeviceHardware]]:
def _get_host_level_devices(
self, _topology: Topology
) -> List[List[DeviceHardware]]:
num_hosts: int = _topology.world_size // _topology.local_world_size
host_level_devices: List[List[DeviceHardware]] = []
for i in range(num_hosts):
Expand All @@ -393,9 +385,8 @@ def _get_host_level_devices(_topology: Topology) -> List[List[DeviceHardware]]:
host_level_devices.append(devices_in_host)
return host_level_devices

@staticmethod
def _uniform_partition(
sharding_options: List[ShardingOption], devices: List[DeviceHardware]
self, sharding_options: List[ShardingOption], devices: List[DeviceHardware]
) -> None:
for sharding_option in sharding_options:
if sharding_option.num_shards != len(devices):
Expand Down
6 changes: 3 additions & 3 deletions torchrec/distributed/planner/tests/test_partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ def empty_devices() -> List[DeviceHardware]:

def validate(threshold: float) -> None:
devices = empty_devices()
minheap_devices = GreedyPerfPartitioner._establish_minheap(
minheap_devices = GreedyPerfPartitioner()._establish_minheap(
devices, local_world_size
)

GreedyPerfPartitioner._device_partition(
GreedyPerfPartitioner()._device_partition(
sharding_option, minheap_devices, threshold
)

want_minheap_devices = GreedyPerfPartitioner._establish_minheap(
want_minheap_devices = GreedyPerfPartitioner()._establish_minheap(
devices, local_world_size
)
device_heaps_equal(minheap_devices, want_minheap_devices)
Expand Down

0 comments on commit 9195dc7

Please sign in to comment.