diff --git a/torchrec/distributed/planner/constants.py b/torchrec/distributed/planner/constants.py index f95b1956b..b6abb8dfd 100644 --- a/torchrec/distributed/planner/constants.py +++ b/torchrec/distributed/planner/constants.py @@ -42,6 +42,7 @@ def kernel_bw_lookup( hbm_mem_bw: float, ddr_mem_bw: float, caching_ratio: Optional[float] = None, + prefetch_pipeline: bool = False, ) -> Optional[float]: """ Calculates the device bandwidth based on given compute device, compute kernel, and @@ -54,6 +55,7 @@ def kernel_bw_lookup( ddr_mem_bw (float): the bandwidth of the system DDR memory. caching_ratio (Optional[float]): caching ratio used to determine device bandwidth if UVM caching is enabled. + prefetch_pipeline (bool): whether prefetch pipeline is enabled. Returns: Optional[float]: the device bandwidth. @@ -84,4 +86,12 @@ def kernel_bw_lookup( ) / 10, } + + if ( + prefetch_pipeline + and compute_device == "cuda" + and compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value + ): + return lookup.get(("cuda", EmbeddingComputeKernel.FUSED.value)) + return lookup.get((compute_device, compute_kernel)) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index fcad7fb63..729cff9f2 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -148,6 +148,23 @@ def estimate( bwd_sr_comm_data_type_size, ) = _extract_comm_data_type_size(sharder, sharding_option) + prefetch_pipeline = ( + self._constraints[ # pyre-ignore[16] + sharding_option.name + ].cache_params.prefetch_pipeline + if self._constraints + and self._constraints.get(sharding_option.name) + and self._constraints[sharding_option.name].cache_params + else False + ) + # TODO: remove after deprecating fused_params in sharder + if prefetch_pipeline is False: + prefetch_pipeline = ( + sharder.fused_params.get("prefetch_pipeline", False) + if hasattr(sharder, "fused_params") and sharder.fused_params + else False + ) + shard_perfs = perf_func_emb_wall_time( shard_sizes=[shard.size for shard in sharding_option.shards], compute_kernel=sharding_option.compute_kernel, @@ -172,6 +189,7 @@ def estimate( is_weighted=is_weighted, is_inference=self._is_inference, caching_ratio=caching_ratio, + prefetch_pipeline=prefetch_pipeline, ) for shard, perf in zip(sharding_option.shards, shard_perfs): @@ -202,6 +220,7 @@ def perf_func_emb_wall_time( is_weighted: bool = False, caching_ratio: Optional[float] = None, is_inference: bool = False, + prefetch_pipeline: bool = False, ) -> List[Perf]: """ Attempts to model perfs as a function of relative wall times. @@ -236,6 +255,7 @@ def perf_func_emb_wall_time( is_inference (bool = False): if planning for inference. caching_ratio (Optional[float] = None): cache ratio to determine the bandwidth of device. + prefetch_pipeline (bool = False): whether prefetch pipeline is enabled. Returns: List[float]: the list of perf for each shard. @@ -243,7 +263,12 @@ def perf_func_emb_wall_time( shard_perfs = [] device_bw = kernel_bw_lookup( - compute_device, compute_kernel, hbm_mem_bw, ddr_mem_bw, caching_ratio + compute_device, + compute_kernel, + hbm_mem_bw, + ddr_mem_bw, + caching_ratio, + prefetch_pipeline, ) if device_bw is None: raise PlannerError( diff --git a/torchrec/distributed/planner/tests/test_constants.py b/torchrec/distributed/planner/tests/test_constants.py new file mode 100644 index 000000000..1b9b91a64 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_constants.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import List, Optional + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner.constants import ( + DDR_MEM_BW, + HBM_MEM_BW, + kernel_bw_lookup, +) + + +class TestKernelBWLookup(unittest.TestCase): + def test_uvm_caching_bw(self) -> None: + compute_device: str = "cuda" + computer_kernel: str = EmbeddingComputeKernel.FUSED_UVM_CACHING.value + + caching_ratios: List[float] = [0, 0.25, 0.5, 0.75, 1] + + uvm_caching_bw: list[Optional[float]] = [ + kernel_bw_lookup( + compute_device, computer_kernel, HBM_MEM_BW, DDR_MEM_BW, caching_ratio + ) + for caching_ratio in caching_ratios + ] + expected_uvm_caching_bw: List[float] = [ + 23643794.96448, + 28185722.880000003, + 50895362.457600005, + 73605002.0352, + 96314641.6128, + ] + + self.assertEqual(expected_uvm_caching_bw, uvm_caching_bw) + + def test_uvm_caching_bw_with_prefetch_pipeline(self) -> None: + compute_device: str = "cuda" + computer_kernel: str = EmbeddingComputeKernel.FUSED_UVM_CACHING.value + prefetch_pipeline: bool = True + + caching_ratios: List[float] = [0, 0.25, 0.5, 0.75, 1] + + uvm_caching_bw: list[Optional[float]] = [ + kernel_bw_lookup( + compute_device, + computer_kernel, + HBM_MEM_BW, + DDR_MEM_BW, + caching_ratio, + prefetch_pipeline, + ) + for caching_ratio in caching_ratios + ] + print(f"henry uvm_caching_bw {uvm_caching_bw}") + expected_uvm_caching_bw: List[float] = [ + 963146416.128, + 963146416.128, + 963146416.128, + 963146416.128, + 963146416.128, + ] + + self.assertEqual(expected_uvm_caching_bw, uvm_caching_bw)