Skip to content

Commit

Permalink
RW Dist change to support uneven sharding (pytorch#1525)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1525

X-link: pytorch/FBGEMM#2142

As titled

Reviewed By: dstaay-fb

Differential Revision: D51053207

fbshipit-source-id: 8aee7d967ceec9ea0739f5d2c56d6e541a0d3648
  • Loading branch information
gnahzg authored and facebook-github-bot committed Nov 18, 2023
1 parent 4ab5259 commit 873faf3
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 15 deletions.
3 changes: 3 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def bucketize_kjt_before_all2all(
block_sizes: torch.Tensor,
output_permute: bool = False,
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]:
"""
Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets,
Expand All @@ -78,6 +79,7 @@ def bucketize_kjt_before_all2all(
values to bucketized values or not.
bucketize_pos (bool): output the changed position of the bucketized values or
not.
block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature.
Returns:
Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: the bucketized `KeyedJaggedTensor` and the optional permute mapping from the unbucketized values to bucketized value.
Expand All @@ -103,6 +105,7 @@ def bucketize_kjt_before_all2all(
block_sizes=block_sizes_new_type,
my_size=num_buckets,
weights=kjt.weights_or_none(),
block_bucketize_pos=block_bucketize_row_pos,
)

return (
Expand Down
82 changes: 73 additions & 9 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional, TypeVar
import math
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -55,6 +56,40 @@
W = TypeVar("W")


def get_embedding_shard_metadata(
grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]],
) -> Tuple[List[List[int]], bool]:
is_even_sharding: bool = True
world_size = len(grouped_embedding_configs_per_rank)

def get_even_shard_sizes(hash_size: int, world_size: int) -> List[int]:
block_size: int = math.ceil(hash_size / world_size)
last_rank: int = hash_size // block_size

expected_even_shard_sizes = [block_size] * last_rank
if hash_size % world_size != 0:
expected_even_shard_sizes.append(hash_size - sum(expected_even_shard_sizes))
return expected_even_shard_sizes

embed_sharding = []
for table in grouped_embedding_configs_per_rank[0][0].embedding_tables:
embed_sharding_per_feature = []
total_rows = 0
sizes = []
# pyre-ignore [16]: `Optional` has no attribute `shards_metadata`
for metadata in table.global_metadata.shards_metadata:
embed_sharding_per_feature.append(metadata.shard_offsets[0])
total_rows += metadata.shard_sizes[0]
sizes.append(metadata.shard_sizes[0])
embed_sharding_per_feature.append(total_rows)
embed_sharding.extend([embed_sharding_per_feature] * len(table.embedding_names))
expected_even_sizes = get_even_shard_sizes(total_rows, world_size)
if sizes != expected_even_sizes:
is_even_sharding = False

return (embed_sharding, is_even_sharding)


class BaseRwEmbeddingSharding(EmbeddingSharding[C, F, T, W]):
"""
Base class for row-wise sharding.
Expand Down Expand Up @@ -428,14 +463,27 @@ def create_output_dist(
def get_block_sizes_runtime_device(
block_sizes: List[int],
runtime_device: torch.device,
tensor_cache: Dict[str, torch.Tensor],
) -> torch.Tensor:
tensor_cache: Dict[str, Tuple[torch.Tensor, List[torch.Tensor]]],
embedding_shard_metadata: Optional[List[List[int]]],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
cache_key: str = "__block_sizes"
if cache_key not in tensor_cache:
tensor_cache[cache_key] = torch.tensor(
block_sizes,
device=runtime_device,
dtype=torch.int32,
tensor_cache[cache_key] = (
torch.tensor(
block_sizes,
device=runtime_device,
dtype=torch.int32,
),
[]
if embedding_shard_metadata is None
else [
torch.tensor(
row_pos,
device=runtime_device,
dtype=torch.int32,
)
for row_pos in embedding_shard_metadata
],
)

return tensor_cache[cache_key]
Expand All @@ -451,6 +499,7 @@ def __init__(
is_sequence: bool = False,
has_feature_processor: bool = False,
need_pos: bool = False,
embedding_shard_metadata: Optional[List[List[int]]] = None,
) -> None:
super().__init__()
self._world_size: int = world_size
Expand All @@ -459,7 +508,10 @@ def __init__(
(hash_size + self._world_size - 1) // self._world_size
for hash_size in feature_hash_sizes
]
self.tensor_cache: Dict[str, torch.Tensor] = {}
self.tensor_cache: Dict[
str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
] = {}

self._dist = KJTOneToAll(
splits=self._world_size * [self._num_features],
world_size=world_size,
Expand All @@ -470,14 +522,19 @@ def __init__(
self._need_pos = need_pos
self.unbucketize_permute_tensor: Optional[torch.Tensor] = None

self._embedding_shard_metadata: Optional[
List[List[int]]
] = embedding_shard_metadata

def forward(
self,
sparse_features: KeyedJaggedTensor,
) -> KJTList:
block_sizes = get_block_sizes_runtime_device(
block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device(
self.feature_block_sizes,
sparse_features.device(),
self.tensor_cache,
self._embedding_shard_metadata,
)
(
bucketized_features,
Expand All @@ -490,6 +547,7 @@ def forward(
bucketize_pos=self._has_feature_processor
if sparse_features.weights_or_none() is None
else self._need_pos,
block_bucketize_row_pos=block_bucketize_row_pos,
)
return self._dist.forward(bucketized_features)

Expand All @@ -505,11 +563,17 @@ def create_input_dist(
) -> BaseSparseFeaturesDist[KJTList]:
num_features = self._get_num_features()
feature_hash_sizes = self._get_feature_hash_sizes()

(embed_sharding, is_even_sharding) = get_embedding_shard_metadata(
self._grouped_embedding_configs_per_rank
)

return InferRwSparseFeaturesDist(
world_size=self._world_size,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
embedding_shard_metadata=embed_sharding if not is_even_sharding else None,
)

def create_lookup(
Expand Down
156 changes: 150 additions & 6 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@
from torchrec.fx import symbolic_trace


def placement_helper(device_type: str, index: int = 0) -> str:
if device_type == "cpu":
return f"rank:0/{device_type}" # cpu only use rank 0

return f"rank:{index}/{device_type}:{index}"


class InferShardingsTest(unittest.TestCase):
@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down Expand Up @@ -702,6 +709,7 @@ def test_rw_sequence(self, weight_dtype: torch.dtype) -> None:
ShardingType.ROW_WISE.value,
)

@unittest.skip("Need D51309697 before turn on")
@unittest.skipIf(
torch.cuda.device_count() <= 2,
"Not enough GPUs available",
Expand All @@ -728,7 +736,7 @@ def test_rw_uneven_sharding(
local_size = 3
world_size = 3
batch_size = 4
local_device = torch.device("cuda:0")
local_device = torch.device("cpu")
mi = create_test_model(
num_embeddings,
emb_dim,
Expand Down Expand Up @@ -792,11 +800,147 @@ def test_rw_uneven_sharding(
sharded_model.load_state_dict(non_sharded_model.state_dict())

# We need this first inference to make all lazy init in forward
_ = sharded_model(*inputs[0])
_ = non_sharded_model(*inputs[0])
# TODO (drqiangzhang): Add comparison between sharded and unsharded model outputs
sharded_output = sharded_model(*inputs[0])
non_sharded_output = non_sharded_model(*inputs[0])
assert_close(non_sharded_output, sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skip("Need D51309697 before turn on")
@unittest.skipIf(
torch.cuda.device_count() <= 3,
"Not enough GPUs available",
)
# pyre-fixme[56]Pyre was not able to infer the type of argument `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`.
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
)
@settings(max_examples=4, deadline=None)
def test_rw_uneven_sharding_mutiple_table(
self,
weight_dtype: torch.dtype,
) -> None:
num_embeddings = 512
emb_dim = 64
local_size = 4
world_size = 4
batch_size = 1
local_device = torch.device("cpu")
mi = create_test_model(
num_embeddings,
emb_dim,
world_size,
batch_size,
dense_device=local_device,
sparse_device=local_device,
quant_state_dict_split_scale_bias=True,
weight_dtype=weight_dtype,
num_features=3,
)

non_sharded_model = mi.quant_model
expected_shards = [
[
(
(0, 0, 256, 64),
placement_helper("cpu", 0),
),
(
(256, 0, 128, 64),
placement_helper("cpu", 0),
),
(
(384, 0, 64, 64),
placement_helper("cpu", 0),
),
(
(448, 0, 64, 64),
placement_helper("cpu", 0),
),
],
[
(
(0, 0, 128, 64),
placement_helper("cpu", 0),
),
(
(128, 0, 128, 64),
placement_helper("cpu", 0),
),
(
(256, 0, 128, 64),
placement_helper("cpu", 0),
),
(
(384, 0, 128, 64),
placement_helper("cpu", 0),
),
],
[
(
(0, 0, 256, 64),
placement_helper("cpu", 0),
),
(
(256, 0, 128, 64),
placement_helper("cpu", 0),
),
(
(384, 0, 128, 64),
placement_helper("cpu", 0),
),
(
(512, 0, 0, 64),
placement_helper("cpu", 0),
),
],
]

sharder = TestQuantEBCSharder(
sharding_type=ShardingType.ROW_WISE.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in mi.tables],
)

module_plan = construct_module_sharding_plan(
non_sharded_model._module.sparse.ebc,
per_param_sharding={
"table_0": row_wise(
([256, 128, 64, 64], "cpu"),
),
"table_1": row_wise(([128, 128, 128, 128], "cpu")),
"table_2": row_wise(([256, 128, 128, 0], "cpu")),
},
# pyre-ignore
sharder=sharder,
local_size=local_size,
world_size=world_size,
)

plan = ShardingPlan(plan={"_module.sparse.ebc": module_plan})

sharded_model = shard_qebc(
mi=mi,
sharding_type=ShardingType.ROW_WISE,
device=local_device,
expected_shards=expected_shards,
plan=plan,
)
inputs = [
model_input_to_forward_args(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]
sharded_model.load_state_dict(non_sharded_model.state_dict())

# We need this first inference to make all lazy init in forward
sharded_output = sharded_model(*inputs[0])
non_sharded_output = non_sharded_model(*inputs[0])
assert_close(non_sharded_output, sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
_ = gm_script(*inputs[0])
# TODO (drqiangzhang): Add comparison between scripted and nonscripted model outputs
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

0 comments on commit 873faf3

Please sign in to comment.