Skip to content

Commit

Permalink
Use hbm bw for uvm caching when we are using prefetching
Browse files Browse the repository at this point in the history
Summary:
Choose to use hbm perf when we are using prefetch_pipeline.

Underlying assumption is that we have correct per table cache_load_factor when using prefetch_pipeline.

Re the underlying assumption, there is no way the planner can tell if the cache_load_factor is good or not, since it doesn't have access to the index distribution, and pooling_factor alone is not enough.

Differential Revision: D49084591
  • Loading branch information
hlhtsang authored and facebook-github-bot committed Oct 10, 2023
1 parent 3870c59 commit 3f9b14e
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
10 changes: 10 additions & 0 deletions torchrec/distributed/planner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def kernel_bw_lookup(
hbm_mem_bw: float,
ddr_mem_bw: float,
caching_ratio: Optional[float] = None,
prefetch_pipeline: bool = False,
) -> Optional[float]:
"""
Calculates the device bandwidth based on given compute device, compute kernel, and
Expand All @@ -54,6 +55,7 @@ def kernel_bw_lookup(
ddr_mem_bw (float): the bandwidth of the system DDR memory.
caching_ratio (Optional[float]): caching ratio used to determine device bandwidth
if UVM caching is enabled.
prefetch_pipeline (bool): whether prefetch pipeline is enabled.
Returns:
Optional[float]: the device bandwidth.
Expand Down Expand Up @@ -84,4 +86,12 @@ def kernel_bw_lookup(
)
/ 10,
}

if (
prefetch_pipeline
and compute_device == "cuda"
and compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value
):
return lookup.get(("cuda", EmbeddingComputeKernel.FUSED.value))

return lookup.get((compute_device, compute_kernel))
27 changes: 26 additions & 1 deletion torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@ def estimate(
bwd_sr_comm_data_type_size,
) = _extract_comm_data_type_size(sharder, sharding_option)

prefetch_pipeline = (
self._constraints[ # pyre-ignore[16]
sharding_option.name
].cache_params.prefetch_pipeline
if self._constraints
and self._constraints.get(sharding_option.name)
and self._constraints[sharding_option.name].cache_params
else False
)
# TODO: remove after deprecating fused_params in sharder
if prefetch_pipeline is False:
prefetch_pipeline = (
sharder.fused_params.get("prefetch_pipeline", False)
if hasattr(sharder, "fused_params") and sharder.fused_params
else False
)

shard_perfs = perf_func_emb_wall_time(
shard_sizes=[shard.size for shard in sharding_option.shards],
compute_kernel=sharding_option.compute_kernel,
Expand All @@ -172,6 +189,7 @@ def estimate(
is_weighted=is_weighted,
is_inference=self._is_inference,
caching_ratio=caching_ratio,
prefetch_pipeline=prefetch_pipeline,
)

for shard, perf in zip(sharding_option.shards, shard_perfs):
Expand Down Expand Up @@ -202,6 +220,7 @@ def perf_func_emb_wall_time(
is_weighted: bool = False,
caching_ratio: Optional[float] = None,
is_inference: bool = False,
prefetch_pipeline: bool = False,
) -> List[Perf]:
"""
Attempts to model perfs as a function of relative wall times.
Expand Down Expand Up @@ -236,14 +255,20 @@ def perf_func_emb_wall_time(
is_inference (bool = False): if planning for inference.
caching_ratio (Optional[float] = None): cache ratio to determine the bandwidth
of device.
prefetch_pipeline (bool = False): whether prefetch pipeline is enabled.
Returns:
List[float]: the list of perf for each shard.
"""

shard_perfs = []
device_bw = kernel_bw_lookup(
compute_device, compute_kernel, hbm_mem_bw, ddr_mem_bw, caching_ratio
compute_device,
compute_kernel,
hbm_mem_bw,
ddr_mem_bw,
caching_ratio,
prefetch_pipeline,
)
if device_bw is None:
raise PlannerError(
Expand Down
69 changes: 69 additions & 0 deletions torchrec/distributed/planner/tests/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import List, Optional

from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner.constants import (
DDR_MEM_BW,
HBM_MEM_BW,
kernel_bw_lookup,
)


class TestKernelBWLookup(unittest.TestCase):
def test_uvm_caching_bw(self) -> None:
compute_device: str = "cuda"
computer_kernel: str = EmbeddingComputeKernel.FUSED_UVM_CACHING.value

caching_ratios: List[float] = [0, 0.25, 0.5, 0.75, 1]

uvm_caching_bw: list[Optional[float]] = [
kernel_bw_lookup(
compute_device, computer_kernel, HBM_MEM_BW, DDR_MEM_BW, caching_ratio
)
for caching_ratio in caching_ratios
]
expected_uvm_caching_bw: List[float] = [
23643794.96448,
28185722.880000003,
50895362.457600005,
73605002.0352,
96314641.6128,
]

self.assertEqual(expected_uvm_caching_bw, uvm_caching_bw)

def test_uvm_caching_bw_with_prefetch_pipeline(self) -> None:
compute_device: str = "cuda"
computer_kernel: str = EmbeddingComputeKernel.FUSED_UVM_CACHING.value
prefetch_pipeline: bool = True

caching_ratios: List[float] = [0, 0.25, 0.5, 0.75, 1]

uvm_caching_bw: list[Optional[float]] = [
kernel_bw_lookup(
compute_device,
computer_kernel,
HBM_MEM_BW,
DDR_MEM_BW,
caching_ratio,
prefetch_pipeline,
)
for caching_ratio in caching_ratios
]
print(f"henry uvm_caching_bw {uvm_caching_bw}")
expected_uvm_caching_bw: List[float] = [
963146416.128,
963146416.128,
963146416.128,
963146416.128,
963146416.128,
]

self.assertEqual(expected_uvm_caching_bw, uvm_caching_bw)

0 comments on commit 3f9b14e

Please sign in to comment.