diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 1236a8a13..b4f0fc656 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -492,9 +492,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: def benchmark( name: str, model: torch.nn.Module, - warmup_inputs: List[KeyedJaggedTensor], - bench_inputs: List[KeyedJaggedTensor], - prof_inputs: List[KeyedJaggedTensor], + warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], world_size: int, output_dir: str, num_benchmarks: int, @@ -558,7 +558,7 @@ def benchmark( [si.elapsed_time(ei) for si, ei in zip(start[1:], end[1:])] ) else: - elapsed_time = torch.tensor(times) + elapsed_time = torch.tensor(times) * 1e3 if device_type == "cuda": if rank == -1: diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index 1745910ea..47565260f 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -40,6 +40,7 @@ def bench( run_backward: bool, fn: Callable[..., List[torch.Tensor]], fn_kwargs: Dict[str, Any], + output_dir: str = "", ) -> None: # initial call @@ -49,8 +50,8 @@ def wrapped_func( model: torch.nn.Module, # not used bench_inputs: List[KeyedJaggedTensor], # not used fn: Callable[..., List[torch.Tensor]], - fn_kwargs: Dict[str, Any], run_backward: bool, + **kwargs: Dict[str, Any], ) -> None: result = fn(**fn_kwargs) if run_backward: @@ -64,26 +65,28 @@ def wrapped_func( loss = torch.nn.functional.l1_loss(pred, labels) loss.sum().backward() + model = DummyModel() + setattr(model, "forward", lambda kwargs: fn(**kwargs)) + prof_num = 10 if device_type == "cuda": result = benchmark( name=name, - model=DummyModel(), + model=model, warmup_inputs=[], bench_inputs=[], - prof_inputs=[], + prof_inputs=[fn_kwargs] * prof_num, world_size=1, - output_dir="", + output_dir=output_dir, num_benchmarks=20, func_to_benchmark=functools.partial( wrapped_func, fn=fn, run_backward=run_backward, fn_kwargs=fn_kwargs ), benchmark_func_kwargs={}, rank=0, - enable_logging=False, + enable_logging=True, ) else: # cpu - model = DummyModel() times = timeit.repeat( lambda: wrapped_func( model=model, @@ -97,7 +100,7 @@ def wrapped_func( ) result = BenchmarkResult( short_name=name, - elapsed_time=torch.tensor(times), + elapsed_time=torch.tensor(times) * 1e3, max_mem_allocated=[0], ) @@ -160,6 +163,12 @@ def wrapped_func( default=2, help="Total num of regrouping", ) +@click.option( + "--profile", + type=str, + default="", + help="profile output directory", +) def main( cuda_matrix: bool, run_backward: bool, @@ -170,6 +179,7 @@ def main( dim_sparse: int, batch_size: int, n_groups: int, + profile: str, ) -> None: if cuda_matrix: n_denses = [64, 128, 256, 512, 1024] @@ -184,54 +194,58 @@ def main( for device_type in device_types: for batch_size in batch_sizes: - for n_dense, n_sparse in zip(n_denses, n_sparses): - - device = torch.device(device_type) - kts = build_kts( - n_dense, - n_sparse, - dim_dense, - dim_sparse, - batch_size, - device, - run_backward, - ) - labels = torch.randint( - 0, 1, (batch_size,), device=torch.device(device_type) - ).float() - groups = build_groups(kts, n_groups) - bench( - "[fallback] _regroup_keyed_tenors", - labels, - batch_size, - n_dense + n_sparse, - device_type, - run_backward, - _regroup_keyed_tensors, - {"keyed_tensors": kts, "groups": groups}, - ) - bench( - "[prod] KeyedTensor.regroup", - labels, - batch_size, - n_dense + n_sparse, - device_type, - run_backward, - KeyedTensor.regroup, - {"keyed_tensors": kts, "groups": groups}, - ) - bench( - "[prod] KTRegroupAsDict", - labels, - batch_size, - n_dense + n_sparse, - device_type, - run_backward, - KTRegroupAsDict( - groups=groups, keys=[str(i) for i in range(n_groups)] - ), - {"keyed_tensors": kts}, - ) + for duplicates in [False, True]: + for n_dense, n_sparse in zip(n_denses, n_sparses): + dup = "_dup" if duplicates else "" + device = torch.device(device_type) + kts = build_kts( + n_dense, + n_sparse, + dim_dense, + dim_sparse, + batch_size, + device, + run_backward, + ) + labels = torch.randint( + 0, 1, (batch_size,), device=torch.device(device_type) + ).float() + groups = build_groups(kts, n_groups, duplicates=duplicates) + bench( + "_regroup_keyed_tenors" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _regroup_keyed_tensors, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "KeyedTensor.regroup" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KeyedTensor.regroup, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "KTRegroupAsDict" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KTRegroupAsDict( + groups=groups, keys=[str(i) for i in range(n_groups)] + ), + {"keyed_tensors": kts}, + profile, + ) if __name__ == "__main__":