Skip to content

Commit

Permalink
Add GPU trace for KT.regroup benchmark (pytorch#2157)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2157

# context
* we are adding fbgemm operators for the KT.regroup function.
* we wanted a good way to measure the performance beside the runtime
* **trace is very important to evaluate the actual performance impact**
* for example, just from the GPU runtime readings, it seems like the native-pytorch implementation (`_regroup_keyed_tenors`) has better performance over the fbgemm_gpu implementation (`KeyedTensor.regroup`)
* but if we look at the CPU/GPU traces, we'll find that the native-pytorch implementation is actually CPU-bounded, and has very bad impact on the overall performance.

# usage
* to generate trace file in the given path (.)
```
buck2 run fbcode//mode/opt fbcode//torchrec/sparse/tests:jagged_tensor_benchmark -- --profile=.
```
```
$ ll *.json
-rw-rw-r-- 1 hhy hhy 8062963 Jun 21 22:21 trace-KeyedTensor.regroup_dup.json
-rw-rw-r-- 1 hhy hhy  943675 Jun 21 22:21 trace-KeyedTensor.regroup.json
-rw-rw-r-- 1 hhy hhy 5140105 Jun 21 22:21 trace-KTRegroupAsDict_dup.json
-rw-rw-r-- 1 hhy hhy  350349 Jun 21 22:21 trace-KTRegroupAsDict.json
-rw-rw-r-- 1 hhy hhy 8025287 Jun 21 22:21 trace-_regroup_keyed_tenors_dup.json
-rw-rw-r-- 1 hhy hhy 8041473 Jun 21 22:21 trace-_regroup_keyed_tenors.json
```

# performance
* GPU (notes: w/ dup falls back to native-pytorch implementation (`_regroup_keyed_tenors`))
```
INFO:2024-06-21 22:22:51 1102779:1102779 CuptiCallbackApi.cpp:78] Callback: domain = 3, cbid = 1
INFO:2024-06-21 22:22:51 1102779:1102779 CuptiActivityProfiler.cpp:241] CUDA versions. CUPTI: 18; Runtime: 12000; Driver: 12000
INFO:2024-06-21 22:22:51 1102779:1102779 NcclProfiler.cpp:150] NCCL Profiler Instantiated
  _regroup_keyed_tenors               | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   2.8 ms | Memory (P90): 1011.0
  KeyedTensor.regroup                 | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   5.0 ms | Memory (P90): 1517.0
  KTRegroupAsDict                     | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   4.9 ms | Memory (P90): 1517.0
  _regroup_keyed_tenors_dup           | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   2.5 ms | Memory (P90): 1011.0
  KeyedTensor.regroup_dup             | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   2.5 ms | Memory (P90): 1011.0
  KTRegroupAsDict_dup                 | B: 1024     | F: 1020     | device: cuda     | Runtime (P90):   2.5 ms | Memory (P90): 1011.0
```
* CPU
```
  _regroup_keyed_tenors               | B: 1024     | F: 1020     | device: cpu      | Runtime (P90): 144.8 ms | Memory (P90):   0.0
  KeyedTensor.regroup                 | B: 1024     | F: 1020     | device: cpu      | Runtime (P90): 159.1 ms | Memory (P90):   0.0
  KTRegroupAsDict                     | B: 1024     | F: 1020     | device: cpu      | Runtime (P90): 203.0 ms | Memory (P90):   0.0
  _regroup_keyed_tenors_dup           | B: 1024     | F: 1020     | device: cpu      | Runtime (P90): 132.4 ms | Memory (P90):   0.0
  KeyedTensor.regroup_dup             | B: 1024     | F: 1020     | device: cpu      | Runtime (P90): 134.7 ms | Memory (P90):   0.0
  KTRegroupAsDict_dup                 | B: 1024     | F: 1020     | device: cpu      | Runtime (P90): 131.8 ms | Memory (P90):   0.0
```
# traces
* _regroup_keyed_tenors
 {F1712147044}
* KeyedTensor.regroup
 {F1712148863}
* KTRegroupAsDict
 {F1712150411}

Reviewed By: dstaay-fb

Differential Revision: D58906521

fbshipit-source-id: 46e37184cd58c0f25e48112510388de9bd39ac71
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 25, 2024
1 parent 03c3a72 commit 704afbe
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 59 deletions.
8 changes: 4 additions & 4 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
def benchmark(
name: str,
model: torch.nn.Module,
warmup_inputs: List[KeyedJaggedTensor],
bench_inputs: List[KeyedJaggedTensor],
prof_inputs: List[KeyedJaggedTensor],
warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
world_size: int,
output_dir: str,
num_benchmarks: int,
Expand Down Expand Up @@ -558,7 +558,7 @@ def benchmark(
[si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])]
)
else:
elapsed_time = torch.tensor(times)
elapsed_time = torch.tensor(times) * 1e3

if device_type == "cuda":
if rank == -1:
Expand Down
124 changes: 69 additions & 55 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def bench(
run_backward: bool,
fn: Callable[..., List[torch.Tensor]],
fn_kwargs: Dict[str, Any],
output_dir: str = "",
) -> None:

# initial call
Expand All @@ -49,8 +50,8 @@ def wrapped_func(
model: torch.nn.Module, # not used
bench_inputs: List[KeyedJaggedTensor], # not used
fn: Callable[..., List[torch.Tensor]],
fn_kwargs: Dict[str, Any],
run_backward: bool,
**kwargs: Dict[str, Any],
) -> None:
result = fn(**fn_kwargs)
if run_backward:
Expand All @@ -64,26 +65,28 @@ def wrapped_func(
loss = torch.nn.functional.l1_loss(pred, labels)
loss.sum().backward()

model = DummyModel()
setattr(model, "forward", lambda kwargs: fn(**kwargs))
prof_num = 10
if device_type == "cuda":
result = benchmark(
name=name,
model=DummyModel(),
model=model,
warmup_inputs=[],
bench_inputs=[],
prof_inputs=[],
prof_inputs=[fn_kwargs] * prof_num,
world_size=1,
output_dir="",
output_dir=output_dir,
num_benchmarks=20,
func_to_benchmark=functools.partial(
wrapped_func, fn=fn, run_backward=run_backward, fn_kwargs=fn_kwargs
),
benchmark_func_kwargs={},
rank=0,
enable_logging=False,
enable_logging=True,
)

else: # cpu
model = DummyModel()
times = timeit.repeat(
lambda: wrapped_func(
model=model,
Expand All @@ -97,7 +100,7 @@ def wrapped_func(
)
result = BenchmarkResult(
short_name=name,
elapsed_time=torch.tensor(times),
elapsed_time=torch.tensor(times) * 1e3,
max_mem_allocated=[0],
)

Expand Down Expand Up @@ -160,6 +163,12 @@ def wrapped_func(
default=2,
help="Total num of regrouping",
)
@click.option(
"--profile",
type=str,
default="",
help="profile output directory",
)
def main(
cuda_matrix: bool,
run_backward: bool,
Expand All @@ -170,6 +179,7 @@ def main(
dim_sparse: int,
batch_size: int,
n_groups: int,
profile: str,
) -> None:
if cuda_matrix:
n_denses = [64, 128, 256, 512, 1024]
Expand All @@ -184,54 +194,58 @@ def main(

for device_type in device_types:
for batch_size in batch_sizes:
for n_dense, n_sparse in zip(n_denses, n_sparses):

device = torch.device(device_type)
kts = build_kts(
n_dense,
n_sparse,
dim_dense,
dim_sparse,
batch_size,
device,
run_backward,
)
labels = torch.randint(
0, 1, (batch_size,), device=torch.device(device_type)
).float()
groups = build_groups(kts, n_groups)
bench(
"[fallback] _regroup_keyed_tenors",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_regroup_keyed_tensors,
{"keyed_tensors": kts, "groups": groups},
)
bench(
"[prod] KeyedTensor.regroup",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KeyedTensor.regroup,
{"keyed_tensors": kts, "groups": groups},
)
bench(
"[prod] KTRegroupAsDict",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KTRegroupAsDict(
groups=groups, keys=[str(i) for i in range(n_groups)]
),
{"keyed_tensors": kts},
)
for duplicates in [False, True]:
for n_dense, n_sparse in zip(n_denses, n_sparses):
dup = "_dup" if duplicates else ""
device = torch.device(device_type)
kts = build_kts(
n_dense,
n_sparse,
dim_dense,
dim_sparse,
batch_size,
device,
run_backward,
)
labels = torch.randint(
0, 1, (batch_size,), device=torch.device(device_type)
).float()
groups = build_groups(kts, n_groups, duplicates=duplicates)
bench(
"_regroup_keyed_tenors" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_regroup_keyed_tensors,
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"KeyedTensor.regroup" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KeyedTensor.regroup,
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"KTRegroupAsDict" + dup,
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
KTRegroupAsDict(
groups=groups, keys=[str(i) for i in range(n_groups)]
),
{"keyed_tensors": kts},
profile,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 704afbe

Please sign in to comment.