From d7d5449bef8805598a83bcf4a56cac06b23da0c0 Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Fri, 10 Jan 2025 16:54:12 +0000 Subject: [PATCH] Add ratio of hidden communication time to total communication time --- .../nsys_jax/analyses/communication.py | 28 +++++++++++++++++++ .../container/nsys_jax/nsys_jax/analysis.py | 3 ++ 2 files changed, 31 insertions(+) diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py index 5388a1f84..b02e4af01 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -8,6 +8,7 @@ load_profiler_data, ) from math import sqrt +from statistics import mean import pathlib from uncertainties import ufloat # type: ignore @@ -95,6 +96,33 @@ def format_bandwidth(data, collective): ) ) + collective_types = set() + summary_data = defaultdict(dict) + for collective, df in steady_state.communication.groupby( + ["Collective"] + ): + collective_types.add(collective) + summary_data[collective] = df["DurHiddenMsToDurMs"].mean() + + collective_width = max(len("Collective"), max(len(f"{collective}") for collective in collective_types)) + ratio_width = len("Mean HiddenToTotalMs") + + print() + print(f"{'Collective':<{collective_width}} | {'Mean HiddenToTotalMs':<{ratio_width}}") + print(f"{'-' * collective_width} | {'-' * ratio_width}") + + for collective in collective_types: + mean_value = summary_data[collective] + collective_str = str(collective[0]) + print(f"{collective_str:<{collective_width}} | {mean_value:>{ratio_width}}") + + overall_hidden_ms_to_total_ms = ( + steady_state.communication["ProjDurHiddenMs"].sum() / + (steady_state.communication["ProjDurMs"] + steady_state.communication["ProjDurHiddenMs"]).sum() + ) + + print() + print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms:>{ratio_width}}") if __name__ == "__main__": main() diff --git a/.github/container/nsys_jax/nsys_jax/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py index c4e37fdf9..31c16cbb7 100644 --- a/.github/container/nsys_jax/nsys_jax/analysis.py +++ b/.github/container/nsys_jax/nsys_jax/analysis.py @@ -331,6 +331,9 @@ def calculate_collective_metrics( comm_df["BusBandwidthGBPerSec"] = ( comm_df["AlgorithmBandwidthGBPerSec"] * comm_df["BusBandwidthCorrection"] ) + comm_df["DurHiddenMsToDurMs"] = ( + comm_df["ProjDurHiddenMs"] / (comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"]) + ) return comm_df.drop(columns=["BandwidthCorrection", "BusBandwidthCorrection"])