diff --git a/enzyme/Enzyme/Herbie.cpp b/enzyme/Enzyme/Herbie.cpp index 902d000b8b06..1337d21945ee 100644 --- a/enzyme/Enzyme/Herbie.cpp +++ b/enzyme/Enzyme/Herbie.cpp @@ -77,24 +77,91 @@ void runViaHerbie(const std::string &cmd) { output.close(); } +std::string getHerbieOperator(const Instruction &I) { + switch (I.getOpcode()) { + case Instruction::FAdd: + return "+"; + case Instruction::FSub: + return "-"; + case Instruction::FMul: + return "*"; + case Instruction::FDiv: + return "/"; + default: + return "UnknownOp"; + } +} + // Run (our choice of) floating point optimizations on function `F`. // Return whether or not we change the function. -bool fpOptimize(llvm::Function &F) { +bool fpOptimize(Function &F) { bool changed = false; - // 1) Identify subgraphs of the computation which can be entirely represented - // in herbie-style arithmetic + std::string herbieInput; + std::map valueToSymbolMap; + std::map symbolToValueMap; + std::set arguments; + int symbolCounter = 0; - llvm::errs() << "Optimizing function " << F.getName().str() << "\n"; + auto getNextSymbol = [&symbolCounter]() -> std::string { + return "v" + std::to_string(symbolCounter++); + }; + // 1) Identify subgraphs of the computation which can be entirely represented + // in herbie-style arithmetic // 2) Make the herbie FP-style expression by // converting llvm instructions into herbie string (FPNode ....) + for (auto &BB : F) { + for (auto &I : BB) { + if (auto *op = dyn_cast(&I)) { + if (op->getType()->isFloatingPointTy()) { + std::string lhs = + valueToSymbolMap.count(op->getOperand(0)) + ? valueToSymbolMap[op->getOperand(0)] + : (valueToSymbolMap[op->getOperand(0)] = getNextSymbol()); + std::string rhs = + valueToSymbolMap.count(op->getOperand(1)) + ? valueToSymbolMap[op->getOperand(1)] + : (valueToSymbolMap[op->getOperand(1)] = getNextSymbol()); + + arguments.insert(lhs); + arguments.insert(rhs); + + std::string symbol = getNextSymbol(); + valueToSymbolMap[&I] = symbol; + symbolToValueMap[symbol] = &I; + + std::string herbieNode = "("; + herbieNode += getHerbieOperator(I); + herbieNode += " "; + herbieNode += lhs; + herbieNode += " "; + herbieNode += rhs; + herbieNode += ")"; + herbieInput += herbieNode; + } + } + } + } - // 3) run fancy opts + if (herbieInput.empty()) { + return changed; + } - // runViaHerbie() + std::string argumentsStr = "("; + for (const auto &arg : arguments) { + argumentsStr += arg + " "; + } + argumentsStr.pop_back(); + argumentsStr += ")"; - // 4) parse the output string solution from herbieland + herbieInput = "(FPCore " + argumentsStr + " " + herbieInput + ")"; + llvm::errs() << "Herbie input:\n" << herbieInput << "\n"; + + // 3) run fancy opts + runViaHerbie(herbieInput); + + // 4) parse the output string solution from herbieland // 5) convert into a solution in llvm vals/instructions return changed; }