Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP integration: refactoring EmitHLSCpp into EmissionMethods and addition of -emit-IP feature #31

Open
wants to merge 5 commits into
base: ip_integration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
395 changes: 395 additions & 0 deletions include/scalehls/Translation/EmissionMethods.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,395 @@
#include "scalehls/Translation/EmitHLSCpp.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Translation.h"
#include "scalehls/Dialect/HLSCpp/Visitor.h"
#include "scalehls/Dialect/HLSKernel/Visitor.h"
#include "scalehls/InitAllDialects.h"
#include "scalehls/Support/Utils.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"

#include <iostream>
#include <fstream>

using namespace mlir;
using namespace scalehls;

namespace mlir {
namespace scalehls {

//===----------------------------------------------------------------------===//
// Some Base Classes
//===----------------------------------------------------------------------===//

/// This class maintains the mutable state that cross-cuts and is shared by the
/// various emitters.
class ScaleHLSEmitterState {
public:
explicit ScaleHLSEmitterState(raw_ostream &os) : os(os) {}

// The stream to emit to.
raw_ostream &os;

bool encounteredError = false;
unsigned currentIndent = 0;

// This table contains all declared values.
DenseMap<Value, SmallString<8>> nameTable;

private:
ScaleHLSEmitterState(const ScaleHLSEmitterState &) = delete;
void operator=(const ScaleHLSEmitterState &) = delete;
};

/// This is the base class for all of the HLSCpp Emitter components. Simple methods are implemented here.
class ScaleHLSEmitterBase {
public:
explicit ScaleHLSEmitterBase(ScaleHLSEmitterState &state)
: state(state), os(state.os) {}

InFlightDiagnostic emitError(Operation *op, const Twine &message) {
state.encounteredError = true;
return op->emitError(message);
}

raw_ostream &indent() { return os.indent(state.currentIndent); }

void addIndent() { state.currentIndent += 2; }
void reduceIndent() { state.currentIndent -= 2; }

// All of the mutable state we are maintaining.
ScaleHLSEmitterState &state;

// The stream to emit to.
raw_ostream &os;
// std::fstream &os; // Instead of writing to stdout by default, write to file stream

/// Value name management methods.
SmallString<8> addName(Value val, bool isPtr = false);

SmallString<8> addAlias(Value val, Value alias);

SmallString<8> getName(Value val);

bool isDeclared(Value val) {
if (getName(val).empty()) {
return false;
} else
return true;
}

private:
ScaleHLSEmitterBase(const ScaleHLSEmitterBase &) = delete;
void operator=(const ScaleHLSEmitterBase &) = delete;

};

class ModuleEmitter : public ScaleHLSEmitterBase {
public:
using operand_range = Operation::operand_range;
explicit ModuleEmitter(ScaleHLSEmitterState &state)
: ScaleHLSEmitterBase(state) {}

/// SCF statement emitters.
void emitScfFor(scf::ForOp op);
void emitScfIf(scf::IfOp op);
void emitScfYield(scf::YieldOp op);

/// Affine statement emitters.
void emitAffineFor(AffineForOp op);
void emitAffineIf(AffineIfOp op);
void emitAffineParallel(AffineParallelOp op);
void emitAffineApply(AffineApplyOp op);
template <typename OpType>
void emitAffineMaxMin(OpType op, const char *syntax);
void emitAffineLoad(AffineLoadOp op);
void emitAffineStore(AffineStoreOp op);
void emitAffineYield(AffineYieldOp op);

/// Memref-related statement emitters.
template <typename OpType> void emitAlloc(OpType op);
void emitLoad(memref::LoadOp op);
void emitStore(memref::StoreOp op);

/// Tensor-related statement emitters.
void emitTensorLoad(memref::TensorLoadOp op);
void emitTensorStore(memref::TensorStoreOp op);
void emitTensorToMemref(memref::BufferCastOp op);
void emitDim(memref::DimOp op);
void emitRank(RankOp op);

/// Standard expression emitters.
void emitBinary(Operation *op, const char *syntax);
void emitUnary(Operation *op, const char *syntax);

/// IP operation emitter.
void emitIP(IPOp op);

/// Special operation emitters.
void emitCall(CallOp op);
void emitSelect(SelectOp op);
void emitConstant(arith::ConstantOp op);
template <typename CastOpType> void emitCast(CastOpType op);

/// Structure operations emitters.
void emitAssign(AssignOp op);

/// Top-level MLIR module emitter.
void emitModule(ModuleOp module);

private:
/// C++ component emitters.
void emitValue(Value val, unsigned rank = 0, bool isPtr = false);
void emitArrayDecl(Value array);
unsigned emitNestedLoopHead(Value val);
void emitNestedLoopTail(unsigned rank);
void emitInfoAndNewLine(Operation *op);

/// MLIR component and HLS C++ pragma emitters.
void emitBlock(Block &block);
void emitLoopDirectives(Operation *op);
void emitArrayDirectives(Value memref);
void emitFunctionDirectives(FuncOp func, ArrayRef<Value> portList);
void emitFunction(FuncOp func);
};

//===----------------------------------------------------------------------===//
// AffineEmitter Class
//===----------------------------------------------------------------------===//

class AffineExprEmitter : public ScaleHLSEmitterBase,
public AffineExprVisitor<AffineExprEmitter> {
public:
using operand_range = Operation::operand_range;
explicit AffineExprEmitter(ScaleHLSEmitterState &state, unsigned numDim,
operand_range operands)
: ScaleHLSEmitterBase(state), numDim(numDim), operands(operands) {}

void visitAddExpr(AffineBinaryOpExpr expr) { emitAffineBinary(expr, "+"); }
void visitMulExpr(AffineBinaryOpExpr expr) { emitAffineBinary(expr, "*"); }
void visitModExpr(AffineBinaryOpExpr expr) { emitAffineBinary(expr, "%"); }
void visitFloorDivExpr(AffineBinaryOpExpr expr) {
emitAffineBinary(expr, "/");
}
void visitCeilDivExpr(AffineBinaryOpExpr expr) {
// This is super inefficient.
os << "(";
visit(expr.getLHS());
os << " + ";
visit(expr.getRHS());
os << " - 1) / ";
visit(expr.getRHS());
os << ")";
}

void visitConstantExpr(AffineConstantExpr expr) { os << expr.getValue(); }

void visitDimExpr(AffineDimExpr expr) {
os << getName(operands[expr.getPosition()]);
}
void visitSymbolExpr(AffineSymbolExpr expr) {
os << getName(operands[numDim + expr.getPosition()]);
}

/// Affine expression emitters.
void emitAffineBinary(AffineBinaryOpExpr expr, const char *syntax) {
os << "(";
if (auto constRHS = expr.getRHS().dyn_cast<AffineConstantExpr>()) {
if ((unsigned)*syntax == (unsigned)*"*" && constRHS.getValue() == -1) {
os << "-";
visit(expr.getLHS());
os << ")";
return;
}
if ((unsigned)*syntax == (unsigned)*"+" && constRHS.getValue() < 0) {
visit(expr.getLHS());
os << " - ";
os << -constRHS.getValue();
os << ")";
return;
}
}
if (auto binaryRHS = expr.getRHS().dyn_cast<AffineBinaryOpExpr>()) {
if (auto constRHS = binaryRHS.getRHS().dyn_cast<AffineConstantExpr>()) {
if ((unsigned)*syntax == (unsigned)*"+" && constRHS.getValue() == -1 &&
binaryRHS.getKind() == AffineExprKind::Mul) {
visit(expr.getLHS());
os << " - ";
visit(binaryRHS.getLHS());
os << ")";
return;
}
}
}
visit(expr.getLHS());
os << " " << syntax << " ";
visit(expr.getRHS());
os << ")";
}

void emitAffineExpr(AffineExpr expr) { visit(expr); }

private:
unsigned numDim;
operand_range operands;
};

//===----------------------------------------------------------------------===//
// Definition of StmtVisitor, ExprVisitor, and PragmaVisitor Classes
//===----------------------------------------------------------------------===//

class StmtVisitor : public HLSCppVisitorBase<StmtVisitor, bool> {
public:
StmtVisitor(ModuleEmitter &emitter) : emitter(emitter) {}

using HLSCppVisitorBase::visitOp;
/// SCF statements.
bool visitOp(scf::ForOp op) { return emitter.emitScfFor(op), true; };
bool visitOp(scf::IfOp op) { return emitter.emitScfIf(op), true; };
bool visitOp(scf::ParallelOp op) { return true; };
bool visitOp(scf::ReduceOp op) { return true; };
bool visitOp(scf::ReduceReturnOp op) { return true; };
bool visitOp(scf::YieldOp op) { return emitter.emitScfYield(op), true; };

/// Affine statements.
bool visitOp(AffineForOp op) { return emitter.emitAffineFor(op), true; }
bool visitOp(AffineIfOp op) { return emitter.emitAffineIf(op), true; }
bool visitOp(AffineParallelOp op) {
return emitter.emitAffineParallel(op), true;
}
bool visitOp(AffineApplyOp op) { return emitter.emitAffineApply(op), true; }
bool visitOp(AffineMaxOp op) {
return emitter.emitAffineMaxMin<AffineMaxOp>(op, "max"), true;
}
bool visitOp(AffineMinOp op) {
return emitter.emitAffineMaxMin<AffineMinOp>(op, "min"), true;
}
bool visitOp(AffineLoadOp op) { return emitter.emitAffineLoad(op), true; }
bool visitOp(AffineStoreOp op) { return emitter.emitAffineStore(op), true; }
bool visitOp(AffineYieldOp op) { return emitter.emitAffineYield(op), true; }

/// Memref-related statements.
bool visitOp(memref::AllocOp op) {
return emitter.emitAlloc<memref::AllocOp>(op), true;
}
bool visitOp(memref::AllocaOp op) {
return emitter.emitAlloc<memref::AllocaOp>(op), true;
}
bool visitOp(memref::LoadOp op) { return emitter.emitLoad(op), true; }
bool visitOp(memref::StoreOp op) { return emitter.emitStore(op), true; }
bool visitOp(memref::DeallocOp op) { return true; }

/// Tensor-related statements.
bool visitOp(memref::TensorLoadOp op) {
return emitter.emitTensorLoad(op), true;
}
bool visitOp(memref::TensorStoreOp op) {
return emitter.emitTensorStore(op), true;
}
bool visitOp(memref::BufferCastOp op) {
return emitter.emitTensorToMemref(op), true;
}
bool visitOp(memref::DimOp op) { return emitter.emitDim(op), true; }
bool visitOp(RankOp op) { return emitter.emitRank(op), true; }

/// HLSCpp operations.
bool visitOp(AssignOp op) { return emitter.emitAssign(op), true; }
bool visitOp(CastOp op) { return emitter.emitCast<CastOp>(op), true; }
bool visitOp(MulOp op) { return emitter.emitBinary(op, "*"), true; }
bool visitOp(AddOp op) { return emitter.emitBinary(op, "+"), true; }

private:
ModuleEmitter &emitter;
};

class ExprVisitor : public HLSCppVisitorBase<ExprVisitor, bool> {
public:
ExprVisitor(ModuleEmitter &emitter) : emitter(emitter) {}

using HLSCppVisitorBase::visitOp;
/// Float binary expressions.
bool visitOp(arith::CmpFOp op);
bool visitOp(arith::AddFOp op) { return emitter.emitBinary(op, "+"), true; }
bool visitOp(arith::SubFOp op) { return emitter.emitBinary(op, "-"), true; }
bool visitOp(arith::MulFOp op) { return emitter.emitBinary(op, "*"), true; }
bool visitOp(arith::DivFOp op) { return emitter.emitBinary(op, "/"), true; }
bool visitOp(arith::RemFOp op) { return emitter.emitBinary(op, "%"), true; }

/// Integer binary expressions.
bool visitOp(arith::CmpIOp op);
bool visitOp(arith::AddIOp op) { return emitter.emitBinary(op, "+"), true; }
bool visitOp(arith::SubIOp op) { return emitter.emitBinary(op, "-"), true; }
bool visitOp(arith::MulIOp op) { return emitter.emitBinary(op, "*"), true; }
bool visitOp(arith::DivSIOp op) { return emitter.emitBinary(op, "/"), true; }
bool visitOp(arith::RemSIOp op) { return emitter.emitBinary(op, "%"), true; }
bool visitOp(arith::DivUIOp op) { return emitter.emitBinary(op, "/"), true; }
bool visitOp(arith::RemUIOp op) { return emitter.emitBinary(op, "%"), true; }
bool visitOp(arith::XOrIOp op) { return emitter.emitBinary(op, "^"), true; }
bool visitOp(arith::AndIOp op) { return emitter.emitBinary(op, "&"), true; }
bool visitOp(arith::OrIOp op) { return emitter.emitBinary(op, "|"), true; }
bool visitOp(arith::ShLIOp op) { return emitter.emitBinary(op, "<<"), true; }
bool visitOp(arith::ShRSIOp op) { return emitter.emitBinary(op, ">>"), true; }
bool visitOp(arith::ShRUIOp op) { return emitter.emitBinary(op, ">>"), true; }

/// Unary expressions.
bool visitOp(math::AbsOp op) { return emitter.emitUnary(op, "abs"), true; }
bool visitOp(math::CeilOp op) { return emitter.emitUnary(op, "ceil"), true; }
bool visitOp(math::CosOp op) { return emitter.emitUnary(op, "cos"), true; }
bool visitOp(math::SinOp op) { return emitter.emitUnary(op, "sin"), true; }
bool visitOp(math::TanhOp op) { return emitter.emitUnary(op, "tanh"), true; }
bool visitOp(math::SqrtOp op) { return emitter.emitUnary(op, "sqrt"), true; }
bool visitOp(math::RsqrtOp op) {
return emitter.emitUnary(op, "1.0 / sqrt"), true;
}
bool visitOp(math::ExpOp op) { return emitter.emitUnary(op, "exp"), true; }
bool visitOp(math::Exp2Op op) { return emitter.emitUnary(op, "exp2"), true; }
bool visitOp(math::LogOp op) { return emitter.emitUnary(op, "log"), true; }
bool visitOp(math::Log2Op op) { return emitter.emitUnary(op, "log2"), true; }
bool visitOp(math::Log10Op op) {
return emitter.emitUnary(op, "log10"), true;
}
bool visitOp(arith::NegFOp op) { return emitter.emitUnary(op, "-"), true; }

/// Special operations.
bool visitOp(CallOp op) { return emitter.emitCall(op), true; }
bool visitOp(ReturnOp op) { return true; }
bool visitOp(SelectOp op) { return emitter.emitSelect(op), true; }
bool visitOp(arith::ConstantOp op) { return emitter.emitConstant(op), true; }
bool visitOp(arith::IndexCastOp op) {
return emitter.emitCast<arith::IndexCastOp>(op), true;
}
bool visitOp(arith::UIToFPOp op) {
return emitter.emitCast<arith::UIToFPOp>(op), true;
}
bool visitOp(arith::SIToFPOp op) {
return emitter.emitCast<arith::SIToFPOp>(op), true;
}
bool visitOp(arith::FPToUIOp op) {
return emitter.emitCast<arith::FPToUIOp>(op), true;
}
bool visitOp(arith::FPToSIOp op) {
return emitter.emitCast<arith::FPToSIOp>(op), true;
}

private:
ModuleEmitter &emitter;
};

class KernelVisitor : public HLSKernelVisitorBase<KernelVisitor, bool> {
public:
KernelVisitor(ModuleEmitter &emitter) : emitter(emitter) {}

using HLSKernelVisitorBase::visitOp;
/// IP operation.
bool visitOp(IPOp op) { return emitter.emitIP(op), true; }

private:
ModuleEmitter &emitter;
};

}
}
Loading