Skip to content

Commit

Permalink
Add preference for device timing metrics instead of host-aligned timi…
Browse files Browse the repository at this point in the history
…ng metrics when constructing device_op_metrics_db.

PiperOrigin-RevId: 679731218
  • Loading branch information
bmass02 authored and Google-ML-Automation committed Oct 8, 2024
1 parent f924b1f commit 0e2507e
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 7 deletions.
4 changes: 3 additions & 1 deletion xla/tsl/profiler/utils/xplane_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ const StatTypeMap& GetStatTypeMap() {
{"cuda_graph_orig_id", kCudaGraphOrigId},
{"step_idle_time_ps", kStepIdleTimePs},
{"gpu_device_name", kGpuDeviceName},
{"source_stack", kSourceStack}});
{"source_stack", kSourceStack},
{"device_offset_ps", kDeviceOffsetPs},
{"device_duration_ps", kDeviceDurationPs}});
DCHECK_EQ(stat_type_map->size(), kNumStatTypes);
return *stat_type_map;
}
Expand Down
4 changes: 3 additions & 1 deletion xla/tsl/profiler/utils/xplane_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ enum StatType {
kStepIdleTimePs,
kGpuDeviceName,
kSourceStack,
kLastStatType = kSourceStack,
kDeviceOffsetPs,
kDeviceDurationPs,
kLastStatType = kDeviceDurationPs,
};

enum MegaScaleStatType : uint8_t {
Expand Down
24 changes: 19 additions & 5 deletions xla/tsl/profiler/utils/xplane_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ bool IsOpLineName(absl::string_view line_name) {
return line_name == kXlaOpLineName || line_name == kTensorFlowOpLineName;
}

Timespan GetEventTimespan(const XEventVisitor& event) {
const std::optional<XStatVisitor> device_offset_ps =
event.GetStat(StatType::kDeviceOffsetPs);
const std::optional<XStatVisitor> device_duration_ps =
event.GetStat(StatType::kDeviceDurationPs);
if (device_offset_ps.has_value() && device_duration_ps.has_value()) {
return Timespan(device_offset_ps->IntOrUintValue(),
device_duration_ps->IntOrUintValue());
}

return event.GetTimespan();
}

} // namespace

const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) {
Expand Down Expand Up @@ -563,26 +576,27 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) {
aggregated_line.SetName(line.Name());
std::vector<XEventVisitor> event_stack;
line.ForEachEvent([&](XEventVisitor event) {
Timespan timespan = GetEventTimespan(event);
first_op_start_ps = first_op_start_ps <= event.TimestampPs()
? first_op_start_ps
: event.TimestampPs();
: timespan.begin_ps();
last_op_end_ps = last_op_end_ps >= event.EndTimestampPs()
? last_op_end_ps
: event.EndTimestampPs();
: timespan.end_ps();
const auto& group_stat = event.GetStat(StatType::kGroupId);
int64_t group_id =
group_stat.has_value() ? group_stat->IntOrUintValue() : kint64max;

StatByEvent& line_stats = stats[line.Id()][group_id];
line_stats[event.Id()].stat.UpdateStat(event.DurationPs());
line_stats[event.Id()].stat.UpdateStat(timespan.duration_ps());
DCHECK(event_stack.empty() || !(event < event_stack.back()));
while (!event_stack.empty() &&
!event_stack.back().GetTimespan().Includes(event.GetTimespan())) {
!GetEventTimespan(event_stack.back()).Includes(timespan)) {
event_stack.pop_back();
}
if (!event_stack.empty()) {
line_stats[event_stack.back().Id()].children_duration +=
event.DurationPs();
timespan.duration_ps();
}
event_stack.push_back(std::move(event));
});
Expand Down
74 changes: 74 additions & 0 deletions xla/tsl/profiler/utils/xplane_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,80 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) {
#endif
}

TEST(XPlaneUtilsTest, TestAggregateXPlaneWithCycleStats) {
XPlane xplane;
XPlaneBuilder builder(&xplane);
const XStatMetadata& device_offset_stat = *builder.GetOrCreateStatMetadata(
GetStatTypeStr(StatType::kDeviceOffsetPs));
const XStatMetadata& device_duration_stat = *builder.GetOrCreateStatMetadata(
GetStatTypeStr(StatType::kDeviceDurationPs));

const XEventMetadata& event_metadata1 =
*builder.GetOrCreateEventMetadata("EventMetadata1");
const XEventMetadata& event_metadata2 =
*builder.GetOrCreateEventMetadata("EventMetadata2");
const XEventMetadata& event_metadata3 =
*builder.GetOrCreateEventMetadata("EventMetadata3");

XLineBuilder line = builder.GetOrCreateLine(1);
line.SetName(kXlaOpLineName);
XEventBuilder event1 = line.AddEvent(event_metadata1);
event1.SetOffsetNs(0);
event1.SetDurationNs(5);
event1.AddStatValue(device_offset_stat, 50);
event1.AddStatValue(device_duration_stat, 4950);
XEventBuilder event2 = line.AddEvent(event_metadata2);
event2.SetOffsetNs(2);
event2.SetDurationNs(1);
event2.AddStatValue(device_offset_stat, 1950);
event2.AddStatValue(device_duration_stat, 890);
XEventBuilder event3 = line.AddEvent(event_metadata3);
event3.SetOffsetNs(3);
event3.SetDurationNs(2);
event3.AddStatValue(device_offset_stat, 2950);
event3.AddStatValue(device_duration_stat, 2050);
XEventBuilder event4 = line.AddEvent(event_metadata1);
event4.SetOffsetNs(5);
event4.SetDurationNs(5);
event4.AddStatValue(device_offset_stat, 5000);
event4.AddStatValue(device_duration_stat, 4950);
XEventBuilder event5 = line.AddEvent(event_metadata2);
event5.SetOffsetNs(7);
event5.SetDurationNs(1);
event5.AddStatValue(device_offset_stat, 7050);
event5.AddStatValue(device_duration_stat, 900);
XEventBuilder event6 = line.AddEvent(event_metadata3);
event6.SetOffsetNs(8);
event6.SetDurationNs(2);
event6.AddStatValue(device_offset_stat, 8050);
event6.AddStatValue(device_duration_stat, 1900);

XPlane aggregated_xplane;
AggregateXPlane(xplane, aggregated_xplane);

XPlaneVisitor visitor = CreateTfXPlaneVisitor(&aggregated_xplane);
visitor.ForEachLine([&](const XLineVisitor& line) {
EXPECT_EQ(line.Name(), kXlaOpLineName);
line.ForEachEvent([&](const XEventVisitor& event) {
EXPECT_EQ(event.OffsetPs(), 0);
if (event.Metadata().Name() == "EventMetadata1") {
EXPECT_EQ(event.NumOccurrences(), 2);
EXPECT_EQ(event.DurationPs(), 9900);
EXPECT_EQ((*event.GetStat(StatType::kSelfDurationPs)).IntOrUintValue(),
4160);
}
if (event.Metadata().Name() == "EventMetadata2") {
EXPECT_EQ(event.NumOccurrences(), 2);
EXPECT_EQ(event.DurationPs(), 1790);
}
if (event.Metadata().Name() == "EventMetadata3") {
EXPECT_EQ(event.NumOccurrences(), 2);
EXPECT_EQ(event.DurationPs(), 3950);
}
});
});
}

TEST(XPlanuUtilsTest, TestInstantEventDoesNotFail) {
XPlane xplane;
XPlaneBuilder xplane_builder(&xplane);
Expand Down

0 comments on commit 0e2507e

Please sign in to comment.