Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relaxing divisibility constraints on num_canonical_nodes and num_physical_nodes #476

Merged
merged 10 commits into from
Oct 26, 2023
2 changes: 1 addition & 1 deletion streaming/base/batching/per_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e
stream_partition = get_partitions(dataset.partition_algo, samples_in_stream,
dataset.num_canonical_nodes, world.num_nodes,
world.ranks_per_node, world.workers_per_rank, batch_size,
0)
0, dataset.initial_physical_nodes)
if dataset.shuffle:
# Ratio of stream's shuffle block size to overall shuffle block size should be the
# same as the ratio of the stream's samples to overall samples.
Expand Down
3 changes: 2 additions & 1 deletion streaming/base/batching/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def generate_work_random_batching(dataset: StreamingDataset, world: World, epoch
# batch) such that we have an elastically deterministic sample order.
big_ids = get_partitions(dataset.partition_algo, dataset.epoch_size,
dataset.num_canonical_nodes, world.num_nodes, world.ranks_per_node,
world.workers_per_rank, dataset.batch_size, sample_in_epoch)
world.workers_per_rank, dataset.batch_size, sample_in_epoch,
dataset.initial_physical_nodes)

# If we need to shuffle, shuffle in a node-aware and *underlying* shard-aware way.
if dataset.shuffle:
Expand Down
3 changes: 2 additions & 1 deletion streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
# We also handle used samples (drop_first) at the end.
stream_partition = get_partitions(dataset.partition_algo, samples_in_stream,
dataset.num_canonical_nodes, 1, world.ranks_per_node,
world.workers_per_rank, 1, 0)
world.workers_per_rank, 1, 0,
dataset.initial_physical_nodes)
if dataset.shuffle:
# Ratio of stream's shuffle block size to overall shuffle block size should be the
# same as the ratio of the stream's samples to overall samples.
Expand Down
8 changes: 7 additions & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ def __init__(self,
self.shuffle_block_size = shuffle_block_size
self.batching_method = batching_method

# Initialize initial_physical_nodes to None. If we are resuming, then we will set it to the
# number of physical nodes of the initial run in the _resume function.
self.initial_physical_nodes = None
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved

# Check streams vs remote/local.
if bool(streams) == (bool(remote) or bool(local)):
raise ValueError(
Expand Down Expand Up @@ -678,6 +682,7 @@ def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
sample_in_epoch = obj['sample_in_epoch']
self.num_canonical_nodes = obj['num_canonical_nodes']
self.shuffle_seed = obj['shuffle_seed']
self.initial_physical_nodes = obj['initial_physical_nodes']
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
self._set_predownload()

return epoch, sample_in_epoch
Expand Down Expand Up @@ -736,7 +741,8 @@ def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]:
'epoch': epoch,
'sample_in_epoch': sample_in_epoch,
'num_canonical_nodes': self.num_canonical_nodes,
'shuffle_seed': self.shuffle_seed
'shuffle_seed': self.shuffle_seed,
'initial_physical_nodes': world.num_nodes
}

def load_state_dict(self, obj: Dict[str, Any]) -> None:
Expand Down
9 changes: 7 additions & 2 deletions streaming/base/partition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from numpy.typing import NDArray

from streaming.base.partition.orig import get_partitions_orig
from streaming.base.partition.relaxed import get_partitions_relaxed

algos = {
'orig': get_partitions_orig,
'relaxed': get_partitions_relaxed,
}


Expand All @@ -22,7 +24,8 @@ def get_partitions(algo: str,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
drop_first: int = 0) -> NDArray[np.int64]:
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.

Either canonical or physical nodes must be evenly divisible by the other.
Expand All @@ -41,11 +44,13 @@ def get_partitions(algo: str,
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.

Returns:
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
get = algos[algo]
return get(num_samples, num_canonical_nodes, num_physical_nodes, ranks_per_node,
workers_per_rank, batch_size, drop_first)
workers_per_rank, batch_size, drop_first, initial_physical_nodes)
5 changes: 4 additions & 1 deletion streaming/base/partition/orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def get_partitions_orig(num_samples: int,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
drop_first: int = 0) -> NDArray[np.int64]:
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.

Either canonical or physical nodes must be evenly divisible by the other.
Expand All @@ -37,6 +38,8 @@ def get_partitions_orig(num_samples: int,
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.

Returns:
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
Expand Down
98 changes: 98 additions & 0 deletions streaming/base/partition/relaxed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Apportion shards/samples to nodes/ranks/workers for elastically deterministic sample order."""

import logging
from typing import Optional

import numpy as np
from numpy.typing import NDArray

from streaming.base.partition.orig import get_partitions_orig

logger = logging.getLogger(__name__)


def get_partitions_relaxed(num_samples: int,
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
num_canonical_nodes: int,
num_physical_nodes: int,
ranks_per_node: int,
workers_per_rank: int,
batch_size: Optional[int] = None,
drop_first: int = 0,
initial_physical_nodes: Optional[int] = None) -> NDArray[np.int64]:
"""Partition the given number of samples to nodes, ranks, and workers.

Either canonical or physical nodes must be evenly divisible by the other when partitioning over
the initial number of physical nodes. For partitions during resumption, the only constraint
is that the global batch size, which remains constant during training, must be evenly divisible
by the total number of devices, which is num_physical_nodes * ranks_per_node.

It is suggested to set num_canonical_nodes higher than your expected number of physical nodes,
because scaling your number of nodes below that level may result in more shards being used
across node boundaries due to preserving the same global sample order.

Args:
num_samples (int): Dataset size.
num_canonical_nodes (int): Number of canonical nodes.
num_physical_nodes (int): Number of physical nodes.
ranks_per_node (int): Number of ranks per node.
workers_per_rank (int): Number of worker partitions per rank.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
drop_first (int): Number of samples seen already, which are dropped. Defaults to ``0``.
initial_physical_nodes (int, optional): Number of physical nodes at the start of training.
Defaults to ``None``.

Returns:
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
if num_samples <= drop_first:
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' +
f'({num_samples})')

if initial_physical_nodes is None or (num_physical_nodes <= num_canonical_nodes and
num_canonical_nodes % num_physical_nodes == 0) or \
(num_physical_nodes > num_canonical_nodes and
num_physical_nodes % num_canonical_nodes == 0):
# Case 1: We are partitioning for the first time. Use the original partitions algorithm,
# which also requires that NCN be divisible by PN or vice versa.
# Case 2: PN <= NCN and PN evenly divides NCN. The original partition algo can be used,
# and will give better downloads per node as well.
# Case 3: PN > NCN and NCN evenly divides PN. The original partition algo can be used.
return get_partitions_orig(num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first)
else:
batch_size = batch_size or 1
# First, make a partition over the initial number of physical nodes and device batch size.
# We assume that ranks_per_node and workers_per_rank stay constant during resumptions.
global_batch_size = num_physical_nodes * ranks_per_node * batch_size
initial_total_devices = initial_physical_nodes * ranks_per_node
# Check for divisibility of the current global batch size and the initial total devices.
# This should be true since the global batch size should not change in the middle of
# training.
if global_batch_size % initial_total_devices != 0:
raise ValueError(f'A global batch size of {global_batch_size} is not evenly ' +
f'divisible by the initial total number of devices of ' +
f'{initial_total_devices}. Make sure that when using ' +
f'the `relaxed` partitioning algorithm, the global batch size does ' +
f'not change during resumption of training.')
initial_batch_size = global_batch_size // initial_total_devices
partition = get_partitions_orig(num_samples, num_canonical_nodes, initial_physical_nodes,
ranks_per_node, workers_per_rank, initial_batch_size,
drop_first)

# Flatten the initial partition in order of traversal.
# partition was originally (nodes, ranks, workers, batches per worker, batch size)
# in-order, the dimensions are (batches per worker, workers, nodes, ranks, batch size)
partition = partition.transpose(3, 2, 0, 1, 4).flatten()

# Reshape the in-order traversal of the partition to the new physical nodes and batch size.
partition = partition.reshape(-1, workers_per_rank, num_physical_nodes, ranks_per_node,
batch_size)

# Re-transpose this partition matrix back to the original format below and return it:
# (physical nodes, ranks per node, workers per rank, batches per worker, batch size)
return partition.transpose(2, 3, 1, 0, 4)
80 changes: 78 additions & 2 deletions tests/test_partition.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

import pytest
karan6181 marked this conversation as resolved.
Show resolved Hide resolved

from streaming.base.partition import get_partitions


def test_partition_walk():
partition_algo = 'orig'
@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed'])
def test_partition_walk(partition_algo: str):
num_samples = 1000
num_canonical_nodes = 176
num_physical_nodes = 22
Expand All @@ -29,3 +31,77 @@ def test_partition_walk():
x = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first)
assert x.shape == (22, 8, 8, 1, 10)


def test_partition_relaxed_resumption():
# For global batch size 960, which is a highly divisible number, go through all possible
# values of physical nodes we can train on.
# Assuming 8 devices per node, we can train on the following numbers of nodes:
# 1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20, 24, 30, 40, 60, 120
# The initial number of physical nodes is 15, which is also num_canonical_nodes.
# Without relaxed partitioning, we can only train on the following numbers of nodes with
# deterministic resumption (with 15 == NCN == initial physical nodes):
# 1, 3, 5, 15, 30, 60, 120
# And with orig partitioning, we cannot train on the following numbers of nodes due to the NCN
# and PN divisibility constraints:
# 2, 4, 6, 8, 10, 12, 20, 24, 40

# Make initial partition with with 15 == NCN == initial physical nodes
initial_physical_nodes = 15
num_canonical_nodes = 15
global_batch_size = 960

num_samples = 10000
ranks_per_node = 8
workers_per_rank = 8
drop_first = 0
initial_batch_size = global_batch_size // (initial_physical_nodes * ranks_per_node)
# relaxed partitioning is the same as orig partitioning for the initial partition
initial_partition = get_partitions('relaxed', num_samples, num_canonical_nodes,
initial_physical_nodes, ranks_per_node, workers_per_rank,
initial_batch_size, drop_first)
# Get the inorder global batches of the initial partition
initial_partition = initial_partition.transpose(3, 2, 0, 1, 4).reshape(-1, global_batch_size)
num_initial_batches = initial_partition.shape[0]

# For each possible number of physical nodes, get the new partition and check that the inorder
# global batches are the same with relaxed partitioning.
resumption_nodes = [1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 20, 24, 30, 40, 60, 120]
for new_node_num in resumption_nodes:
new_batch_size = global_batch_size // (new_node_num * ranks_per_node)
new_partition = get_partitions('relaxed', num_samples, num_canonical_nodes, new_node_num,
ranks_per_node, workers_per_rank, new_batch_size,
drop_first, initial_physical_nodes)
# Get the inorder global batches of the new partition
new_partition = new_partition.transpose(3, 2, 0, 1, 4).reshape(-1, global_batch_size)
for batch_idx in range(num_initial_batches):
initial_samples = set(initial_partition[batch_idx])
new_samples = set(new_partition[batch_idx])
# don't check equality for batches with padding.
if -1 not in initial_samples and -1 not in new_samples:
assert initial_samples == new_samples

# For orig partitioning, test that we can only resume on a limited number of nodes.
resumption_nodes = [1, 3, 5, 15, 30, 60, 120]
for new_node_num in resumption_nodes:
new_batch_size = global_batch_size // (new_node_num * ranks_per_node)
new_partition = get_partitions('orig', num_samples, num_canonical_nodes, new_node_num,
ranks_per_node, workers_per_rank, new_batch_size,
drop_first)
# Get the inorder global batches of the new partition
new_partition = new_partition.transpose(3, 2, 0, 1, 4).reshape(-1, global_batch_size)
for batch_idx in range(num_initial_batches):
initial_samples = set(initial_partition[batch_idx])
new_samples = set(new_partition[batch_idx])
# don't check equality for batches with padding.
if -1 not in initial_samples and -1 not in new_samples:
assert initial_samples == new_samples

# For orig partitioning, test that we cannot resume on the other node values due to the NCN
# and PN divisibility constraints.
resumption_nodes = [2, 4, 6, 8, 10, 12, 20, 24, 40]
for new_node_num in resumption_nodes:
new_batch_size = global_batch_size // (new_node_num * ranks_per_node)
with pytest.raises(ValueError, match=f'Either canonical or physical nodes must be*'):
_ = get_partitions('orig', num_samples, num_canonical_nodes, new_node_num,
ranks_per_node, workers_per_rank, new_batch_size, drop_first)