Skip to content

Commit

Permalink
Prefetch for ITEP v8
Browse files Browse the repository at this point in the history
Summary:
design doc: https://docs.google.com/document/d/17_nqdEtH6B_ev9Gnuw2mpgtFq4dqzC6-XUdw18R4F8Q/

# changes
## in ITEP
make remap a callback of input dist

## in torchrec
copy callbacks from input dist to fused input dist

# impact
For non-ITEP modules, this should not change anything

Differential Revision: D57012790
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed May 10, 2024
1 parent 65f973e commit 063c8a6
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 16 deletions.
92 changes: 92 additions & 0 deletions torchrec/distributed/train_pipeline/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
from functools import partial
from typing import List
from unittest.mock import MagicMock

import torch

from torchrec.distributed.embedding_sharding import (
FusedKJTListSplitsAwaitable,
KJTListAwaitable,
KJTListSplitsAwaitable,
)
from torchrec.distributed.embedding_types import KJTList
from torchrec.distributed.train_pipeline.utils import (
_fuse_input_dist_splits,
TrainPipelineContext,
)
from torchrec.distributed.types import Awaitable, NoWait
from torchrec.distributed.utils import append_callback
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


class TestFuseInputDist(unittest.TestCase):
def test_fuse_input_dist_splits_no_callbacks(self) -> None:
name = "ebc"
context = TrainPipelineContext()
kjt = KeyedJaggedTensor(
values=torch.tensor([1.0]), lengths=torch.tensor(1), keys=["t1"]
)
# pyre-ignore
awaitables: List[Awaitable[Awaitable[KeyedJaggedTensor]]] = [
NoWait(NoWait(kjt))
]
ebc_context = MagicMock()
context.input_dist_splits_requests[name] = KJTListSplitsAwaitable(
awaitables, ebc_context
)
context.module_contexts_next_batch[name] = MagicMock()

_fuse_input_dist_splits(context)

self.assertTrue(len(context.fused_splits_awaitables))

def test_fuse_input_dist_splits_with_callbacks(self) -> None:
name = "ebc"
context: TrainPipelineContext = TrainPipelineContext()
kjt: KeyedJaggedTensor = KeyedJaggedTensor(
values=torch.tensor([1.0]), lengths=torch.tensor(1), keys=["t1"]
)

# pyre-ignore
awaitable: Awaitable[Awaitable[KeyedJaggedTensor]] = NoWait(NoWait(kjt))
ebc_context = MagicMock()
splits_awaitable: Awaitable[Awaitable[KJTList]] = KJTListSplitsAwaitable(
[awaitable], ebc_context
)

# append two layer callback
def remap(kjtlist: KJTList) -> KJTList:
for kjt in kjtlist:
kjt._values += 1
return kjtlist

callback = partial(append_callback, callback=remap)
splits_awaitable.callbacks.append(callback)

# test fuse input dist splits
context.input_dist_splits_requests[name] = splits_awaitable
context.module_contexts_next_batch[name] = MagicMock()
_fuse_input_dist_splits(context)
self.assertEqual(len(context.fused_splits_awaitables), 1)

# first FusedKJTListSplitsAwaitable, and then second position in a tuple
fused_splits_awaitable: FusedKJTListSplitsAwaitable = (
context.fused_splits_awaitables[0][1]
)
self.assertEqual(len(fused_splits_awaitable.callbacks), 1)

fused_awaitables: List[KJTListAwaitable] = fused_splits_awaitable.wait()
kjtlist: KJTList = fused_awaitables[0].wait()
kjt = kjtlist[0]
self.assertIsInstance(kjt, KeyedJaggedTensor)
self.assertEqual(kjt._values, torch.tensor([2.0]))
43 changes: 28 additions & 15 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
from threading import Event, Thread
from typing import (
Any,
Expand All @@ -36,12 +37,14 @@
from torchrec.distributed.dist_data import KJTAllToAll
from torchrec.distributed.embedding_sharding import (
FusedKJTListSplitsAwaitable,
KJTListAwaitable,
KJTListSplitsAwaitable,
KJTSplitsAllToAllMeta,
)
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule

from torchrec.distributed.types import Awaitable
from torchrec.distributed.utils import batch_apply_callbacks

from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable, Pipelineable
Expand Down Expand Up @@ -469,6 +472,7 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None:
names_per_pg = defaultdict(list)
for name, request in context.input_dist_splits_requests.items():
pg = None
# TODO: remove if statement which was used to appease pyre
if isinstance(request, KJTListSplitsAwaitable):
for awaitable in request.awaitables:
if isinstance(awaitable, KJTSplitsAllToAllMeta):
Expand All @@ -477,24 +481,33 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None:
names_per_pg[pg].append(name)

for pg, names in names_per_pg.items():
fused_splits_awaitable = FusedKJTListSplitsAwaitable(
# pyre-ignore[6]
requests=[context.input_dist_splits_requests[name] for name in names],
contexts=[
(
context.module_contexts_next_batch[name]
if context.version == 0
else context.module_contexts[name]
)
for name in names
],
pg=pg,
)

splits_callbacks: List[List[Callable[[KJTListAwaitable], KJTListAwaitable]]] = [
(context.input_dist_splits_requests[name].callbacks) for name in names
]

batch_callback: Callable[[List[KJTListAwaitable]], List[KJTListAwaitable]] = (
partial(batch_apply_callbacks, list_callbacks=splits_callbacks)
)

fused_splits_awaitable.callbacks.append(batch_callback)
context.fused_splits_awaitables.append(
(
names,
FusedKJTListSplitsAwaitable(
# pyre-ignore[6]
requests=[
context.input_dist_splits_requests[name] for name in names
],
contexts=[
(
context.module_contexts_next_batch[name]
if context.version == 0
else context.module_contexts[name]
)
for name in names
],
pg=pg,
),
fused_splits_awaitable,
)
)

Expand Down
34 changes: 33 additions & 1 deletion torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import logging

from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union

import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from torch import nn
from torchrec import optim as trec_optim
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import (
Awaitable,
DataType,
ParameterSharding,
ShardedModule,
Expand All @@ -28,6 +29,7 @@
from torchrec.types import CopyMixIn

logger: logging.Logger = logging.getLogger(__name__)
W = TypeVar("W")
_T = TypeVar("_T")

"""
Expand Down Expand Up @@ -442,3 +444,33 @@ def maybe_reset_parameters(m: nn.Module) -> None:
m.reset_parameters()

module.apply(maybe_reset_parameters)


def append_callback(
awaitable: Awaitable[W], callback: Callable[[W], W]
) -> Awaitable[W]:
"""
Utility function to append a callback to an awaitable.
"""
awaitable.callbacks.append(callback)
return awaitable


def apply_callbacks(ret: W, callbacks: List[Callable[[W], W]]) -> W:
"""
Apply a list of callbacks to a value.
"""
for callback in callbacks:
ret = callback(ret)
return ret


def batch_apply_callbacks(
rets: List[W], list_callbacks: List[List[Callable[[W], W]]]
) -> List[W]:
"""
Apply a list of callbacks to a list of values.
"""
for ret, callbacks in zip(rets, list_callbacks):
apply_callbacks(ret, callbacks)
return rets

0 comments on commit 063c8a6

Please sign in to comment.