Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pure quantum struct usage in kernels with restrictions #2199

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/cudaq/Frontend/nvqpp/ASTBridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class QuakeBridgeVisitor
DataRecursionQueue *q = nullptr);
bool VisitCXXConstructExpr(clang::CXXConstructExpr *x);
bool VisitCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
bool VisitCXXParenListInitExpr(clang::CXXParenListInitExpr *x);
bool WalkUpFromCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
bool TraverseDeclRefExpr(clang::DeclRefExpr *x,
DataRecursionQueue *q = nullptr);
Expand Down Expand Up @@ -586,6 +587,10 @@ class QuakeBridgeVisitor
mlir::StringRef funcName,
mlir::FunctionType funcTy);

/// Return true if the input type is a CC StructType and
/// it ONLY contains quantum members.
bool isQuantumStructType(mlir::Type ty);

/// Stack of Values built by the visitor. (right-to-left ordering)
mlir::SmallVector<mlir::Value> valueStack;
clang::ASTContext *astContext;
Expand Down
14 changes: 14 additions & 0 deletions lib/Frontend/nvqpp/ASTBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,20 @@ class QPUCodeFinder : public clang::RecursiveASTVisitor<QPUCodeFinder> {

#ifndef NDEBUG
namespace cudaq::details {
bool QuakeBridgeVisitor::isQuantumStructType(Type ty) {
auto structTy = dyn_cast<cc::StructType>(ty);
if (!structTy)
return false;

// If there is a classical data type, return false
for (auto member : structTy.getMembers())
if (!quake::isQuantumType(member))
return false;

// Is a struct and only contains quantum data types.
return true;
}

bool QuakeBridgeVisitor::pushValue(Value v) {
LLVM_DEBUG(llvm::dbgs() << std::string(valueStack.size(), ' ')
<< "+push value: ";
Expand Down
9 changes: 8 additions & 1 deletion lib/Frontend/nvqpp/ConvertDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ void QuakeBridgeVisitor::addArgumentSymbols(
auto parmTy = entryBlock->getArgument(index).getType();
if (isa<FunctionType, cc::CallableType, cc::PointerType, cc::SpanLikeType,
LLVM::LLVMStructType, quake::ControlType, quake::RefType,
quake::VeqType, quake::WireType>(parmTy)) {
quake::VeqType, quake::WireType>(parmTy) ||
isQuantumStructType(parmTy)) {
symbolTable.insert(name, entryBlock->getArgument(index));
} else {
auto stackSlot = builder.create<cc::AllocaOp>(loc, parmTy);
Expand Down Expand Up @@ -796,6 +797,12 @@ bool QuakeBridgeVisitor::VisitVarDecl(clang::VarDecl *x) {
return pushValue(cast);
}

// Don't allocate memory for a quantum or value-semantic struct.
if (auto insertValOp = initValue.getDefiningOp<cc::InsertValueOp>()) {
symbolTable.insert(x->getName(), initValue);
return pushValue(initValue);
}

// Initialization expression resulted in a value. Create a variable and save
// that value to the variable's memory address.
Value alloca = builder.create<cc::AllocaOp>(loc, type);
Expand Down
98 changes: 76 additions & 22 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,16 +1109,23 @@ bool QuakeBridgeVisitor::VisitMemberExpr(clang::MemberExpr *x) {
if (auto *field = dyn_cast<clang::FieldDecl>(x->getMemberDecl())) {
auto loc = toLocation(x->getSourceRange());
auto object = popValue(); // DeclRefExpr
auto eleTy = cast<cc::PointerType>(object.getType()).getElementType();
SmallVector<cc::ComputePtrArg> offsets;
if (auto arrTy = dyn_cast<cc::ArrayType>(eleTy))
if (arrTy.isUnknownSize())
offsets.push_back(0);
std::int32_t offset = field->getFieldIndex();
offsets.push_back(offset);
auto ty = popType();
return pushValue(builder.create<cc::ComputePtrOp>(
loc, cc::PointerType::get(ty), object, offsets));
if (auto ptrStructTy = dyn_cast<cc::PointerType>(object.getType())) {
auto eleTy = cast<cc::PointerType>(object.getType()).getElementType();
if (auto arrTy = dyn_cast<cc::ArrayType>(eleTy))
if (arrTy.isUnknownSize())
offsets.push_back(0);
offsets.push_back(offset);

return pushValue(builder.create<cc::ComputePtrOp>(
loc, cc::PointerType::get(ty), object, offsets));
}
// We have a struct value
offsets.push_back(offset);
return pushValue(
builder.create<cc::ExtractValueOp>(loc, ty, object, offsets));
}
return true;
}
Expand Down Expand Up @@ -2179,24 +2186,35 @@ bool QuakeBridgeVisitor::VisitCXXOperatorCallExpr(
auto idx_var = popValue();
auto qreg_var = popValue();

// Get name of the qreg, e.g. qr, and use it to construct a name for the
// element, which is intended to be qr%n when n is the index of the
// accessed qubit.
StringRef qregName = getNamedDecl(x->getArg(0))->getName();
auto name = getQubitSymbolTableName(qregName, idx_var);
char *varName = strdup(name.c_str());

// If the name exists in the symbol table, return its stored value.
if (symbolTable.count(name))
return replaceTOSValue(symbolTable.lookup(name));
if (isa<clang::DeclRefExpr>(x->getArg(0))) {
// Get name of the qreg, e.g. qr, and use it to construct a name for the
// element, which is intended to be qr%n when n is the index of the
// accessed qubit.
StringRef qregName = getNamedDecl(x->getArg(0))->getName();
auto name = getQubitSymbolTableName(qregName, idx_var);
char *varName = strdup(name.c_str());

// If the name exists in the symbol table, return its stored value.
if (symbolTable.count(name))
return replaceTOSValue(symbolTable.lookup(name));

// Otherwise create an operation to access the qubit, store that value
// in the symbol table, and return the AddressQubit operation's
// resulting value.
auto address_qubit =
builder.create<quake::ExtractRefOp>(loc, qreg_var, idx_var);

symbolTable.insert(StringRef(varName), address_qubit);
return replaceTOSValue(address_qubit);
}

// Otherwise create an operation to access the qubit, store that value in
// the symbol table, and return the AddressQubit operation's resulting
// value.
// We have a quantum value that is not in the symbol table.
// Here we will just extract the qubit. This is likely
// coming from a quantum struct member.
auto address_qubit =
builder.create<quake::ExtractRefOp>(loc, qreg_var, idx_var);

symbolTable.insert(StringRef(varName), address_qubit);
// symbolTable.insert(StringRef(varName), address_qubit);
return replaceTOSValue(address_qubit);
}
if (typeName == "vector") {
Expand Down Expand Up @@ -2367,7 +2385,7 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
bool allRef = std::all_of(last.begin(), last.end(), [](auto v) {
return isa<quake::RefType, quake::VeqType>(v.getType());
});
if (allRef) {
if (allRef && !isa<cc::StructType>(initListTy)) {
// Initializer list contains all quantum reference types. In this case we
// want to create quake code to concatenate the references into a veq.
if (size > 1) {
Expand Down Expand Up @@ -2438,6 +2456,19 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
auto globalInit = builder.create<cc::AddressOfOp>(loc, ptrTy, name);
return pushValue(globalInit);
}

// If quantum, use value semantics with cc insert / extract value.
if (isQuantumStructType(eleTy)) {
Value undefOpRes = builder.create<cc::UndefOp>(loc, eleTy);
for (auto iter : llvm::enumerate(last)) {
std::int32_t i = iter.index();
auto v = iter.value();
undefOpRes =
builder.create<cc::InsertValueOp>(loc, eleTy, undefOpRes, v, i);
}
return pushValue(undefOpRes);
}

Value alloca = (numEles > 1)
? builder.create<cc::AllocaOp>(loc, eleTy, arrSize)
: builder.create<cc::AllocaOp>(loc, eleTy);
Expand Down Expand Up @@ -2528,6 +2559,25 @@ static Type getEleTyFromVectorCtor(Type ctorTy) {
return ctorTy;
}

bool QuakeBridgeVisitor::VisitCXXParenListInitExpr(
clang::CXXParenListInitExpr *x) {
if (auto ty = peekType(); isQuantumStructType(ty)) {
auto loc = toLocation(x);
auto structTy = dyn_cast<cc::StructType>(ty);
auto last = lastValues(structTy.getMembers().size());
Value undefOpRes = builder.create<cc::UndefOp>(loc, structTy);
for (auto iter : llvm::enumerate(last)) {
std::int32_t i = iter.index();
auto v = iter.value();
undefOpRes =
builder.create<cc::InsertValueOp>(loc, structTy, undefOpRes, v, i);
}
return pushValue(undefOpRes);
}

return false;
}

bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
auto loc = toLocation(x);
auto *ctor = x->getConstructor();
Expand Down Expand Up @@ -2823,6 +2873,10 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
return true;
}

// Just walk through copy constructors for quantum struct types.
if (ctor->isCopyOrMoveConstructor() && isQuantumStructType(ctorTy))
return true;

if (ctor->isCopyOrMoveConstructor() && parent->isPOD()) {
// Copy or move constructor on a POD struct. The value stack should contain
// the object to load the value from.
Expand Down
51 changes: 47 additions & 4 deletions lib/Frontend/nvqpp/ConvertType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,53 @@ bool QuakeBridgeVisitor::VisitRecordDecl(clang::RecordDecl *x) {
SmallVector<Type> fieldTys =
lastTypes(std::distance(x->field_begin(), x->field_end()));
auto [width, alignInBytes] = getWidthAndAlignment(x);
if (name.empty())
return pushType(cc::StructType::get(ctx, fieldTys, width, alignInBytes));
return pushType(
cc::StructType::get(ctx, name, fieldTys, width, alignInBytes));
cc::StructType ty =
name.empty()
? cc::StructType::get(ctx, fieldTys, width, alignInBytes)
: cc::StructType::get(ctx, name, fieldTys, width, alignInBytes);

// Do some error analysis on the struct. Check the following:
// Does this struct contain contain a quantum struct? Recursive quantum types
// are not allowed
// Is this a struct with both classical and quantum types? Not allowed
// Does this struct have user-specified methods? Not allowd

for (auto fieldTy : fieldTys)
if (isQuantumStructType(fieldTy))
reportClangError(x, mangler,
"recursive quantum struct types are not allowed.");

if (!isQuantumStructType(ty))
for (auto fieldTy : fieldTys)
if (quake::isQuantumType(fieldTy))
reportClangError(
x, mangler,
"hybrid quantum-classical struct types are not allowed.");

// for any kind of struct struct, throw error if it has methods
if (auto *cxxRd = dyn_cast<clang::CXXRecordDecl>(x)) {
auto numMethods = [&cxxRd]() {
std::size_t count = 0;
for (auto methodIter = cxxRd->method_begin();
methodIter != cxxRd->method_end(); ++methodIter) {
// Don't check if this is a __qpu__ struct method
if (auto attr = (*methodIter)->getAttr<clang::AnnotateAttr>();
attr && attr->getAnnotation().str() == cudaq::kernelAnnotation)
continue;
// Check if the method is not implicit (i.e., user-defined)
if (!(*methodIter)->isImplicit())
count++;
}
return count;
}();

if (numMethods > 0)
reportClangError(
x, mangler,
"struct with user-defined methods is not allowed in quantum kernel.");
}

return pushType(ty);
}

bool QuakeBridgeVisitor::VisitFunctionProtoType(clang::FunctionProtoType *t) {
Expand Down
54 changes: 53 additions & 1 deletion lib/Optimizer/Dialect/CC/CCOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,14 +967,66 @@ struct FuseWithConstantArray
return success();
}
}
return failure();
}
};
struct FoldQuantumStructMember
: public OpRewritePattern<cudaq::cc::ExtractValueOp> {
using OpRewritePattern::OpRewritePattern;

// Replace
//
// %6 = cc.extract_value %5[2]:(!cc.struct<{..., !veq<?>}>) -> !veq<?>
//
// in code like
//
// %0 = quake.alloca !quake.veq<4>
// %1 = quake.relax_size %0 : (!quake.veq<4>) -> !quake.veq<?>
// ...
// %5 = cc.insert_value %1, %4[2] : (cc.struct..., !veq<?>) -> cc.struct...
// %6 = cc.extract_value %5[2]:(!cc.struct<{..., !veq<?>}>) -> !veq<?>
//
// with
//
// %0 = quake.alloca !quake.veq<4>

LogicalResult matchAndRewrite(cudaq::cc::ExtractValueOp evOp,
PatternRewriter &rewriter) const override {
using namespace cudaq;

// Only operate on quantum extractions.
if (!quake::isQuantumType(evOp.getResult().getType()))
return failure();

auto base = evOp.getAggregate();
auto idx = evOp.getRawConstantIndices()[0];

Value originalVeq;
cc::InsertValueOp ivOp = base.getDefiningOp<cc::InsertValueOp>();
while (ivOp) {
if (idx == ivOp.getPosition()[0]) {
originalVeq = ivOp.getValue();
break;
}
ivOp = ivOp.getContainer().getDefiningOp<cc::InsertValueOp>();
}

if (!ivOp)
return failure();

if (auto relaxSizeOp = originalVeq.getDefiningOp<quake::RelaxSizeOp>())
originalVeq = relaxSizeOp.getInputVec();

rewriter.replaceOp(evOp, originalVeq);
return success();
}
};

} // namespace

void cudaq::cc::ExtractValueOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<FuseWithConstantArray>(context);
patterns.add<FuseWithConstantArray, FoldQuantumStructMember>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
33 changes: 28 additions & 5 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,12 @@ def convertArithmeticToSuperiorType(self, values, type):

def isQuantumStructType(self, structTy):
"""
Return True if the given struct type has one or more quantum member variables.
Return True if the given struct type has only quantum member variables.
"""
if not cc.StructType.isinstance(structTy):
self.emitFatalError(
f'isQuantumStructType called on type that is not a struct ({structTy})'
)
return False

return True in [
return False not in [
self.isQuantumType(t) for t in cc.StructType.getTypes(structTy)
]

Expand Down Expand Up @@ -1903,8 +1901,33 @@ def bodyBuilder(iterVal):
mlirTypeFromPyType(v, self.ctx)
for _, v in annotations.items()
]
# Ensure we don't use hybrid data types
numQuantumMemberTys = sum(
[1 if self.isQuantumType(ty) else 0 for ty in structTys])
if numQuantumMemberTys != 0: # we have quantum member types
if numQuantumMemberTys != len(structTys):
self.emitFatalError(
f'hybrid quantum-classical data types not allowed in kernel code',
node)

structTy = cc.StructType.getNamed(self.ctx, node.func.id,
structTys)
# Disallow recursive quantum struct types.
for fieldTy in cc.StructType.getTypes(structTy):
if self.isQuantumStructType(fieldTy):
self.emitFatalError(
'recursive quantum struct types not allowed.', node)

# Disallow user specified methods on structs
if len({
k: v
for k, v in cls.__dict__.items()
if not (k.startswith('__') and k.endswith('__'))
}) != 0:
self.emitFatalError(
'struct types with user specified methods are not allowed.',
node)

nArgs = len(self.valueStack)
ctorArgs = [self.popValue() for _ in range(nArgs)]
ctorArgs.reverse()
Expand Down
Loading
Loading