Skip to content

Commit

Permalink
Add preference for device timing metrics instead of host-aligned timi…
Browse files Browse the repository at this point in the history
…ng metrics when constructing device_op_metrics_db.

PiperOrigin-RevId: 679731218
  • Loading branch information
bmass02 authored and Google-ML-Automation committed Sep 27, 2024
1 parent ccaeb74 commit 1ad477d
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 8 deletions.
2 changes: 1 addition & 1 deletion third_party/tsl/tsl/profiler/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ tsl_cc_test(
srcs = ["xplane_utils_test.cc"],
deps = [
":math_utils",
":tf_xplane_visitor",
":xplane_builder",
":xplane_schema",
":xplane_utils",
Expand All @@ -236,7 +237,6 @@ tsl_cc_test(
"//tsl/platform:test_main",
"//tsl/platform:types",
"//tsl/profiler/protobuf:xplane_proto_cc",
"//tsl/profiler/utils:tf_xplane_visitor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
Expand Down
2 changes: 2 additions & 0 deletions third_party/tsl/tsl/profiler/utils/xplane_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ const StatTypeMap& GetStatTypeMap() {
{"cuda_graph_orig_id", kCudaGraphOrigId},
{"step_idle_time_ps", kStepIdleTimePs},
{"gpu_device_name", kGpuDeviceName},
{"device_offset_ps", kDeviceOffsetPs},
{"device_duration_ps", kDeviceDurationPs},
});
DCHECK_EQ(stat_type_map->size(), kNumStatTypes);
return *stat_type_map;
Expand Down
4 changes: 3 additions & 1 deletion third_party/tsl/tsl/profiler/utils/xplane_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ enum StatType {
kCudaGraphOrigId,
kStepIdleTimePs,
kGpuDeviceName,
kLastStatType = kGpuDeviceName,
kDeviceOffsetPs,
kDeviceDurationPs,
kLastStatType = kDeviceDurationPs,
};

enum MegaScaleStatType : uint8_t {
Expand Down
24 changes: 19 additions & 5 deletions third_party/tsl/tsl/profiler/utils/xplane_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ bool IsOpLineName(absl::string_view line_name) {
return line_name == kXlaOpLineName || line_name == kTensorFlowOpLineName;
}

Timespan GetEventTimespan(const XEventVisitor& event) {
const std::optional<XStatVisitor> device_offset_ps =
event.GetStat(StatType::kDeviceOffsetPs);
const std::optional<XStatVisitor> device_duration_ps =
event.GetStat(StatType::kDeviceDurationPs);
if (device_offset_ps.has_value() && device_duration_ps.has_value()) {
return Timespan(device_offset_ps->IntOrUintValue(),
device_duration_ps->IntOrUintValue());
}

return event.GetTimespan();
}

} // namespace

const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) {
Expand Down Expand Up @@ -563,26 +576,27 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) {
aggregated_line.SetName(line.Name());
std::vector<XEventVisitor> event_stack;
line.ForEachEvent([&](XEventVisitor event) {
Timespan timespan = GetEventTimespan(event);
first_op_start_ps = first_op_start_ps <= event.TimestampPs()
? first_op_start_ps
: event.TimestampPs();
: timespan.begin_ps();
last_op_end_ps = last_op_end_ps >= event.EndTimestampPs()
? last_op_end_ps
: event.EndTimestampPs();
: timespan.end_ps();
const auto& group_stat = event.GetStat(StatType::kGroupId);
int64_t group_id =
group_stat.has_value() ? group_stat->IntOrUintValue() : kint64max;

StatByEvent& line_stats = stats[line.Id()][group_id];
line_stats[event.Id()].stat.UpdateStat(event.DurationPs());
line_stats[event.Id()].stat.UpdateStat(timespan.duration_ps());
DCHECK(event_stack.empty() || !(event < event_stack.back()));
while (!event_stack.empty() &&
!event_stack.back().GetTimespan().Includes(event.GetTimespan())) {
!GetEventTimespan(event_stack.back()).Includes(timespan)) {
event_stack.pop_back();
}
if (!event_stack.empty()) {
line_stats[event_stack.back().Id()].children_duration +=
event.DurationPs();
timespan.duration_ps();
}
event_stack.push_back(std::move(event));
});
Expand Down
76 changes: 75 additions & 1 deletion third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,81 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) {
#endif
}

TEST(XPlanuUtilsTest, TestInstantEventDoesNotFail) {
TEST(XPlaneUtilsTest, TestAggregateXPlaneWithCycleStats) {
XPlane xplane;
XPlaneBuilder builder(&xplane);
const XStatMetadata& device_offset_stat = *builder.GetOrCreateStatMetadata(
GetStatTypeStr(StatType::kDeviceOffsetPs));
const XStatMetadata& device_duration_stat = *builder.GetOrCreateStatMetadata(
GetStatTypeStr(StatType::kDeviceDurationPs));

const XEventMetadata& event_metadata1 =
*builder.GetOrCreateEventMetadata("EventMetadata1");
const XEventMetadata& event_metadata2 =
*builder.GetOrCreateEventMetadata("EventMetadata2");
const XEventMetadata& event_metadata3 =
*builder.GetOrCreateEventMetadata("EventMetadata3");

XLineBuilder line = builder.GetOrCreateLine(1);
line.SetName(kXlaOpLineName);
XEventBuilder event1 = line.AddEvent(event_metadata1);
event1.SetOffsetNs(0);
event1.SetDurationNs(5);
event1.AddStatValue(device_offset_stat, 50);
event1.AddStatValue(device_duration_stat, 4950);
XEventBuilder event2 = line.AddEvent(event_metadata2);
event2.SetOffsetNs(2);
event2.SetDurationNs(1);
event2.AddStatValue(device_offset_stat, 1950);
event2.AddStatValue(device_duration_stat, 890);
XEventBuilder event3 = line.AddEvent(event_metadata3);
event3.SetOffsetNs(3);
event3.SetDurationNs(2);
event3.AddStatValue(device_offset_stat, 2950);
event3.AddStatValue(device_duration_stat, 2050);
XEventBuilder event4 = line.AddEvent(event_metadata1);
event4.SetOffsetNs(5);
event4.SetDurationNs(5);
event4.AddStatValue(device_offset_stat, 5000);
event4.AddStatValue(device_duration_stat, 4950);
XEventBuilder event5 = line.AddEvent(event_metadata2);
event5.SetOffsetNs(7);
event5.SetDurationNs(1);
event5.AddStatValue(device_offset_stat, 7050);
event5.AddStatValue(device_duration_stat, 900);
XEventBuilder event6 = line.AddEvent(event_metadata3);
event6.SetOffsetNs(8);
event6.SetDurationNs(2);
event6.AddStatValue(device_offset_stat, 8050);
event6.AddStatValue(device_duration_stat, 1900);

XPlane aggregated_xplane;
AggregateXPlane(xplane, aggregated_xplane);

XPlaneVisitor visitor = CreateTfXPlaneVisitor(&aggregated_xplane);
visitor.ForEachLine([&](const XLineVisitor& line) {
EXPECT_EQ(line.Name(), kXlaOpLineName);
line.ForEachEvent([&](const XEventVisitor& event) {
EXPECT_EQ(event.OffsetPs(), 0);
if (event.Metadata().Name() == "EventMetadata1") {
EXPECT_EQ(event.NumOccurrences(), 2);
EXPECT_EQ(event.DurationPs(), 9900);
EXPECT_EQ((*event.GetStat(StatType::kSelfDurationPs)).IntOrUintValue(),
4160);
}
if (event.Metadata().Name() == "EventMetadata2") {
EXPECT_EQ(event.NumOccurrences(), 2);
EXPECT_EQ(event.DurationPs(), 1790);
}
if (event.Metadata().Name() == "EventMetadata3") {
EXPECT_EQ(event.NumOccurrences(), 2);
EXPECT_EQ(event.DurationPs(), 3950);
}
});
});
}

TEST(XPlaneUtilsTest, TestInstantEventDoesNotFail) {
XPlane xplane;
XPlaneBuilder xplane_builder(&xplane);
XEventMetadata* event_metadata1 = xplane_builder.GetOrCreateEventMetadata(1);
Expand Down

0 comments on commit 1ad477d

Please sign in to comment.