From 9195dc7c4063b503b53023359773f406dc801644 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 12 Feb 2024 11:16:35 -0800 Subject: [PATCH] Refactor partitioners to make them more modular (#1696) Summary: This change is so that people can inherit the function more easily. Differential Revision: D53628705 --- torchrec/distributed/planner/partitioners.py | 39 +++++++------------ .../planner/tests/test_partitioners.py | 6 +-- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/torchrec/distributed/planner/partitioners.py b/torchrec/distributed/planner/partitioners.py index ff603bd10..25f9701e1 100644 --- a/torchrec/distributed/planner/partitioners.py +++ b/torchrec/distributed/planner/partitioners.py @@ -226,13 +226,11 @@ def partition( _topology: Topology = copy.deepcopy(storage_constraint) minheap_devices: Optional[List[OrderedDeviceHardware]] = None - _host_level_devices = GreedyPerfPartitioner._get_host_level_devices(_topology) + _host_level_devices = self._get_host_level_devices(_topology) # first partition the uniform sharding options (RW & DP) uniform_sharding_options = _get_uniform_sharding_options(proposal) - GreedyPerfPartitioner._uniform_partition( - uniform_sharding_options, _topology.devices - ) + self._uniform_partition(uniform_sharding_options, _topology.devices) # group the rest sharding options by colocation type (co-host, co-device, none) # and sort the groups by storage in reverse order @@ -245,9 +243,7 @@ def partition( sharding_option_group.sharding_options[0].partition_by == PartitionByType.HOST.value ): - GreedyPerfPartitioner._cohost_partition( - sharding_option_group, _host_level_devices - ) + self._cohost_partition(sharding_option_group, _host_level_devices) # _cohost_partition invalidates minheap_devices, force rebuild before using minheap_devices = None elif ( @@ -255,13 +251,13 @@ def partition( == PartitionByType.DEVICE.value ): if minheap_devices is None: - minheap_devices = GreedyPerfPartitioner._establish_minheap( + minheap_devices = self._establish_minheap( _topology.devices, _topology.local_world_size ) assert ( len(sharding_option_group.sharding_options) == 1 ), f"Unexpected length for sharding options: {len(sharding_option_group.sharding_options)}" - GreedyPerfPartitioner._device_partition( + self._device_partition( sharding_option_group.sharding_options[0], minheap_devices, ) @@ -273,9 +269,8 @@ def partition( self._topology: Topology = _topology return proposal - @staticmethod def _establish_minheap( - devices: List[DeviceHardware], local_world_size: int + self, devices: List[DeviceHardware], local_world_size: int ) -> List[OrderedDeviceHardware]: minheap_devices = [ OrderedDeviceHardware(device, local_world_size) for device in devices @@ -283,8 +278,8 @@ def _establish_minheap( heapq.heapify(minheap_devices) return minheap_devices - @staticmethod def _device_partition( + self, sharding_option: ShardingOption, minheap_devices: List[OrderedDeviceHardware], bulk_heapify_threshold: float = 0.25, @@ -322,8 +317,8 @@ def _device_partition( minheap_devices.extend(tmp_heap) heapq.heapify(minheap_devices) - @staticmethod def _cohost_partition( + self, sharding_option_group: ShardingOptionGroup, _host_level_devices: List[List[DeviceHardware]], ) -> None: @@ -344,9 +339,7 @@ def _cohost_partition( sharding_option.sharding_type == ShardingType.TABLE_ROW_WISE.value ): - GreedyPerfPartitioner._uniform_partition( - [sharding_option], host_devices - ) + self._uniform_partition([sharding_option], host_devices) # _uniform_partition invalidates minheap_devices, force rebuild # before using minheap_devices = None @@ -355,12 +348,10 @@ def _cohost_partition( == ShardingType.TABLE_COLUMN_WISE.value ): if minheap_devices is None: - minheap_devices = GreedyPerfPartitioner._establish_minheap( + minheap_devices = self._establish_minheap( host_devices, len(host_devices) ) - GreedyPerfPartitioner._device_partition( - sharding_option, minheap_devices - ) + self._device_partition(sharding_option, minheap_devices) else: raise RuntimeError( f"unexpected cohost sharding type: {sharding_option.sharding_type}" @@ -382,8 +373,9 @@ def _cohost_partition( message=f"can't find a host for sharding option group {sharding_option_group}", ) - @staticmethod - def _get_host_level_devices(_topology: Topology) -> List[List[DeviceHardware]]: + def _get_host_level_devices( + self, _topology: Topology + ) -> List[List[DeviceHardware]]: num_hosts: int = _topology.world_size // _topology.local_world_size host_level_devices: List[List[DeviceHardware]] = [] for i in range(num_hosts): @@ -393,9 +385,8 @@ def _get_host_level_devices(_topology: Topology) -> List[List[DeviceHardware]]: host_level_devices.append(devices_in_host) return host_level_devices - @staticmethod def _uniform_partition( - sharding_options: List[ShardingOption], devices: List[DeviceHardware] + self, sharding_options: List[ShardingOption], devices: List[DeviceHardware] ) -> None: for sharding_option in sharding_options: if sharding_option.num_shards != len(devices): diff --git a/torchrec/distributed/planner/tests/test_partitioners.py b/torchrec/distributed/planner/tests/test_partitioners.py index 83e38bed4..a0ddae96f 100644 --- a/torchrec/distributed/planner/tests/test_partitioners.py +++ b/torchrec/distributed/planner/tests/test_partitioners.py @@ -220,15 +220,15 @@ def empty_devices() -> List[DeviceHardware]: def validate(threshold: float) -> None: devices = empty_devices() - minheap_devices = GreedyPerfPartitioner._establish_minheap( + minheap_devices = GreedyPerfPartitioner()._establish_minheap( devices, local_world_size ) - GreedyPerfPartitioner._device_partition( + GreedyPerfPartitioner()._device_partition( sharding_option, minheap_devices, threshold ) - want_minheap_devices = GreedyPerfPartitioner._establish_minheap( + want_minheap_devices = GreedyPerfPartitioner()._establish_minheap( devices, local_world_size ) device_heaps_equal(minheap_devices, want_minheap_devices)