diff --git a/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp index a5b35b0be8..5ad19580cc 100644 --- a/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp @@ -45,28 +45,25 @@ class KrnlInstrumentOpLowering : public ConversionPattern { KrnlInstrumentOpAdaptor operandAdaptor(operands); Location loc = op->getLoc(); KrnlInstrumentOp instrumentOp = llvm::dyn_cast(op); - MultiDialectBuilder create(rewriter, loc); - LLVMTypeConverter *typeConverter = - static_cast(getTypeConverter()); - - // Get a symbol reference to the memcpy function, inserting it if necessary. - ModuleOp parentModule = op->getParentOfType(); - auto instrumentRef = getOrInsertInstrument(rewriter, parentModule); StringRef opNameStr = instrumentOp.getOpName(); - LLVM::GlobalOp globalOpNameStr = krnl::getOrCreateGlobalString( - opNameStr, loc, rewriter, parentModule, typeConverter); - Value opNamePtr = - krnl::getPtrToGlobalString(globalOpNameStr, loc, rewriter); - Value tag = create.llvm.constant( - IntegerType::get(context, 64), (int64_t)instrumentOp.getTag()); + + // We may try to relate the node names generated by the instrumentation + // with the node names printed by onnx-op-report. Thus it is key to keep the + // code that generates these node name in sync. + // + // The code are found here: + // 1) `matchAndRewrite` from `src/Conversion/KrnlToLLVM/KrnlInstrument.cpp` + // 2) `getNodeNameLikeInKrnlInstrument` from + // `src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp` + StringRef nodeName; if (instrumentOp.getNodeName().has_value()) nodeName = instrumentOp.getNodeName().value(); else if (auto nameLoc = loc.dyn_cast()) nodeName = nameLoc.getName(); else if (auto fusedLoc = loc.dyn_cast()) { - // Combine each location name and set it as nodeName. + // Combine each location name and set it as nodeName, appended by "-". std::string name; for (Location locIt : fusedLoc.getLocations()) { if (auto nameLocIt = locIt.dyn_cast()) @@ -83,19 +80,34 @@ class KrnlInstrumentOpLowering : public ConversionPattern { name = "NOTSET"; else name.pop_back(); // remove last "-" - loc = NameLoc::get(rewriter.getStringAttr(name)); - nodeName = cast(loc).getName(); + Location newLoc = NameLoc::get(rewriter.getStringAttr(name)); + nodeName = cast(newLoc).getName(); } else if (auto fileLineColLoc = loc.dyn_cast()) { std::string filename = llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str(); std::string name = filename + ":" + std::to_string(fileLineColLoc.getLine()); - loc = NameLoc::get(rewriter.getStringAttr(name)); - nodeName = cast(loc).getName(); + Location newLoc = NameLoc::get(rewriter.getStringAttr(name)); + nodeName = cast(newLoc).getName(); } else nodeName = StringRef("NOTSET"); LLVM_DEBUG( llvm::dbgs() << "Instrumentation_nodeName: " << nodeName << "\n"); + + MultiDialectBuilder create(rewriter, loc); + LLVMTypeConverter *typeConverter = + static_cast(getTypeConverter()); + + // Get a symbol reference to the memcpy function, inserting it if necessary. + ModuleOp parentModule = op->getParentOfType(); + auto instrumentRef = getOrInsertInstrument(rewriter, parentModule); + + LLVM::GlobalOp globalOpNameStr = krnl::getOrCreateGlobalString( + opNameStr, loc, rewriter, parentModule, typeConverter); + Value opNamePtr = + krnl::getPtrToGlobalString(globalOpNameStr, loc, rewriter); + Value tag = create.llvm.constant( + IntegerType::get(context, 64), (int64_t)instrumentOp.getTag()); LLVM::GlobalOp globalStr = krnl::getOrCreateGlobalString( nodeName, loc, rewriter, parentModule, typeConverter); Value nodeNamePtr = krnl::getPtrToGlobalString(globalStr, loc, rewriter); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 6eb12d1b7f..3edea1abbf 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -16,6 +16,7 @@ #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "llvm/Support/Path.h" #include "src/Accelerators/Accelerator.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" @@ -23,6 +24,8 @@ #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp" +#define DEBUG_TYPE "lowering-to-krnl" + using namespace mlir; namespace onnx_mlir { @@ -661,18 +664,70 @@ bool hasNonIdentityLayout(ValueRange operands) { // Support functions for reporting. //===----------------------------------------------------------------------===// +// We may try to relate the node names generated by the instrumentation +// with the node names printed by onnx-op-report. Thus it is key to keep the +// code that generates these node name in sync. +// +// The code are found here: +// 1) `matchAndRewrite` from `src/Conversion/KrnlToLLVM/KrnlInstrument.cpp` +// 2) `getNodeNameLikeInKrnlInstrument` from +// `src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp` + +static std::string getNodeNameLikeInKrnlInstrument(Operation *op) { + StringAttr nodeName; + // Try with op onnx_node_name attribute. + nodeName = op->getAttrOfType("onnx_node_name"); + if (nodeName) { + LLVM_DEBUG(llvm::dbgs() << "op has node name\n"); + return nodeName.str(); + } + // Try with op location. + Location loc = op->getLoc(); + if (auto nameLoc = loc.dyn_cast()) { + LLVM_DEBUG(llvm::dbgs() << "op has node name\n"); + return nameLoc.getName().str(); + } + if (auto fusedLoc = loc.dyn_cast()) { + // Combine each location name and set it as nodeName. + LLVM_DEBUG(llvm::dbgs() << "fuse loc has name\n"); + std::string name; + for (Location locIt : fusedLoc.getLocations()) { + if (auto nameLocIt = locIt.dyn_cast()) + name += nameLocIt.getName().str() + "-"; + else if (auto fileLineColLoc = locIt.dyn_cast()) { + std::string filename = + llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str(); + name += filename + ":" + std::to_string(fileLineColLoc.getLine()) + "-"; + } + } + if (name.empty()) + name = "NOTSET"; + else + name.pop_back(); // remove last "-" + return name; + } + if (auto fileLineColLoc = loc.dyn_cast()) { + LLVM_DEBUG(llvm::dbgs() << "file line col has name\n"); + std::string filename = + llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str(); + std::string name = + filename + ":" + std::to_string(fileLineColLoc.getLine()); + return name; + } + return "NOTSET"; +} + void impl::onnxToKrnlParallelReport(Operation *op, bool successful, int64_t loopLevel, int64_t parallelLoopTripCount, const std::string &comment) { assert(OnnxToKrnlLoweringConfiguration::reportOnParallel && "must report"); assert(comment.find(',') == std::string::npos && "no comma in comments"); StringAttr opName = op->getName().getIdentifier(); - StringAttr nodeName = op->getAttrOfType("onnx_node_name"); + std::string nodeNameStr = getNodeNameLikeInKrnlInstrument(op); // Print report on this op. - fprintf(stderr, "==ONNX-PAR-REPORT==, %s%s, %s, %lld, %lld, %s\n", - opName.data(), (successful ? "-parallel" : ""), - (nodeName ? nodeName.data() : "no-node-name"), loopLevel, - parallelLoopTripCount, comment.c_str()); + printf("==PAR-REPORT==, %s%s, %s, %s, %lld, %lld\n", opName.data(), + (successful ? "-parallel" : ""), nodeNameStr.c_str(), comment.c_str(), + loopLevel, parallelLoopTripCount); } void impl::onnxToKrnlSimdReport(Operation *op, bool successful, @@ -681,7 +736,7 @@ void impl::onnxToKrnlSimdReport(Operation *op, bool successful, assert(OnnxToKrnlLoweringConfiguration::reportOnSimd && "must report"); assert(comment.find(',') == std::string::npos && "no comma in comments"); StringAttr opName = op->getName().getIdentifier(); - StringAttr nodeName = op->getAttrOfType("onnx_node_name"); + std::string nodeNameStr = getNodeNameLikeInKrnlInstrument(op); // Handling message. std::string message = OnnxToKrnlLoweringConfiguration::defaultSimdComment; if (message.empty()) @@ -690,10 +745,9 @@ void impl::onnxToKrnlSimdReport(Operation *op, bool successful, // No comments, all values indicate no simd message = "unsupported"; // Print report on this op. - fprintf(stderr, "==ONNX-SIMD-REPORT==, %s%s, %s, %lld, %lld, %s\n", - opName.data(), (successful ? "-simd" : ""), - (nodeName ? nodeName.data() : "no-node-name"), vectorLength, - simdLoopTripCount, message.c_str()); + printf("==SIMD-REPORT==, %s%s, %s, %s, %lld, %lld\n", opName.data(), + (successful ? "-simd" : ""), nodeNameStr.c_str(), message.c_str(), + vectorLength, simdLoopTripCount); } } // namespace onnx_mlir diff --git a/src/Runtime/OMInstrument.inc b/src/Runtime/OMInstrument.inc index 864c7a5104..5cedf1d7ca 100644 --- a/src/Runtime/OMInstrument.inc +++ b/src/Runtime/OMInstrument.inc @@ -39,6 +39,7 @@ #ifdef _WIN32 #include "windows.h" +// The windows.h include must go first. #include "psapi.h" static LARGE_INTEGER globalTime, initTime; @@ -50,13 +51,13 @@ static LARGE_INTEGER perfFrequency; static struct timeval globalTimeVal, initTimeVal; static pid_t mypid; -static int psErrorCount = 0; #endif static bool instrumentReportDisabled = false; static bool instrumentReportTimeDisabled = false; static bool instrumentReportMemoryDisabled = false; static int instrumentCounter = 0; +static int psErrorCount = 0; #ifdef __MVS__ #define timersub(a, b, result) \ @@ -96,9 +97,9 @@ void ReportTime() { LONGLONG resultSeconds, resultMicroseconds; QueryPerformanceCounter(&newTime); WinTimerSub(newTime, globalTime, &resultSeconds, &resultMicroseconds); - printf(" Time elapsed: %lld.%06lld", resultSeconds, resultMicroseconds); + printf(", %lld.%06lld", resultSeconds, resultMicroseconds); WinTimerSub(newTime, initTime, &resultSeconds, &resultMicroseconds); - printf(" accumulated: %lld.%06lld", resultSeconds, resultMicroseconds); + printf(", %lld.%06lld\n", resultSeconds, resultMicroseconds); globalTime = newTime; } #else @@ -106,11 +107,9 @@ void ReportTime() { struct timeval newTimeValue, result; gettimeofday(&newTimeValue, NULL); timersub(&newTimeValue, &globalTimeVal, &result); - printf(" Time elapsed: %ld.%06ld", (long int)result.tv_sec, - (long int)result.tv_usec); + printf(", %ld.%06ld", (long int)result.tv_sec, (long int)result.tv_usec); timersub(&newTimeValue, &initTimeVal, &result); - printf(" accumulated: %ld.%06ld", (long int)result.tv_sec, - (long int)result.tv_usec); + printf(", %ld.%06ld\n", (long int)result.tv_sec, (long int)result.tv_usec); globalTimeVal = newTimeValue; } #endif @@ -121,7 +120,7 @@ void ReportMemory() { GetProcessMemoryInfo( GetCurrentProcess(), (PROCESS_MEMORY_COUNTERS *)&pmc, sizeof(pmc)); SIZE_T vMemSizeKB = pmc.PrivateUsage / 1024; - printf(" VMem: %zu", vMemSizeKB); + printf("%zu\n", vMemSizeKB); } #else void ReportMemory() { @@ -132,22 +131,19 @@ void ReportMemory() { snprintf(memCommand, sizeof(memCommand), "ps -o vsz='' -p %d", mypid); memPipe = popen(memCommand, "r"); if (!memPipe) { - if (psErrorCount > 20) - return; - fprintf(stderr, "ERROR: Failed to execute ps"); + printf(", error-failed-to-execute-ps\n"); psErrorCount++; + return; } (void)fgets(memOutput, 200, memPipe); (void)fgetc(memPipe); memOutput[strcspn(memOutput, "\n")] = 0; - printf(" VMem:%s", memOutput); if (!feof(memPipe)) { - if (psErrorCount > 20) { - pclose(memPipe); - return; - } - fprintf(stderr, "ERROR: Unexpected output from ps"); + printf(", error-unexpected-output-from-pipe\n"); psErrorCount++; + } else { + // No error, print data/ + printf(", %s\n", memOutput); } pclose(memPipe); } @@ -180,24 +176,24 @@ void OMInstrumentPoint(const char *opName, int64_t tag, const char *nodeName) { if (instrumentReportDisabled) return; - // Print header - printf("#%d) %s %s", instrumentCounter, - tag & (1 << (int)InstrumentBeforeOp) ? "before" : "after ", opName); - instrumentCounter++; - - bool localReportTime = + bool isBefore = tag & (1 << (int)InstrumentBeforeOp); + bool reportTime = tag & (1 << (int)InstrumentReportTime) && !instrumentReportTimeDisabled; + bool reportMem = tag & (1 << (int)InstrumentReportMemory) && + !instrumentReportMemoryDisabled; - if (localReportTime) { + if (reportTime) { + // Print header and data for time. + printf("==PERF-REPORT==, %s, %s, %s", opName, nodeName, + (isBefore ? "before" : "after")); ReportTime(); } - - bool localReportMemory = tag & (1 << (int)InstrumentReportMemory) && - !instrumentReportMemoryDisabled; - if (localReportMemory) { + if (reportMem && psErrorCount < 20) { + // Print header and data for memory. + printf("==MEM-REPORT==, %s, %s, %s", opName, nodeName, + (isBefore ? "before" : "after")); ReportMemory(); } - if (strncmp(nodeName, "NOTSET", 6) != 0) - printf(" (%s)", nodeName); - printf("\n"); + // Not sure if needed, remove if that is not the case. + instrumentCounter++; } diff --git a/utils/make-report.py b/utils/make-report.py index a404367605..6135df7afe 100755 --- a/utils/make-report.py +++ b/utils/make-report.py @@ -16,170 +16,391 @@ ################################################################################ import sys -import os import getopt -import fileinput import re -import subprocess +import numpy as np def print_usage(msg = ""): if msg: print("Error:", msg, "\n") - print("make-report -[hs] -i [-d ] [-p ]") + print("make-report.py -[svh] [-c ] [-r ] [-l ] [-p ]") print("") - print("Usage: scan onnx op report generated by `onnx-mlir1 with `--onnx-op-report`") - print("equal to `Simd` or `Parallel` and report statistics about the ops found") - print("in the provided input log.") + print("Usage: Report statistics on compiler and runtime characteristics of onnx ops.") print("") - print("Parameters:") - print(" -d/--detail : Print statistics:") - print(" 0: Just count successful/unsuccessful ops.") - print(" 1: Also count reasons for success/failure.") - print(" 2: Also list metrics.") - print(" 3: Also list node name.") - print(" -h/--help: Print usage.") - print(" -i/--input : Input file to be scanned.") - print(" -p/--pattern : Focus only on ops that match the regexp pattern.") - print(" -s/supported: Focus only on ops that are supported, i.e. skip unsupported ops.") + print("Compile-time statistics are collected from a `onnx-mlir` compiler output") + print("with the `--onnx-op-report` option equal to `Simd` or other supported sub-options.") + print("") + print("Runtime statistics are collected from the runtime output of a model compiled.") + print("with the `--profile-ir` option equal to `Onnx` or other supported sub-options.") + print("") + print("When both compile time and runtime statistics are provided at the same time,") + print("it will correlate the performance metrics with data gathered at compile time.") print("") - print("More info:") + print("Additional help.") print(" If you need more specific info on individual success/failure, run ") print(" `onnx-mlir --debug-only=lowering-to-krnl` and look at the compiler output.") - print(" Use `-d 3` to correlate the node name printed here with compiler output.") + print(" Use `-l 3` to correlate the node name printed here with compiler output.") + print("") + print("Parameters:") + print(" -c/--compile : File name containing the compile time statistics.") + print(" -r/--runtime : File name containing the runtime statistics.") + print(" -l/--level : Print statistics:") + print(" 0: Just count successful/unsuccessful ops.") + print(" 1: Also count reasons for success/failure.") + print(" 2: Also list metrics.") + print(" 3: Also list node name.") + print(" -f/--focus : Focus only on ops that match the regexp pattern.") + print(" -s/supported: Focus only on ops that are supported. Namely, the report") + print(" will skip ops for which compile-time statistics list") + print(" the 'unsupported' keyword in its printout.") + print(" For SIMD/parallel statistics, this include all ops that") + print(" have currently no support for it.") + print(" -v/--verbose: Run in verbose mode (see error and warnings).") + print(" -h/--help: Print usage.") print("") exit(1) ################################################################################ # Global info. -report_pattern_count = {} -report_level_pattern_count = {} -# ==ONNX-SIMD-REPORT==, op name, node name, info -simd_report_str = r'^==ONNX-SIMD-REPORT==,\s*([0-9a-zA-Z\.\-]+)\s*,([^,]*),(.*)' -simd_stat_message = "SIMD vector length (in elements), SIMD loop trip count (-1 is runtime), message" -report_str = simd_report_str -stat_message = simd_stat_message +# For statistic info. +op_count_dict = {} # op -> count +op_detail_count_dict = {} # op -> {dictionary of detailed pattern -> count} +op_time_dict = {} # op -> cumulative time +op_detail_time_dict = {} # op -> {dictionary of detailed pattern -> cumulative time} + +# For timing info +node_time_dict = {} # op + node_name -> time statistic + focus_on_op_with_pattern = r'.*' +spurious_node_name_count = 0 +error_missing_time = 0 supported_only = False +has_timing = False +verbose = False report_level = 0 # 0: none; 1: details; 2: extra info; 3: plus node names +# Basic pattern for reports: "==" "==," "," "," +def common_report_str(stat_name): + return r'^==' + stat_name + r'-REPORT==,\s*([0-9a-zA-Z\.\-]+)\s*,\s*([^,]*),\s*(.*)' + +# ==SIMD-REPORT==, ..., , , +simd_stat_message = "SIMD vector length (in elements), SIMD loop trip count (-1 is runtime), message" + +# ==PERF-REPORT==, ..., "before" | "after", time since last call, absolute time +perf_stat_message = "(after|before), time for op(s), time since start(s)" + ################################################################################ -# Support. +# # Support. -def record_pattern(pat, det): - global report_pattern_count, report_level_pattern_count - global report_level +# To record time, use op name and node name to better disambiguate. +def timing_dict_key(op, node_name): + p = re.match(r'(.*)-(simd|par)', op) + if p: + op = p[1] + return op + "__" + node_name + +# Add num to dict[key] +def add_to_dict_entry(dict, key, num): + if key in dict: + return dict[key] + num + # First visit, entry does not exist. + return num + +# Dict1 is a dictionary of dictionaries. First locate the secondary directory, +# dict1[key1], and then add num to the key2 entry of that secondary dictionary. +def add_to_dict2_entry(dict1, key1, key2, num): + if key1 in dict1: + # Retrieve dict of dict. + dict2 = dict1[key1] + # Increment entry for key2. + dict2[key2] = add_to_dict_entry(dict2, key2, num) + return dict2 + # First visit, secondary dict is empty. + return { key2 : num} + +def append_to_dict_entry(dict, key, num): + if key in dict: + return np.append(dict[key], num) + # First visit, entry does not exist. + return np.array([num]) + +def append_to_dict2_entry(dict1, key1, key2, num): + if key1 in dict1: + # Retrieve dict of dict. + dict2 = dict1[key1] + # Increment entry for key2. + dict2[key2] = append_to_dict_entry(dict2, key2, num) + return dict2 + # First visit, secondary dict is empty. + return { key2 : np.array([num])} + + +def record_pattern(op, node_name, detail_key): + global op_count_dict, op_detail_count_dict + global op_time_dict, op_detail_time_dict + global node_time_dict + global verbose, report_level, has_timing, error_missing_time + + # Update statistic summaries + op_count_dict[op] = add_to_dict_entry(op_count_dict, op, 1) + if report_level > 0: + op_detail_count_dict[op] = add_to_dict2_entry( + op_detail_count_dict, op, detail_key, 1) + + # Has timing for this node? + if not has_timing: + return + # Process timing. + timing_key = timing_dict_key(op, node_name) + if not timing_key in node_time_dict: + error_missing_time += 1 + if verbose: + print("> timing key", timing_key, "with no times found in the performance data.") + return + # Update timing summaries + time = node_time_dict[timing_key] + op_time_dict[op] = append_to_dict_entry(op_time_dict, op, time) + if report_level > 0: + op_detail_time_dict[op] = append_to_dict2_entry( + op_detail_time_dict, op, detail_key, time) - if pat in report_pattern_count: - report_pattern_count[pat] = report_pattern_count[pat] + 1 - if report_level > 0: - det_dict = report_level_pattern_count[pat] - if det in det_dict: - det_dict[det] = det_dict[det] + 1 - else: - det_dict[det] = 1 - report_level_pattern_count[pat] = det_dict - else: - report_pattern_count[pat] = 1 - if report_level > 0: - det_dict = {} - det_dict[det] = 1 - report_level_pattern_count[pat] = det_dict ################################################################################ -# Main. +# Parse line (generic). + + +def parse_line(line, report_str, is_perf_stat): + global focus_on_op_with_pattern, supported_only + global verbose, spurious_node_name_count -def parse_file(file_name): - global report_str, report_level, focus_on_op_with_pattern, supported_only + p = re.match(report_str, line) + if p is None: + return (False, "", "", "") + # Have a line of relevant info, extract op, op name, and stat details. + op = p[1] + node_name = p[2] + details = p[3] + # If we process supported op only, search for "unsupported" in details. + if supported_only and re.search(r'unsupported', details) is not None: + return (False, "", "", "") + # If we process perf, we don't care about the "before" + if is_perf_stat and re.search(r'before', details) is not None: + return (False, "", "", "") + + # Check if we have an op that we focus on; if not skip. + f = re.match(focus_on_op_with_pattern, op) + if f is None: + return (False, "", "", "") + # Have a perfect match. + + # Issues due to runtime constants having issues. + new_node_name = node_name + # Spurious appending of the last node_name. + if parse_line.last_node_name: + i0 = re.match(r'(.+)'+parse_line.last_node_name, node_name) + if i0: + new_node_name = i0[1] + spurious_node_name_count += 1 + if verbose: + print("Cut last node_name:\n old:", node_name, + "\n cut:", parse_line.last_node_name, + "\n new:", new_node_name) + parse_line.last_node_name = node_name + # Repeating node name. + i1 = re.match(r'(.+)'+op, node_name) + if i1: + new_node_name = i1[1] + spurious_node_name_count += 1 + if verbose: + print("Cut op name:\n old:", node_name, + "\n cut:", op,"\n new:", new_node_name) + # Use new node name. + node_name = new_node_name + return (True, op, node_name, details) + +parse_line.last_node_name = "" + +################################################################################ +# Parse file for statistics. + +def parse_file_for_stat(file_name, stat_name): + global report_level try: file = open(file_name, 'r') except OSError: print_usage("Could not open file `"+file_name+"`") + report_str = common_report_str(stat_name) + is_perf_stat = re.match(r'PERF', stat_name) for line in file: - l = line.rstrip() - # Scan pattern, only keep at it only if we have the report info. - p = re.match(report_str, l) - if p is None: + # Parse line. + (has_stat, op, node_name, details) = parse_line(line.rstrip(), + report_str, is_perf_stat) + if not has_stat: continue - # Have a line. - op = p[1] - node_name = p[2] - details = p[3] - if supported_only and re.search(r'unsupported', details) is not None: - # Has an op that is unsupported, and we asked to skip them - continue - f = re.match(focus_on_op_with_pattern, op) - if f is None: - continue - # Have an interesting op. - if report_level == 0: - record_pattern(op, "") + # Use stat. + secondary_key = "" + detail_array = details.split(",") if report_level == 1: - detail_array = details.split(",") - record_pattern(op, detail_array[-1]) + # Use only first element in secondary key. + secondary_key = detail_array[0] elif report_level == 2: - record_pattern(op, details) + # Use all details in secondary key. + secondary_key = details elif report_level == 3: - record_pattern(op, node_name + ", " + details) - -def make_report(): - global report_pattern_count, report_level_pattern_count - global report_level, supported_only, stat_message - - if (report_level > 1): - print("Statistic legend:") - if (report_level == 2): - print(" num:", stat_message, "\n") - elif (report_level == 3): - print(" num: node-name, ", stat_message, "\n") - print("") + # Use node name in secondary key. + secondary_key = node_name + ", " + details + + record_pattern(op, node_name, secondary_key) + +################################################################################ +# Parse file for performance + +def parse_file_for_perf(file_name, stat_name): + global node_time_dict + global spurious_node_name_count, verbose, has_timing + + try: + file = open(file_name, 'r') + except OSError: + print_usage("Could not open file `"+file_name+"`") + + report_str = common_report_str(stat_name) + time_stat_dict = {} # op+op_name -> numpy array of times + last_node_name = "" + for line in file: + # Parse line. + (has_stat, op, node_name, details) = parse_line(line.rstrip(), + report_str, True) + if not has_stat: + continue + # Keep only after times. + detail_array = details.split(",") + key = timing_dict_key(op, node_name) + time_stat_dict[key] = append_to_dict_entry(time_stat_dict, + key, float(detail_array[1])) + + # Normally, an - pair should be seen only once in a run, + # except for loops. So we take here the sum of all the times. + # This approach would not work well if we had performance for multiple + # runs. + # TODO: If wanted to average/min/max over multiple runs, we would have + # need to pull this inside of the loop above, summing at the end of + # a run, and then taking min/max/average of the times gathered for each + # run. + for node in time_stat_dict: + node_time_dict[node] = np.sum(time_stat_dict[node]) + has_timing = True + +################################################################################ +# make report + +def make_report(stat_message): + global op_count_dict, op_detail_count_dict + global op_time_dict, op_detail_time_dict + global report_level, supported_only, verbose, spurious_node_name_count + global has_timing, error_missing_time + + num_desc = "num" + if has_timing: + num_desc += ", cumulative time(s)" + print("Statistic legend:") + if report_level < 2: + print(" op-name:", num_desc) + elif report_level == 2: + print(" " + num_desc + ":", stat_message, "\n") + elif report_level == 3: + print(" " + num_desc + ": node-name, ", stat_message, "\n") + print("") if supported_only: - print("Statistics (ignore unsupported ops):") + print("Statistics start (ignore unsupported ops).") else: - print("Statistics (all ops):") - for key in sorted(report_pattern_count): - print(" ", key, ":", report_pattern_count[key]) + print("Statistics start (all ops).") + for op in sorted(op_count_dict): + count_time_str = str(op_count_dict[op]) + if op in op_time_dict: + time = np.sum(op_time_dict[op]) + count_time_str += ", {:.7f}".format(time) + print(" " + op + ", " + count_time_str) if report_level: - det_dict = report_level_pattern_count[key] + det_dict = op_detail_count_dict[op] + det_time_dict = {} + if op in op_detail_time_dict: + det_time_dict = op_detail_time_dict[op] for det_key in sorted(det_dict): - if det_dict[det_key] == report_pattern_count[key]: - print(" *:", det_key) + if det_dict[det_key] == op_count_dict[op]: + count_time_str = "*" else: - print(" ", det_dict[det_key], ":", det_key) + count_time_str = str(det_dict[det_key]) + if det_key in det_time_dict: + time = np.sum(det_time_dict[det_key]) + count_time_str += ", {:.7f}".format(time) + print(" ", count_time_str, ":", det_key) + print("Statistics end.") + + # Report spurious node name if any. + if spurious_node_name_count: + if error_missing_time: + print("> Spurious node name were detected.") + print("> Timing information was missing for some of the nodes.") + else: + print("> Spurious node name were detected and fixed.") + print("> Run with `-v` for detailed list of fixes and errors.") + elif error_missing_time: + print("> Timing information was missing for some of the nodes.") + print("> Run with `-v` for detailed list of errors.") +################################################################################ +# Main. + def main(argv): - global report_level, focus_on_op_with_pattern, supported_only + global report_level, focus_on_op_with_pattern, supported_only, verbose - file_name = "" + compile_file_name = "" + runtime_file_name = "" try: opts, args = getopt.getopt( - argv, "d:fhi:p:s", ["detail=", "full", "help", "input=", "pattern=", "supported"]) + argv, "c:f:hl:r:sv", + ["compile=", "focus=", "help", "level=", "runtime=", "supported", "verbose"]) except getopt.GetoptError: print_usage("Failure to parse inputs") for opt, arg in opts: - if opt in ('-d', "--details"): - report_level = int(arg) - if (report_level<0 or report_level > 3): - print_usage("detail level is 0, 1, 2, or 3") - elif opt in ('-h', "--help"): - print_usage() - elif opt in ('-i', "--input"): - file_name = arg - elif opt in ('-p', "--pattern"): + if opt in ('-c', "--compile"): + compile_file_name = arg + elif opt in ('-f', "--focus"): focus_on_op_with_pattern = arg if report_level == 0: report_level = 1 + elif opt in ('-h', "--help"): + print_usage() + elif opt in ('-l', "--level"): + report_level = int(arg) + if (report_level<0 or report_level > 3): + print_usage("detail levels are 0, 1, 2, or 3") + elif opt in ('-r', "--runtime"): + runtime_file_name = arg elif opt in ('-s', "--supported"): supported_only = True + elif opt in ('-v', "--verbose"): + verbose = True - if not file_name: - print_usage("Command requires an input file name.\n") + if compile_file_name and runtime_file_name: + parse_file_for_perf(runtime_file_name, "PERF") + parse_file_for_stat(compile_file_name, "SIMD") + make_report(simd_stat_message) + elif compile_file_name: + parse_file_for_stat(compile_file_name, "SIMD") + make_report(simd_stat_message) + elif runtime_file_name: + parse_file_for_perf(runtime_file_name, "PERF") + parse_file_for_stat(runtime_file_name, "PERF") + make_report(perf_stat_message) + else: + print_usage("Command requires an input file name (compile/runtime or both).\n") - parse_file(file_name) - make_report() if __name__ == "__main__": main(sys.argv[1:])