Skip to content

Commit

Permalink
Print per table storage (pytorch#1427)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1427

Add per table storage (hbm, ddr) info.

Reviewed By: ge0405

Differential Revision: D50025015

fbshipit-source-id: 60d8a3b33bd31b59a9d59d07c43992e17fa2a127
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Oct 9, 2023
1 parent b11eba4 commit ee359ee
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def log(
"Sharding",
"Compute Kernel",
"Perf (ms)",
"Storage (HBM, DDR)",
"Pooling Factor",
"Num Poolings",
"Output",
Expand All @@ -260,6 +261,7 @@ def log(
"----------",
"----------------",
"-----------",
"--------------------",
"----------------",
"--------------",
"--------",
Expand Down Expand Up @@ -298,6 +300,12 @@ def log(

shard_perfs = _format_perf_breakdown(so_perf)

so_storage = Storage(hbm=0, ddr=0)
for shard in so.shards:
so_storage += cast(Storage, shard.storage)

shard_storages = _format_storage_breakdown(so_storage)

pooling_factor = str(round(sum(so.input_lengths), 3))
num_poolings = (
cast(List[float], constraints[so.name].num_poolings)
Expand Down Expand Up @@ -325,6 +333,7 @@ def log(
_get_sharding_type_abbr(so.sharding_type),
so.compute_kernel,
shard_perfs,
shard_storages,
pooling_factor,
num_poolings,
output,
Expand Down Expand Up @@ -609,6 +618,12 @@ def _format_perf_breakdown(perf: Perf) -> str:
return f"{str(round(perf.total, 3))} ({breakdown_string})"


def _format_storage_breakdown(storage: Storage) -> str:
storage_hbm = round(bytes_to_gb(storage.hbm), 3)
storage_ddr = round(bytes_to_gb(storage.ddr), 3)
return f"({storage_hbm} GB, {storage_ddr} GB)"


def round_to_one_sigfig(x: float) -> str:
return f'{float(f"{x:.1g}"):g}'

Expand Down

0 comments on commit ee359ee

Please sign in to comment.