diff --git a/torchrec/distributed/train_pipeline/tests/test_utils.py b/torchrec/distributed/train_pipeline/tests/test_utils.py new file mode 100644 index 000000000..56246b183 --- /dev/null +++ b/torchrec/distributed/train_pipeline/tests/test_utils.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from functools import partial +from typing import List +from unittest.mock import MagicMock + +import torch + +from torchrec.distributed.embedding_sharding import ( + FusedKJTListSplitsAwaitable, + KJTListAwaitable, + KJTListSplitsAwaitable, +) +from torchrec.distributed.embedding_types import KJTList +from torchrec.distributed.train_pipeline.utils import ( + _fuse_input_dist_splits, + TrainPipelineContext, +) +from torchrec.distributed.types import Awaitable, NoWait +from torchrec.distributed.utils import append_callback +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class TestFuseInputDist(unittest.TestCase): + def test_fuse_input_dist_splits_no_callbacks(self) -> None: + name = "ebc" + context = TrainPipelineContext() + kjt = KeyedJaggedTensor( + values=torch.tensor([1.0]), lengths=torch.tensor(1), keys=["t1"] + ) + # pyre-ignore + awaitables: List[Awaitable[Awaitable[KeyedJaggedTensor]]] = [ + NoWait(NoWait(kjt)) + ] + ebc_context = MagicMock() + context.input_dist_splits_requests[name] = KJTListSplitsAwaitable( + awaitables, ebc_context + ) + context.module_contexts_next_batch[name] = MagicMock() + + _fuse_input_dist_splits(context) + + self.assertTrue(len(context.fused_splits_awaitables)) + + def test_fuse_input_dist_splits_with_callbacks(self) -> None: + name = "ebc" + context: TrainPipelineContext = TrainPipelineContext() + kjt: KeyedJaggedTensor = KeyedJaggedTensor( + values=torch.tensor([1.0]), lengths=torch.tensor(1), keys=["t1"] + ) + + # pyre-ignore + awaitable: Awaitable[Awaitable[KeyedJaggedTensor]] = NoWait(NoWait(kjt)) + ebc_context = MagicMock() + splits_awaitable: Awaitable[Awaitable[KJTList]] = KJTListSplitsAwaitable( + [awaitable], ebc_context + ) + + # append two layer callback + def remap(kjtlist: KJTList) -> KJTList: + for kjt in kjtlist: + kjt._values += 1 + return kjtlist + + callback = partial(append_callback, callback=remap) + splits_awaitable.callbacks.append(callback) + + # test fuse input dist splits + context.input_dist_splits_requests[name] = splits_awaitable + context.module_contexts_next_batch[name] = MagicMock() + _fuse_input_dist_splits(context) + self.assertEqual(len(context.fused_splits_awaitables), 1) + + # first FusedKJTListSplitsAwaitable, and then second position in a tuple + fused_splits_awaitable: FusedKJTListSplitsAwaitable = ( + context.fused_splits_awaitables[0][1] + ) + self.assertEqual(len(fused_splits_awaitable.callbacks), 1) + + fused_awaitables: List[KJTListAwaitable] = fused_splits_awaitable.wait() + kjtlist: KJTList = fused_awaitables[0].wait() + kjt = kjtlist[0] + self.assertIsInstance(kjt, KeyedJaggedTensor) + self.assertEqual(kjt._values, torch.tensor([2.0])) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index b91a8f2c5..89af14399 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -12,6 +12,7 @@ import logging from collections import defaultdict from dataclasses import dataclass, field +from functools import partial from threading import Event, Thread from typing import ( Any, @@ -36,12 +37,14 @@ from torchrec.distributed.dist_data import KJTAllToAll from torchrec.distributed.embedding_sharding import ( FusedKJTListSplitsAwaitable, + KJTListAwaitable, KJTListSplitsAwaitable, KJTSplitsAllToAllMeta, ) from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule from torchrec.distributed.types import Awaitable +from torchrec.distributed.utils import batch_apply_callbacks from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable, Pipelineable @@ -469,6 +472,7 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: names_per_pg = defaultdict(list) for name, request in context.input_dist_splits_requests.items(): pg = None + # TODO: remove if statement which was used to appease pyre if isinstance(request, KJTListSplitsAwaitable): for awaitable in request.awaitables: if isinstance(awaitable, KJTSplitsAllToAllMeta): @@ -477,24 +481,33 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: names_per_pg[pg].append(name) for pg, names in names_per_pg.items(): + fused_splits_awaitable = FusedKJTListSplitsAwaitable( + # pyre-ignore[6] + requests=[context.input_dist_splits_requests[name] for name in names], + contexts=[ + ( + context.module_contexts_next_batch[name] + if context.version == 0 + else context.module_contexts[name] + ) + for name in names + ], + pg=pg, + ) + + splits_callbacks: List[List[Callable[[KJTListAwaitable], KJTListAwaitable]]] = [ + (context.input_dist_splits_requests[name].callbacks) for name in names + ] + + batch_callback: Callable[[List[KJTListAwaitable]], List[KJTListAwaitable]] = ( + partial(batch_apply_callbacks, list_callbacks=splits_callbacks) + ) + + fused_splits_awaitable.callbacks.append(batch_callback) context.fused_splits_awaitables.append( ( names, - FusedKJTListSplitsAwaitable( - # pyre-ignore[6] - requests=[ - context.input_dist_splits_requests[name] for name in names - ], - contexts=[ - ( - context.module_contexts_next_batch[name] - if context.version == 0 - else context.module_contexts[name] - ) - for name in names - ], - pg=pg, - ), + fused_splits_awaitable, ) ) diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 37c13f437..ff6003411 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -11,7 +11,7 @@ import logging from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType @@ -19,6 +19,7 @@ from torchrec import optim as trec_optim from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.types import ( + Awaitable, DataType, ParameterSharding, ShardedModule, @@ -28,6 +29,7 @@ from torchrec.types import CopyMixIn logger: logging.Logger = logging.getLogger(__name__) +W = TypeVar("W") _T = TypeVar("_T") """ @@ -442,3 +444,33 @@ def maybe_reset_parameters(m: nn.Module) -> None: m.reset_parameters() module.apply(maybe_reset_parameters) + + +def append_callback( + awaitable: Awaitable[W], callback: Callable[[W], W] +) -> Awaitable[W]: + """ + Utility function to append a callback to an awaitable. + """ + awaitable.callbacks.append(callback) + return awaitable + + +def apply_callbacks(ret: W, callbacks: List[Callable[[W], W]]) -> W: + """ + Apply a list of callbacks to a value. + """ + for callback in callbacks: + ret = callback(ret) + return ret + + +def batch_apply_callbacks( + rets: List[W], list_callbacks: List[List[Callable[[W], W]]] +) -> List[W]: + """ + Apply a list of callbacks to a list of values. + """ + for ret, callbacks in zip(rets, list_callbacks): + apply_callbacks(ret, callbacks) + return rets