From ee359eef66b34b9f96b51610245fd1b96a158b8e Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 9 Oct 2023 11:14:53 -0700 Subject: [PATCH] Print per table storage (#1427) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1427 Add per table storage (hbm, ddr) info. Reviewed By: ge0405 Differential Revision: D50025015 fbshipit-source-id: 60d8a3b33bd31b59a9d59d07c43992e17fa2a127 --- torchrec/distributed/planner/stats.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 9af64d0fe..653f236bf 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -245,6 +245,7 @@ def log( "Sharding", "Compute Kernel", "Perf (ms)", + "Storage (HBM, DDR)", "Pooling Factor", "Num Poolings", "Output", @@ -260,6 +261,7 @@ def log( "----------", "----------------", "-----------", + "--------------------", "----------------", "--------------", "--------", @@ -298,6 +300,12 @@ def log( shard_perfs = _format_perf_breakdown(so_perf) + so_storage = Storage(hbm=0, ddr=0) + for shard in so.shards: + so_storage += cast(Storage, shard.storage) + + shard_storages = _format_storage_breakdown(so_storage) + pooling_factor = str(round(sum(so.input_lengths), 3)) num_poolings = ( cast(List[float], constraints[so.name].num_poolings) @@ -325,6 +333,7 @@ def log( _get_sharding_type_abbr(so.sharding_type), so.compute_kernel, shard_perfs, + shard_storages, pooling_factor, num_poolings, output, @@ -609,6 +618,12 @@ def _format_perf_breakdown(perf: Perf) -> str: return f"{str(round(perf.total, 3))} ({breakdown_string})" +def _format_storage_breakdown(storage: Storage) -> str: + storage_hbm = round(bytes_to_gb(storage.hbm), 3) + storage_ddr = round(bytes_to_gb(storage.ddr), 3) + return f"({storage_hbm} GB, {storage_ddr} GB)" + + def round_to_one_sigfig(x: float) -> str: return f'{float(f"{x:.1g}"):g}'