Skip to content

Commit

Permalink
reporting on compiler and runtime stats (#2451)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored Aug 22, 2023
1 parent 255ac40 commit 7ee11bc
Show file tree
Hide file tree
Showing 4 changed files with 447 additions and 164 deletions.
48 changes: 30 additions & 18 deletions src/Conversion/KrnlToLLVM/KrnlInstrument.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,25 @@ class KrnlInstrumentOpLowering : public ConversionPattern {
KrnlInstrumentOpAdaptor operandAdaptor(operands);
Location loc = op->getLoc();
KrnlInstrumentOp instrumentOp = llvm::dyn_cast<KrnlInstrumentOp>(op);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
LLVMTypeConverter *typeConverter =
static_cast<LLVMTypeConverter *>(getTypeConverter());

// Get a symbol reference to the memcpy function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto instrumentRef = getOrInsertInstrument(rewriter, parentModule);

StringRef opNameStr = instrumentOp.getOpName();
LLVM::GlobalOp globalOpNameStr = krnl::getOrCreateGlobalString(
opNameStr, loc, rewriter, parentModule, typeConverter);
Value opNamePtr =
krnl::getPtrToGlobalString(globalOpNameStr, loc, rewriter);
Value tag = create.llvm.constant(
IntegerType::get(context, 64), (int64_t)instrumentOp.getTag());

// We may try to relate the node names generated by the instrumentation
// with the node names printed by onnx-op-report. Thus it is key to keep the
// code that generates these node name in sync.
//
// The code are found here:
// 1) `matchAndRewrite` from `src/Conversion/KrnlToLLVM/KrnlInstrument.cpp`
// 2) `getNodeNameLikeInKrnlInstrument` from
// `src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp`

StringRef nodeName;
if (instrumentOp.getNodeName().has_value())
nodeName = instrumentOp.getNodeName().value();
else if (auto nameLoc = loc.dyn_cast<NameLoc>())
nodeName = nameLoc.getName();
else if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
// Combine each location name and set it as nodeName.
// Combine each location name and set it as nodeName, appended by "-".
std::string name;
for (Location locIt : fusedLoc.getLocations()) {
if (auto nameLocIt = locIt.dyn_cast<NameLoc>())
Expand All @@ -83,19 +80,34 @@ class KrnlInstrumentOpLowering : public ConversionPattern {
name = "NOTSET";
else
name.pop_back(); // remove last "-"
loc = NameLoc::get(rewriter.getStringAttr(name));
nodeName = cast<NameLoc>(loc).getName();
Location newLoc = NameLoc::get(rewriter.getStringAttr(name));
nodeName = cast<NameLoc>(newLoc).getName();
} else if (auto fileLineColLoc = loc.dyn_cast<FileLineColLoc>()) {
std::string filename =
llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str();
std::string name =
filename + ":" + std::to_string(fileLineColLoc.getLine());
loc = NameLoc::get(rewriter.getStringAttr(name));
nodeName = cast<NameLoc>(loc).getName();
Location newLoc = NameLoc::get(rewriter.getStringAttr(name));
nodeName = cast<NameLoc>(newLoc).getName();
} else
nodeName = StringRef("NOTSET");
LLVM_DEBUG(
llvm::dbgs() << "Instrumentation_nodeName: " << nodeName << "\n");

MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
LLVMTypeConverter *typeConverter =
static_cast<LLVMTypeConverter *>(getTypeConverter());

// Get a symbol reference to the memcpy function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto instrumentRef = getOrInsertInstrument(rewriter, parentModule);

LLVM::GlobalOp globalOpNameStr = krnl::getOrCreateGlobalString(
opNameStr, loc, rewriter, parentModule, typeConverter);
Value opNamePtr =
krnl::getPtrToGlobalString(globalOpNameStr, loc, rewriter);
Value tag = create.llvm.constant(
IntegerType::get(context, 64), (int64_t)instrumentOp.getTag());
LLVM::GlobalOp globalStr = krnl::getOrCreateGlobalString(
nodeName, loc, rewriter, parentModule, typeConverter);
Value nodeNamePtr = krnl::getPtrToGlobalString(globalStr, loc, rewriter);
Expand Down
74 changes: 64 additions & 10 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "llvm/Support/Path.h"

#include "src/Accelerators/Accelerator.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp"

#define DEBUG_TYPE "lowering-to-krnl"

using namespace mlir;

namespace onnx_mlir {
Expand Down Expand Up @@ -661,18 +664,70 @@ bool hasNonIdentityLayout(ValueRange operands) {
// Support functions for reporting.
//===----------------------------------------------------------------------===//

// We may try to relate the node names generated by the instrumentation
// with the node names printed by onnx-op-report. Thus it is key to keep the
// code that generates these node name in sync.
//
// The code are found here:
// 1) `matchAndRewrite` from `src/Conversion/KrnlToLLVM/KrnlInstrument.cpp`
// 2) `getNodeNameLikeInKrnlInstrument` from
// `src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp`

static std::string getNodeNameLikeInKrnlInstrument(Operation *op) {
StringAttr nodeName;
// Try with op onnx_node_name attribute.
nodeName = op->getAttrOfType<StringAttr>("onnx_node_name");
if (nodeName) {
LLVM_DEBUG(llvm::dbgs() << "op has node name\n");
return nodeName.str();
}
// Try with op location.
Location loc = op->getLoc();
if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
LLVM_DEBUG(llvm::dbgs() << "op has node name\n");
return nameLoc.getName().str();
}
if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
// Combine each location name and set it as nodeName.
LLVM_DEBUG(llvm::dbgs() << "fuse loc has name\n");
std::string name;
for (Location locIt : fusedLoc.getLocations()) {
if (auto nameLocIt = locIt.dyn_cast<NameLoc>())
name += nameLocIt.getName().str() + "-";
else if (auto fileLineColLoc = locIt.dyn_cast<FileLineColLoc>()) {
std::string filename =
llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str();
name += filename + ":" + std::to_string(fileLineColLoc.getLine()) + "-";
}
}
if (name.empty())
name = "NOTSET";
else
name.pop_back(); // remove last "-"
return name;
}
if (auto fileLineColLoc = loc.dyn_cast<FileLineColLoc>()) {
LLVM_DEBUG(llvm::dbgs() << "file line col has name\n");
std::string filename =
llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str();
std::string name =
filename + ":" + std::to_string(fileLineColLoc.getLine());
return name;
}
return "NOTSET";
}

void impl::onnxToKrnlParallelReport(Operation *op, bool successful,
int64_t loopLevel, int64_t parallelLoopTripCount,
const std::string &comment) {
assert(OnnxToKrnlLoweringConfiguration::reportOnParallel && "must report");
assert(comment.find(',') == std::string::npos && "no comma in comments");
StringAttr opName = op->getName().getIdentifier();
StringAttr nodeName = op->getAttrOfType<StringAttr>("onnx_node_name");
std::string nodeNameStr = getNodeNameLikeInKrnlInstrument(op);
// Print report on this op.
fprintf(stderr, "==ONNX-PAR-REPORT==, %s%s, %s, %lld, %lld, %s\n",
opName.data(), (successful ? "-parallel" : ""),
(nodeName ? nodeName.data() : "no-node-name"), loopLevel,
parallelLoopTripCount, comment.c_str());
printf("==PAR-REPORT==, %s%s, %s, %s, %lld, %lld\n", opName.data(),
(successful ? "-parallel" : ""), nodeNameStr.c_str(), comment.c_str(),
loopLevel, parallelLoopTripCount);
}

void impl::onnxToKrnlSimdReport(Operation *op, bool successful,
Expand All @@ -681,7 +736,7 @@ void impl::onnxToKrnlSimdReport(Operation *op, bool successful,
assert(OnnxToKrnlLoweringConfiguration::reportOnSimd && "must report");
assert(comment.find(',') == std::string::npos && "no comma in comments");
StringAttr opName = op->getName().getIdentifier();
StringAttr nodeName = op->getAttrOfType<StringAttr>("onnx_node_name");
std::string nodeNameStr = getNodeNameLikeInKrnlInstrument(op);
// Handling message.
std::string message = OnnxToKrnlLoweringConfiguration::defaultSimdComment;
if (message.empty())
Expand All @@ -690,10 +745,9 @@ void impl::onnxToKrnlSimdReport(Operation *op, bool successful,
// No comments, all values indicate no simd
message = "unsupported";
// Print report on this op.
fprintf(stderr, "==ONNX-SIMD-REPORT==, %s%s, %s, %lld, %lld, %s\n",
opName.data(), (successful ? "-simd" : ""),
(nodeName ? nodeName.data() : "no-node-name"), vectorLength,
simdLoopTripCount, message.c_str());
printf("==SIMD-REPORT==, %s%s, %s, %s, %lld, %lld\n", opName.data(),
(successful ? "-simd" : ""), nodeNameStr.c_str(), message.c_str(),
vectorLength, simdLoopTripCount);
}

} // namespace onnx_mlir
58 changes: 27 additions & 31 deletions src/Runtime/OMInstrument.inc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#ifdef _WIN32
#include "windows.h"
// The windows.h include must go first.
#include "psapi.h"

static LARGE_INTEGER globalTime, initTime;
Expand All @@ -50,13 +51,13 @@ static LARGE_INTEGER perfFrequency;

static struct timeval globalTimeVal, initTimeVal;
static pid_t mypid;
static int psErrorCount = 0;
#endif

static bool instrumentReportDisabled = false;
static bool instrumentReportTimeDisabled = false;
static bool instrumentReportMemoryDisabled = false;
static int instrumentCounter = 0;
static int psErrorCount = 0;

#ifdef __MVS__
#define timersub(a, b, result) \
Expand Down Expand Up @@ -96,21 +97,19 @@ void ReportTime() {
LONGLONG resultSeconds, resultMicroseconds;
QueryPerformanceCounter(&newTime);
WinTimerSub(newTime, globalTime, &resultSeconds, &resultMicroseconds);
printf(" Time elapsed: %lld.%06lld", resultSeconds, resultMicroseconds);
printf(", %lld.%06lld", resultSeconds, resultMicroseconds);
WinTimerSub(newTime, initTime, &resultSeconds, &resultMicroseconds);
printf(" accumulated: %lld.%06lld", resultSeconds, resultMicroseconds);
printf(", %lld.%06lld\n", resultSeconds, resultMicroseconds);
globalTime = newTime;
}
#else
void ReportTime() {
struct timeval newTimeValue, result;
gettimeofday(&newTimeValue, NULL);
timersub(&newTimeValue, &globalTimeVal, &result);
printf(" Time elapsed: %ld.%06ld", (long int)result.tv_sec,
(long int)result.tv_usec);
printf(", %ld.%06ld", (long int)result.tv_sec, (long int)result.tv_usec);
timersub(&newTimeValue, &initTimeVal, &result);
printf(" accumulated: %ld.%06ld", (long int)result.tv_sec,
(long int)result.tv_usec);
printf(", %ld.%06ld\n", (long int)result.tv_sec, (long int)result.tv_usec);
globalTimeVal = newTimeValue;
}
#endif
Expand All @@ -121,7 +120,7 @@ void ReportMemory() {
GetProcessMemoryInfo(
GetCurrentProcess(), (PROCESS_MEMORY_COUNTERS *)&pmc, sizeof(pmc));
SIZE_T vMemSizeKB = pmc.PrivateUsage / 1024;
printf(" VMem: %zu", vMemSizeKB);
printf("%zu\n", vMemSizeKB);
}
#else
void ReportMemory() {
Expand All @@ -132,22 +131,19 @@ void ReportMemory() {
snprintf(memCommand, sizeof(memCommand), "ps -o vsz='' -p %d", mypid);
memPipe = popen(memCommand, "r");
if (!memPipe) {
if (psErrorCount > 20)
return;
fprintf(stderr, "ERROR: Failed to execute ps");
printf(", error-failed-to-execute-ps\n");
psErrorCount++;
return;
}
(void)fgets(memOutput, 200, memPipe);
(void)fgetc(memPipe);
memOutput[strcspn(memOutput, "\n")] = 0;
printf(" VMem:%s", memOutput);
if (!feof(memPipe)) {
if (psErrorCount > 20) {
pclose(memPipe);
return;
}
fprintf(stderr, "ERROR: Unexpected output from ps");
printf(", error-unexpected-output-from-pipe\n");
psErrorCount++;
} else {
// No error, print data/
printf(", %s\n", memOutput);
}
pclose(memPipe);
}
Expand Down Expand Up @@ -180,24 +176,24 @@ void OMInstrumentPoint(const char *opName, int64_t tag, const char *nodeName) {
if (instrumentReportDisabled)
return;

// Print header
printf("#%d) %s %s", instrumentCounter,
tag & (1 << (int)InstrumentBeforeOp) ? "before" : "after ", opName);
instrumentCounter++;

bool localReportTime =
bool isBefore = tag & (1 << (int)InstrumentBeforeOp);
bool reportTime =
tag & (1 << (int)InstrumentReportTime) && !instrumentReportTimeDisabled;
bool reportMem = tag & (1 << (int)InstrumentReportMemory) &&
!instrumentReportMemoryDisabled;

if (localReportTime) {
if (reportTime) {
// Print header and data for time.
printf("==PERF-REPORT==, %s, %s, %s", opName, nodeName,
(isBefore ? "before" : "after"));
ReportTime();
}

bool localReportMemory = tag & (1 << (int)InstrumentReportMemory) &&
!instrumentReportMemoryDisabled;
if (localReportMemory) {
if (reportMem && psErrorCount < 20) {
// Print header and data for memory.
printf("==MEM-REPORT==, %s, %s, %s", opName, nodeName,
(isBefore ? "before" : "after"));
ReportMemory();
}
if (strncmp(nodeName, "NOTSET", 6) != 0)
printf(" (%s)", nodeName);
printf("\n");
// Not sure if needed, remove if that is not the case.
instrumentCounter++;
}
Loading

0 comments on commit 7ee11bc

Please sign in to comment.