Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Separate compilation, supporting C++ host side function references. #2216

Merged
merged 7 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/cudaq/Frontend/nvqpp/ASTBridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ class QuakeBridgeVisitor
bool isItaniumCXXABI();

private:
/// Check that the value on the top of the stack is an entry-point kernel.
bool hasTOSEntryKernel();

/// Map the block arguments to the names of the function parameters.
void addArgumentSymbols(mlir::Block *entryBlock,
mlir::ArrayRef<clang::ParmVarDecl *> parameters);
Expand Down
19 changes: 19 additions & 0 deletions include/cudaq/Optimizer/Builder/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,23 @@ static constexpr const char launchKernelHybridFuncName[] = "hybridLaunchKernel";

static constexpr const char mangledNameMap[] = "quake.mangled_name_map";

static constexpr const char deviceCodeHolderAdd[] =
"__cudaq_deviceCodeHolderAdd";

static constexpr const char registerLinkableKernel[] =
"__cudaq_registerLinkableKernel";
static constexpr const char getLinkableKernelKey[] =
"__cudaq_getLinkableKernelKey";
static constexpr const char getLinkableKernelName[] =
"__cudaq_getLinkableKernelName";
static constexpr const char getLinkableKernelDeviceSide[] =
"__cudaq_getLinkableKernelDeviceFunction";

static constexpr const char CudaqRegisterLambdaName[] =
"cudaqRegisterLambdaName";
static constexpr const char CudaqRegisterArgsCreator[] =
"cudaqRegisterArgsCreator";
static constexpr const char CudaqRegisterKernelName[] =
"cudaqRegisterKernelName";

} // namespace cudaq::runtime
18 changes: 9 additions & 9 deletions include/cudaq/Optimizer/CodeGen/Peephole.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ include "mlir/IR/PatternBase.td"
//===----------------------------------------------------------------------===//

def InvokeOnXWithOneControl : Constraint<CPred<
"callToInvokeWithXCtrlOneTarget($0.getValue(), $1)">>;
"$0 && callToInvokeWithXCtrlOneTarget($0.getValue(), $1)">>;

def CreateCallCnot : NativeCodeCall<
"[&]() -> std::size_t {"
Expand All @@ -35,7 +35,7 @@ def XCtrlOneTargetToCNot : Pat<

//===----------------------------------------------------------------------===//

def NeedsRenaming : Constraint<CPred<"needsToBeRenamed($0.getValue())">>;
def NeedsRenaming : Constraint<CPred<"$0 && needsToBeRenamed($0.getValue())">>;

def CreateAddressOf : NativeCodeCall<
"$_builder.create<mlir::LLVM::AddressOfOp>($_loc, $0.getType(),"
Expand All @@ -52,7 +52,7 @@ def AddrOfCisToBase : Pat<

// Apply special rule for `mz`. See below.
def FuncNotMeasure : Constraint<CPred<
"!$_self.getValue().startswith(cudaq::opt::QIRMeasure)">>;
"!($_self && $_self.getValue().startswith(cudaq::opt::QIRMeasure))">>;

def CreateCallOp : NativeCodeCall<
"[&]() -> std::size_t {"
Expand All @@ -72,7 +72,7 @@ def CalleeConv : Pat<
//===----------------------------------------------------------------------===//

def IsArrayGetElementPtrId : Constraint<CPred<
"$0.getValue().str() == cudaq::opt::QIRArrayGetElementPtr1d">>;
"$0 && $0.getValue().str() == cudaq::opt::QIRArrayGetElementPtr1d">>;

def EraseArrayGEPOp : NativeCodeCall<
"$_builder.create<mlir::LLVM::UndefOp>($_loc,"
Expand All @@ -85,7 +85,7 @@ def EraseDeadArrayGEP : Pat<
//===----------------------------------------------------------------------===//

def IsaAllocateCall : Constraint<CPred<
"$0.getValue().str() == cudaq::opt::QIRArrayQubitAllocateArray">>;
"$0 && $0.getValue().str() == cudaq::opt::QIRArrayQubitAllocateArray">>;

def EraseArrayAllocateOp : NativeCodeCall<
"$_builder.create<mlir::LLVM::UndefOp>($_loc,"
Expand All @@ -103,8 +103,8 @@ def EraseArrayAlloc : Pat<
//===----------------------------------------------------------------------===//

def IsaReleaseCall : Constraint<CPred<
"$0.getValue().str() == cudaq::opt::QIRArrayQubitReleaseArray || "
"$0.getValue().str() == cudaq::opt::QIRArrayQubitReleaseQubit">>;
"$0 && ($0.getValue().str() == cudaq::opt::QIRArrayQubitReleaseArray || "
"$0.getValue().str() == cudaq::opt::QIRArrayQubitReleaseQubit)">>;

def EraseArrayReleaseOp : NativeCodeCall<"static_cast<std::size_t>(0)">;

Expand All @@ -120,7 +120,7 @@ def EraseArrayRelease : Pat<
//===----------------------------------------------------------------------===//

def IsaMeasureCall : Constraint<CPred<
"$_self.getValue() == cudaq::opt::QIRMeasure">>;
"$_self && $_self.getValue() == cudaq::opt::QIRMeasure">>;

def IsaIntToPtrOperand : Constraint<CPred<"isIntToPtrOp($0[0])">>;

Expand All @@ -138,7 +138,7 @@ def MeasureCallConv : Pat<
//===----------------------------------------------------------------------===//

def IsaMeasureToRegisterCall : Constraint<CPred<
"$_self.getValue() == cudaq::opt::QIRMeasureToRegister">>;
"$_self && $_self.getValue() == cudaq::opt::QIRMeasureToRegister">>;

// %result = call @__quantum__qis__mz__to__register(%qbit, i8) : (!Qubit) -> i1
// ────────────────────────────────────────────────────────────────────────────
Expand Down
43 changes: 43 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,49 @@ def cc_CallCallableOp : CCOp<"call_callable", [CallOpInterface]> {
}];
}

def cc_CallIndirectCallableOp : CCOp<"call_indirect_callable",
[CallOpInterface]> {
let summary = "Call a C++ callable, unresolved, at run-time.";
let description = [{
This effectively connects a call from one kernel to another kernel, which
would have been done at link-time in host code, at run-time on the device
side. This allows calls between kernels defined in separate compilation
units. The definitions of these caller/callee functions are not both present
at compile-time, so they are exposed to the CUDAQ runtime for stitching or
LTO at JIT compile time.
}];

let arguments = (ins
cc_IndirectCallableType:$callee,
Variadic<AnyType>:$args
);
let results = (outs Variadic<AnyType>:$results);
let hasVerifier = 1;
let hasCanonicalizer = 1;

let assemblyFormat = [{
$callee (`,` $args^)? `:` functional-type(operands, results) attr-dict
}];

let extraClassDeclaration = [{
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}

operand_iterator arg_operand_begin() { return ++operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }

/// Return the callee of this operation.
mlir::CallInterfaceCallable getCallableForCallee() { return getCallee(); }

mlir::FunctionType getFunctionType() {
return mlir::FunctionType::get(getContext(), getOperands().getType(),
getResults().getTypes());
}
}];
}

def cc_InstantiateCallableOp : CCOp<"instantiate_callable", [Pure]> {
let summary = "Construction of a callable object.";
let description = [{
Expand Down
25 changes: 25 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,31 @@ def cc_CallableType : CCType<"Callable", "callable"> {
}];
}

def cc_IndirectCallableType : CCType<"IndirectCallable", "indirect_callable"> {
let summary = "Proxy for cudaq::qkernel_ref.";
let description = [{
An entry-point kernel may take a reference to another kernel as an argument.
The passed kernel may be entirely opaque at compile-time with its definition
present in some other compilation module.

It is on the programmer to use the cudaq::qkernel_ref type. This wrapper
class is very much like std::function, but it extends that functionality
with some extra information for the runtime to be able to "link" the
distinct kernels on the device side and provide, for example, LTO at
JIT compile time.
}];

let parameters = (ins "mlir::FunctionType":$signature);

let assemblyFormat = "`<` $signature `>`";

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::FunctionType":$signature), [{
return Base::get(signature.getContext(), signature);
}]>
];
}

//===----------------------------------------------------------------------===//
// StdVectorType - implemented as a span
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 0 additions & 4 deletions include/cudaq/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@

namespace cudaq::opt {

/// Pass to generate the device code loading stubs.
std::unique_ptr<mlir::Pass>
createGenerateDeviceCodeLoader(bool genAsQuake = false);

/// Add a pass pipeline to transform call between kernels to direct calls that
/// do not go through the runtime layers, inline all calls, and detect if calls
/// to kernels remain in the fully inlined into entry point kernel.
Expand Down
5 changes: 3 additions & 2 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,14 @@ def GenerateDeviceCodeLoader : Pass<"device-code-loader", "mlir::ModuleOp"> {
}];

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
let constructor = "cudaq::opt::createGenerateDeviceCodeLoader()";

let options = [
Option<"outputFilename", "output-filename", "std::string",
/*default=*/"\"-\"", "Name of output file.">,
Option<"generateAsQuake", "use-quake", "bool",
/*default=*/"false", "Output should be module in Quake dialect.">
/*default=*/"true", "Output should be module in Quake dialect.">,
Option<"jitTime", "jit-compile", "bool",
/*default=*/"false", "Running pass at JIT compile time (default=false).">
];
}

Expand Down
1 change: 1 addition & 0 deletions lib/Frontend/nvqpp/ASTBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class QPUCodeFinder : public clang::RecursiveASTVisitor<QPUCodeFinder> {
return true;
}

// NB: DataRecursionQueue* argument intentionally omitted.
bool TraverseLambdaExpr(clang::LambdaExpr *x) {
bool saveQuantumTypesNotAllowed = quantumTypesNotAllowed;
// Rationale: a lambda expression may be passed from classical C++ code into
Expand Down
15 changes: 12 additions & 3 deletions lib/Frontend/nvqpp/ConvertDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ void QuakeBridgeVisitor::addArgumentSymbols(
// Transform pass-by-value arguments to stack slots.
auto loc = toLocation(argVal);
auto parmTy = entryBlock->getArgument(index).getType();
if (isa<FunctionType, cc::CallableType, cc::PointerType, cc::SpanLikeType,
LLVM::LLVMStructType, quake::ControlType, quake::RefType,
quake::VeqType, quake::WireType>(parmTy)) {
if (isa<FunctionType, cc::CallableType, cc::IndirectCallableType,
cc::PointerType, cc::SpanLikeType, LLVM::LLVMStructType,
quake::ControlType, quake::RefType, quake::VeqType,
quake::WireType>(parmTy)) {
symbolTable.insert(name, entryBlock->getArgument(index));
} else {
auto stackSlot = builder.create<cc::AllocaOp>(loc, parmTy);
Expand Down Expand Up @@ -176,6 +177,14 @@ bool QuakeBridgeVisitor::interceptRecordDecl(clang::RecordDecl *x) {
return pushType(cc::StateType::get(ctx));
if (name.equals("pauli_word"))
return pushType(cc::CharspanType::get(ctx));
if (name.equals("qkernel_ref")) {
auto *cts = cast<clang::ClassTemplateSpecializationDecl>(x);
// Traverse template argument 0 to get the function's signature.
if (!TraverseType(cts->getTemplateArgs()[0].getAsType()))
return false;
auto fnTy = cast<FunctionType>(popType());
return pushType(cc::IndirectCallableType::get(fnTy));
}
auto loc = toLocation(x);
TODO_loc(loc, "unhandled type, " + name + ", in cudaq namespace");
}
Expand Down
43 changes: 30 additions & 13 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2163,6 +2163,16 @@ bool QuakeBridgeVisitor::WalkUpFromCXXOperatorCallExpr(
return WalkUpFromCallExpr(x) && VisitCXXOperatorCallExpr(x);
}

bool QuakeBridgeVisitor::hasTOSEntryKernel() {
if (auto fn = peekValue().getDefiningOp<func::ConstantOp>()) {
auto name = fn.getValue().str();
for (auto fdPair : functionsToEmit)
if (getCudaqKernelName(fdPair.first) == name)
return true;
}
return false;
}

bool QuakeBridgeVisitor::VisitCXXOperatorCallExpr(
clang::CXXOperatorCallExpr *x) {
auto loc = toLocation(x->getSourceRange());
Expand Down Expand Up @@ -2246,21 +2256,11 @@ bool QuakeBridgeVisitor::VisitCXXOperatorCallExpr(
auto tos = popValue();
auto tosTy = tos.getType();
auto ptrTy = dyn_cast<cc::PointerType>(tosTy);
bool isEntryKernel = [&]() {
// TODO: make this lambda a member function.
if (auto fn = peekValue().getDefiningOp<func::ConstantOp>()) {
auto name = fn.getValue().str();
for (auto fdPair : functionsToEmit)
if (getCudaqKernelName(fdPair.first) == name)
return true;
}
return false;
}();
if (ptrTy || isEntryKernel) {
bool isEntryKernel = hasTOSEntryKernel();
if ((ptrTy && isa<cc::StructType>(ptrTy.getElementType())) ||
isEntryKernel) {
// The call operator has an object in the call position, so we want to
// replace it with an indirect call to the func::ConstantOp.
assert((isEntryKernel || isa<cc::StructType>(ptrTy.getElementType())) &&
"expected kernel as callable class");
auto indirect = popValue();
auto funcTy = cast<FunctionType>(indirect.getType());
auto call = builder.create<func::CallIndirectOp>(
Expand All @@ -2269,6 +2269,23 @@ bool QuakeBridgeVisitor::VisitCXXOperatorCallExpr(
return true;
return pushValue(call.getResult(0));
}
auto indCallTy = [&]() -> cc::IndirectCallableType {
if (ptrTy) {
auto ty = dyn_cast<cc::IndirectCallableType>(ptrTy.getElementType());
if (ty)
return ty;
}
return dyn_cast<cc::IndirectCallableType>(tosTy);
}();
if (indCallTy) {
[[maybe_unused]] auto discardedCallOp = popValue();
auto funcTy = cast<FunctionType>(indCallTy.getSignature());
auto call = builder.create<cc::CallIndirectCallableOp>(
loc, funcTy.getResults(), tos, args);
if (call.getResults().empty())
return true;
return pushValue(call.getResult(0));
}
auto callableTy = cast<cc::CallableType>(tosTy);
auto callInd = builder.create<cc::CallCallableOp>(
loc, callableTy.getSignature().getResults(), tos, args);
Expand Down
11 changes: 7 additions & 4 deletions lib/Frontend/nvqpp/ConvertType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ static bool isKernelSignatureType(FunctionType t);
static bool isKernelCallable(Type t) {
if (auto lambdaTy = dyn_cast<cudaq::cc::CallableType>(t))
return isKernelSignatureType(lambdaTy.getSignature());
if (auto lambdaTy = dyn_cast<cudaq::cc::IndirectCallableType>(t))
return isKernelSignatureType(lambdaTy.getSignature());
return false;
}

Expand Down Expand Up @@ -364,8 +366,8 @@ bool QuakeBridgeVisitor::VisitLValueReferenceType(
if (t->getPointeeType()->isUndeducedAutoType())
return pushType(cc::PointerType::get(builder.getContext()));
auto eleTy = popType();
if (isa<cc::CallableType, cc::SpanLikeType, quake::VeqType, quake::RefType>(
eleTy))
if (isa<cc::CallableType, cc::IndirectCallableType, cc::SpanLikeType,
quake::VeqType, quake::RefType>(eleTy))
return pushType(eleTy);
return pushType(cc::PointerType::get(eleTy));
}
Expand All @@ -376,8 +378,9 @@ bool QuakeBridgeVisitor::VisitRValueReferenceType(
return pushType(cc::PointerType::get(builder.getContext()));
auto eleTy = popType();
// FIXME: LLVMStructType is promoted as a temporary workaround.
if (isa<cc::CallableType, cc::SpanLikeType, cc::ArrayType, cc::StructType,
quake::VeqType, quake::RefType, LLVM::LLVMStructType>(eleTy))
if (isa<cc::ArrayType, cc::CallableType, cc::IndirectCallableType,
cc::SpanLikeType, cc::StructType, quake::VeqType, quake::RefType,
LLVM::LLVMStructType>(eleTy))
return pushType(eleTy);
return pushType(cc::PointerType::get(eleTy));
}
Expand Down
4 changes: 4 additions & 0 deletions lib/Optimizer/Builder/Factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ static Type genBufferType(Type ty) {
auto *ctx = ty.getContext();
if (isa<cudaq::cc::CallableType>(ty))
return cudaq::cc::PointerType::get(ctx);
if (isa<cudaq::cc::IndirectCallableType>(ty))
return IntegerType::get(ctx, 64);
if (auto vecTy = dyn_cast<cudaq::cc::SpanLikeType>(ty)) {
auto i64Ty = IntegerType::get(ctx, 64);
if (isOutput) {
Expand Down Expand Up @@ -368,6 +370,8 @@ static Type convertToHostSideType(Type ty) {
if (auto memrefTy = dyn_cast<cc::StdvecType>(ty))
return convertToHostSideType(
factory::stlVectorType(memrefTy.getElementType()));
if (isa<cc::IndirectCallableType>(ty))
return cc::PointerType::get(IntegerType::get(ty.getContext(), 8));
if (auto memrefTy = dyn_cast<cc::CharspanType>(ty)) {
// `pauli_word` is an object with a std::vector in the header files at
// present. This data type *must* be updated if it becomes a std::string
Expand Down
Loading
Loading